blob: e0c6cf0785ae9397f4e1a781322ffdab44206166 [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_error_if import ErrorIf
9from generator.tosa_error_if import TosaErrorIfArgGen
James Ward8b390432022-08-12 20:48:56 +010010from generator.tosa_utils import get_accum_dtype_from_tgTypes
11from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010012from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from serializer.tosa_serializer import DTypeNames
14from tosa.DType import DType
15from tosa.Op import Op
16from tosa.ResizeMode import ResizeMode
17
18# DTypeNames, DType, Op and ResizeMode are convenience variables to the
19# flatc-generated types that should be enums, but aren't
20
21
22class TosaQuantGen:
23 """QuantizedInfo random generator helper functions.
24
25 Specify with 'qgen': in the operator defintion.
26 """
27
28 def __init__(self):
29 pass
30
31 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000032 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010033
34 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010035 if testGen.args.zeropoint is not None:
36 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010037 return testGen.randInt(-128, 128)
38 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010039 if testGen.args.zeropoint is not None:
40 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010041 return testGen.randInt(0, 256)
42 elif error_name in [
43 ErrorIf.InputZeroPointNotZero,
44 ErrorIf.WeightZeroPointNotZero,
45 ErrorIf.OutputZeroPointNotZero,
46 ]:
47 zero_point = testGen.randInt(-128, 128)
48 if zero_point == 0:
49 zero_point = 1
50 return zero_point
51 return 0
52
53 @staticmethod
54 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010055 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000056 qinfo = [
57 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
58 TosaQuantGen.getZeroPoint(testGen, dtype),
59 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010060 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000061 qinfo = [
62 TosaQuantGen.getZeroPoint(testGen, dtype),
63 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
64 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010065 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000066 qinfo = [
67 TosaQuantGen.getZeroPoint(testGen, dtype),
68 TosaQuantGen.getZeroPoint(testGen, dtype),
69 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010074 if isinstance(dtype_or_dtypeList, list):
75 # a list of [input, weights, accumulator] dtypes
76 dtypeList = dtype_or_dtypeList
77 else:
78 # an int, [input, weights, accumulator] dtypes are the same
79 dtypeList = [dtype_or_dtypeList] * 3
80
81 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000082 qinfo = [
83 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
84 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
85 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010086 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000087 qinfo = [
88 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
89 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
90 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010091 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000092 qinfo = [
93 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
94 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
95 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 return qinfo
97
98 @staticmethod
99 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000101 qinfo = [
102 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
103 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
104 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100105 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000106 qinfo = [
107 TosaQuantGen.getZeroPoint(testGen, dtype),
108 TosaQuantGen.getZeroPoint(testGen, dtype),
109 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100110 return qinfo
111
112 @staticmethod
113 def computeMultiplierAndShift(scaleFp, scale32):
114 # Derived from computeMultiplierAndShiftTosaScale32
115 # Provide a floating-point scaling factor and the scale32 parameter
116 # to compute the multiplier and shift
117
118 if scale32:
119 scaleBits = 31
120 else:
121 scaleBits = 15
122
123 m, shift = math.frexp(scaleFp)
124
125 if scaleFp < 0.0:
126 m = -m
127
128 multiplier = round(m * (1 << scaleBits))
129 assert multiplier <= (1 << scaleBits)
130
131 if multiplier == (1 << scaleBits):
132 multiplier = multiplier // 2
133 shift = shift + 1
134
135 shift = (-shift) + scaleBits
136 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
137 # scaleFp, scaleBits, m, multiplier, shift))
138
139 # Adjust multiplier such that shift is in allowed value range.
140 if shift == 0:
141 multiplier = multiplier // 4
142 shift = shift + 2
143 elif shift == 1:
144 multiplier = multiplier // 2
145 shift = shift + 1
146 elif shift == 63:
147 multiplier = multiplier * 2
148 shift = shift - 1
149
150 assert multiplier <= (1 << scaleBits)
151 assert shift >= 2 and shift <= 62
152
153 return multiplier, shift
154
155
156class TosaTensorGen:
157 """Tensor generators create a shape list for the placeholder and const tensor
158 data operands for the operator.
159
160 The actual random data is generated separately for each test.
161 """
162
163 def __init__(self):
164 pass
165
166 @staticmethod
167 def tgBasic(testGen, opName, rank, error_name=None):
168 pl, const = opName["operands"]
169 shape = testGen.makeShape(rank)
170
171 # Constrict the overall size of the shape when creating ERROR_IF tests
172 if error_name:
173 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
174
175 shape_list = []
176 for i in range(pl + const):
177 shape_list.append(shape.copy())
178
179 if error_name == ErrorIf.RankMismatch:
180 if rank == 1 and i != 1:
181 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
182 elif i != 1:
183 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
184
185 return shape_list
186
187 @staticmethod
188 def tgNHWC(testGen, opName, rank, error_name=None):
189 pl, const = opName["operands"]
190
191 if error_name != ErrorIf.WrongRank:
192 assert rank == 4
193
194 shape = testGen.makeShape(rank)
195
196 # Constrict the batch size?
197 if testGen.args.max_batch_size:
198 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
211 def tgScatter(testGen, opName, rank, error_name=None):
212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
219 values_in_shape = testGen.makeShape(rank)
220
221 # ignore max batch size if target shape is set
222 if testGen.args.max_batch_size and not testGen.args.target_shapes:
223 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
224
225 W = testGen.randInt(
226 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
227 )
228 # Constrict W if one dimension is too large to keep tensor size reasonable
229 if max(values_in_shape) > 5000:
230 W = testGen.randInt(0, 16)
231
232 input_shape = [values_in_shape[0], W, values_in_shape[2]]
233
234 shape_list = []
235 shape_list.append(values_in_shape.copy())
236 shape_list.append(input_shape.copy())
237
238 return shape_list
239
240 @staticmethod
241 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
242 shape = testGen.makeShape(rank)
243
244 pl, const = op["operands"]
245
246 shape_list = []
247
248 # Choose one of the inputs to broadcast
249 # Note: Simplifies OutputShaper code if we don't change first shape for errors
250 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
251 for i in range(pl + const):
252 shape_bcast = shape.copy()
253
254 # If the chosen input, pick a random index to broadcast
255 if i == bcast_idx:
256 fuzz_idx = testGen.randInt(0, rank)
257 if error_name == ErrorIf.DimensionMismatch:
258 shape_bcast[fuzz_idx] += 1
259 elif error_name == ErrorIf.RankMismatch:
260 # Add one rank to the shape (or more for rank of 1)
261 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
262 shape_bcast = np.concatenate(
263 (shape_bcast, testGen.makeShape(extra_ranks))
264 )
265 if rank != 1:
266 # Either keep the extra rank, or remove it
267 new_len = testGen.rng.choice([-2, len(shape_bcast)])
268 shape_bcast = shape_bcast[:new_len]
269 else:
270 shape_bcast[fuzz_idx] = 1
271
272 shape_list.append(shape_bcast)
273
274 return shape_list
275
276 @staticmethod
277 def tgConv2D(testGen, op, rank, error_name=None):
278 pl, const = op["operands"]
279
280 if error_name != ErrorIf.WrongRank:
281 assert rank == 4
282
283 # IFM dimensions are NHWC
284 ifm_shape = testGen.makeShape(rank)
285
286 # Constrict the batch size?
287 if testGen.args.max_batch_size:
288 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
289
290 # Constrict the overall size of the shape when creating ERROR_IF tests
291 if error_name:
292 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
293 ifm_shape, max_dim=24, max_items=10000
294 )
295
296 # Get the filter height/width from the operator parameters
297 filter_hw = op["filter"]
298
299 # Generate a random OFM depth
300 ofm_depth = testGen.makeShape(1)[0]
301
302 # The filter dimensions are OHWI
303 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
304
305 # The bias is OC
306 bias_shape = np.asarray([ofm_depth])
307
308 return [ifm_shape, filter_shape, bias_shape]
309
310 @staticmethod
311 def tgConv3D(testGen, op, rank, error_name=None):
312 pl, const = op["operands"]
313
314 if error_name != ErrorIf.WrongRank:
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
319
320 # Constrict the batch size?
321 if testGen.args.max_batch_size:
322 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
323
324 # Constrict the overall size of the shape when creating ERROR_IF tests
325 if error_name:
326 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
327 ifm_shape, max_dim=24, max_items=10000
328 )
329
330 # Get the filter depth/height/width from the operator parameters
331 filter_dhw = op["filter"]
332
333 # Generate a random OFM channel
334 ofm_channel = testGen.makeShape(1)[0]
335
336 # The filter dimensions are ODHWI
337 filter_shape = np.asarray(
338 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
339 )
340
341 # The bias is OC
342 bias_shape = np.asarray([ofm_channel])
343
344 return [ifm_shape, filter_shape, bias_shape]
345
346 @staticmethod
347 def tgTransposeConv2D(testGen, op, rank, error_name=None):
348 pl, const = op["operands"]
349
350 if error_name != ErrorIf.WrongRank:
351 assert rank == 4
352
353 # IFM dimensions are NHWC
354 ifm_shape = testGen.makeShape(rank)
355
356 # Constrict the batch size?
357 if testGen.args.max_batch_size:
358 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
359
360 # Constrict the overall size of the shape when creating ERROR_IF tests
361 if error_name:
362 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
363 ifm_shape, max_dim=24, max_items=10000
364 )
365
366 # Get the filter height/width from the operator parameters
367 filter_hw = op["filter"]
368
369 # Generate a random OFM depth
370 ofm_depth = testGen.makeShape(1)[0]
371
372 # The filter dimensions are OHWI
373 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
374
375 # The bias is OC
376 bias_shape = np.asarray([ofm_depth])
377
378 return [ifm_shape, filter_shape, bias_shape]
379
380 @staticmethod
381 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
382 pl, const = op["operands"]
383
384 if error_name != ErrorIf.WrongRank:
385 assert rank == 4
386 assert pl == 1 and const == 2
387
388 # IFM dimensions are NHWC
389 ifm_shape = testGen.makeShape(rank)
390
391 # Constrict the batch size?
392 if testGen.args.max_batch_size:
393 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
394
395 # Constrict the overall size of the shape when creating ERROR_IF tests
396 if error_name:
397 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
398 ifm_shape, max_dim=24, max_items=10000
399 )
400
401 # Get the filter height/width from the operator parameters
402 # Filter is KH, HW, C, M
403 filter_hw = op["filter"]
404
405 # Generate a random OFM depth, but don't let it get too big because
406 # the output depth is M * C
407 filter_m = (
408 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
409 ) + 1
410
411 # The filter dimensions are HWCM
412 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
413
414 # The bias is M * C
415 bias_shape = np.asarray([ifm_shape[3] * filter_m])
416
417 return [ifm_shape, filter_shape, bias_shape]
418
419 @staticmethod
420 def tgFullyConnected(testGen, op, rank, error_name=None):
421 pl, const = op["operands"]
422
423 if error_name != ErrorIf.WrongRank:
424 assert rank == 2
425
426 input_shape = testGen.makeShape(rank)
427
428 # Constrict the overall size of the shape when creating ERROR_IF tests
429 if error_name:
430 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
431
432 filter_oc = testGen.rng.integers(
433 low=testGen.args.tensor_shape_range[0],
434 high=testGen.args.tensor_shape_range[1],
435 size=1,
436 )[0]
437 filter_shape = np.asarray([filter_oc, input_shape[1]])
438
439 bias_shape = np.asarray([filter_oc])
440
441 return [input_shape, filter_shape, bias_shape]
442
443 @staticmethod
444 def tgMatmul(testGen, op, rank, error_name=None):
445 pl, const = op["operands"]
446
447 if error_name != ErrorIf.WrongRank:
448 assert rank == 3
449 assert pl == 2 and const == 0
450
451 a_shape = testGen.makeShape(rank)
452
453 # Constrict the overall size of the shape when creating ERROR_IF tests
454 if error_name:
455 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
456
457 # Get a random number for b_oc even if target shape is defined
458 b_oc = np.int32(
459 testGen.rng.integers(
460 low=testGen.args.tensor_shape_range[0],
461 high=testGen.args.tensor_shape_range[1],
462 size=1,
463 )
464 )[0]
465 # If N or H is large let b_oc be 1 to reduce output tensor size
466 if max(a_shape) > 1000:
467 b_oc = 1
468
469 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
470 return [a_shape, b_shape]
471
472 @staticmethod
473 def tgConcat(testGen, opName, rank, error_name=None):
474 pl, const = opName["operands"]
475 shape = testGen.makeShape(rank)
476
477 # Create extra tensors to concat.
478 # Take into account value of pl when getting maximum number of concats
479 num_tensors = testGen.randInt(0, 4)
480 shape_list = []
481 for i in range(pl + const + num_tensors):
482 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
483 remove = testGen.rng.choice([True, False])
484 wrongShape = shape.copy()
485
486 if remove and len(shape) > 1:
487 wrongShape = wrongShape[1:]
488 else:
489 wrongShape = list(wrongShape)
490 wrongShape.append(testGen.rng.integers(1, 10))
491
492 shape_list.append(wrongShape)
493 else:
494 shape_list.append(shape.copy())
495
496 return shape_list
497
498 @staticmethod
499 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
500 if error_name in [
501 ErrorIf.AxisSmallerZero,
502 ErrorIf.AxisLargerRank,
503 ErrorIf.ConcatInputRankMismatch,
504 ]:
505 return shapeList
506
507 # Split concat shape along axis to allow for multiple const inputs
508 # without making too many large tensors
509 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
510 # If axis can't be split we still need to invalidate other dimensions
511 if error_name == ErrorIf.ConcatInputDimMismatch:
512 for shape in shapeList[1:]:
513 # Negative test shapeLists are created individually for each test,
514 # so no need to copy the shape before altering it.
515 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
516 return shapeList
517
518 # Create copy of shape we are going to split (so we don't alter shapeList)
519 shape = shapeList[0].copy()
520 # Add original shape as first input
521 new_shapeList = [shape.copy()]
522 length_on_axis = shape[axis]
523 remaining_length = length_on_axis
524 for i in range(len(shapeList) - 2):
525 # Calculate split on axis and remaining value
526 split_shape_val = int(shape[axis] / 2)
527 remaining_length = remaining_length - split_shape_val
528
529 # Append new shape, and set remaining shape
530 shape[axis] = split_shape_val
531 new_shapeList.append(shape.copy())
532
533 # invalidate dimensions
534 if error_name == ErrorIf.ConcatInputDimMismatch:
535 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
536 else:
537 shape[axis] = remaining_length
538
539 if i == len(shapeList) - 3:
540 new_shapeList.append(shape.copy())
541
542 return new_shapeList
543
544
545class TosaTensorValuesGen:
546 """Tensor Value generators create the random data for each test."""
547
548 def __init__(self):
549 pass
550
551 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000552 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100553 pCount, cCount = op["operands"]
554
555 tens = []
556 tens.extend(
557 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
558 )
559 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
560
561 return tens
562
563 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000564 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100565 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100566 pCount, cCount = op["operands"]
567 assert (
568 pCount == 1 and cCount == 0
569 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100570 # Must create tensors with values within accumulator (int32) negatable
571 # range
572 max_val = (1 << 31) - 1
573 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100574 arr = np.int32(
575 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
576 )
577 placeholders = []
578 placeholders.append(
579 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
580 )
581 return placeholders
582 else:
583 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000584 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100585 )
586
587 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000588 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100589 if dtypeList[0] == DType.INT32 and error_name is None:
590 # Make sure the operation does not cause value saturation - where
591 # the number wraps due to limited number of bits to store the answer
592 pCount, cCount = op["operands"]
593 assert (
594 pCount == 2 and cCount == 0
595 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
596 placeholders = []
597 add = op["op"] == Op.ADD
598 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
599 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
600 if add:
601 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
602 else:
603 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
604
605 # Work out the saturation limits
606 max_i32 = (1 << 31) - 1
607 min_i32 = -(1 << 31)
608 max_arr = np.full(shapeList[1], max_i32)
609 min_arr = np.full(shapeList[1], min_i32)
610
611 # Find how much values exceed the maximum/minimums
612 sat_max_arr = np.maximum(res_arr - max_arr, 0)
613 sat_min_arr = np.minimum(res_arr - min_arr, 0)
614
615 if not add:
616 # Swap saturation values and negate values as we need to perform opposite operations
617 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
618
619 # Create new array of unsaturated values by clipping values as needed
620 b_unsat_arr = b_arr
621 if (sat_max_arr != 0).any():
622 # Clip values that cause saturation
623 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
624 # Reduce axes in unsaturated tensor to match original tensor
625 for axis, dim in enumerate(b_arr.shape):
626 if dim != b_unsat_arr.shape[axis]:
627 assert (
628 dim == 1
629 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
630 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
631
632 if (sat_min_arr != 0).any():
633 # Clip values that cause saturation
634 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
635 # Reduce axes in unsaturated tensor to match original tensor
636 for axis, dim in enumerate(b_arr.shape):
637 if dim != b_unsat_arr.shape[axis]:
638 assert (
639 dim == 1
640 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
641 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
642
643 placeholders.append(
644 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
645 )
646 placeholders.append(
647 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
648 )
649
650 return placeholders
651 else:
652 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000653 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100654 )
655
656 @staticmethod
657 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000658 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100659 ):
660 if dtypeList[0] in (
661 DType.INT32,
662 DType.INT16,
663 DType.INT8,
664 ):
665 # Limit input tensors with cond_if_binary or while_loop to stop
666 # saturation of add/sub ops with int32 and keep all logical shift
667 # values between 0 to 31 for int16 or int8
668 pCount, cCount = op["operands"]
669 pRemain = pCount
670 placeholders = []
671 for idx, shape in enumerate(shapeList[:]):
672 if dtypeList[0] == DType.INT32:
673 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
674 else:
675 arr = np.int32(
676 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
677 )
678 if pRemain > 0:
679 placeholders.append(
680 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
681 )
682 pRemain -= 1
683 else:
684 placeholders.append(
685 testGen.ser.addConst(shape, dtypeList[idx], arr)
686 )
687
688 return placeholders
689 else:
690 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000691 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100692 )
693
694 @staticmethod
695 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000696 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100697 ):
698 pCount, cCount = op["operands"]
699 # Force value of operand[1] to be within [0, num_bits]
700 assert (
701 pCount == 2 and cCount == 0
702 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
703
704 placeholders = []
705 for idx, shape in enumerate(shapeList[:]):
706 if idx == 1:
707 if dtypeList[idx] == DType.INT8:
708 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
709 elif dtypeList[idx] == DType.INT16:
710 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
711 elif dtypeList[idx] == DType.INT32:
712 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
713 elif error_name == ErrorIf.WrongInputType:
714 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
715 else:
716 raise Exception("OpArithmeticRightShift: invalid input dtype")
717 else:
718 arr = testGen.getRandTensor(shape, dtypeList[idx])
719 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
720
721 return placeholders
722
723 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000724 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100725 # Set datatype of condition tensor to boolean
726 dtypeList[0] = DType.BOOL
727
728 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000729 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100730 )
731
732 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000733 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100734 if error_name is None:
735 pCount, cCount = op["operands"]
736 assert (
737 pCount == 2 and cCount == 0
738 ), "Op.INTDIV must have 2 placeholders, 0 consts"
739
740 placeholders = []
741
742 # Two invalid cases for Op.INTDIV:
743 # 1. divisor == 0
744 # 2. dividend == -(1<<31) and divisor == -1
745 while True:
746 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
747 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
748
749 if (divisor_arr == 0).any():
750 continue
751
752 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
753 continue
754
755 break
756
757 placeholders.append(
758 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
759 )
760 placeholders.append(
761 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
762 )
763
764 return placeholders
765 else:
766 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000767 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100768 )
769
770 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000771 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100772 if error_name is None:
773 pCount, cCount = op["operands"]
774 assert (
775 pCount == 2 and cCount == 0
776 ), "Op.MUL must have 2 placeholders, 0 consts"
777
778 tens = []
James Ward8b390432022-08-12 20:48:56 +0100779 if dtypeList[0] in (DType.FP16, DType.FLOAT):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100780 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
781 else:
782 placeholders = []
783
784 # Make sure multiply result in int32 range
785 shift = testArgs[0]
786 if dtypeList[0] == DType.INT8:
787 num_bits = 8
788 elif dtypeList[0] == DType.INT16:
789 num_bits = 16
790 elif dtypeList[0] == DType.INT32:
791 num_bits = 32
792 elif error_name == ErrorIf.WrongInputType:
793 num_bits = 8
794 else:
795 raise Exception("OpMul: invalid input dtype")
796
797 for idx, shape in enumerate(shapeList[:]):
798 low = -(2 ** (num_bits - 1))
799 high = (2 ** (num_bits - 1)) - 1
800
801 a_arr = np.int32(
802 testGen.rng.integers(low=low, high=high, size=shapeList[0])
803 )
804 b_arr = np.int32(
805 testGen.rng.integers(low=low, high=high, size=shapeList[1])
806 )
807
808 i = 0
809 while True:
810
811 a_arr_64 = a_arr.astype(np.int64)
812 b_arr_64 = b_arr.astype(np.int64)
813
814 if shift > 0:
815 rounding = 1 << (shift - 1)
816 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
817 else:
818 result_arr = a_arr_64 * b_arr_64
819
820 if (result_arr > -(2**31)).all() and (
821 result_arr <= ((2**31) - 1)
822 ).all():
823 break
824
825 i = i + 1
826 a_arr = a_arr // 2
827 b_arr = b_arr // 2
828
829 placeholders.append(
830 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
831 )
832 placeholders.append(
833 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
834 )
835
836 tens.extend(placeholders)
837
838 return tens
839 else:
840 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000841 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100842 )
843
844 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000845 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100846 count = len(shapeList) - testGen.args.num_const_inputs_concat
847 if count < 1:
848 count = 1
849 if testGen.args.num_const_inputs_concat == 0:
850 count = len(shapeList)
851
852 # Ensure axis is an int
853 testArgs[0] = int(testArgs[0])
854
855 shapeList = TosaTensorGen.tgConcatConstInput(
856 testGen, shapeList, testArgs[0], error_name
857 )
858
859 tens = []
860 tens.extend(
861 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
862 )
863 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
864
865 return tens
866
867 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000868 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100869 pCount, cCount = op["operands"]
870 assert (
871 pCount == 2 and cCount == 0
872 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
873 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
874 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
875 placeholders = []
876 placeholders.append(
877 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
878 )
879 placeholders.append(
880 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
881 )
882
883 return placeholders
884
885 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000886 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100887 if error_name is None:
888 pCount, cCount = op["operands"]
889 assert (
890 pCount == 2 and cCount == 0
891 ), "Op.EQUAL must have 2 placeholders, 0 consts"
892 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
893 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
894 # Using random numbers means that it will be very unlikely that
895 # there are any matching (equal) values, therefore force that
896 # there are twice the number of matching values as the tensor rank
897 for num in range(0, len(shapeList[0]) * 2):
898 a_index = []
899 b_index = []
900 # Choose an index in each axis for the whole shape
901 for axis in range(0, len(shapeList[0])):
902 # Index can be up to the largest dimension in both shapes
903 index = np.int32(
904 testGen.rng.integers(
905 0, max(shapeList[0][axis], shapeList[1][axis])
906 )
907 )
908 # Reduce the index down to a shape's dim for broadcasting
909 a_index.append(min(shapeList[0][axis] - 1, index))
910 b_index.append(min(shapeList[1][axis] - 1, index))
911
912 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
913
914 placeholders = []
915 placeholders.append(
916 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
917 )
918 placeholders.append(
919 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
920 )
921 return placeholders
922 else:
923 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000924 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100925 )
926
927 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000928 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100929 if dtypeList[0] == DType.INT32:
930 pCount, cCount = op["operands"]
931 assert (
932 pCount == 1 and cCount == 0
933 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
934 # Limit values so that the sum cannot exceed the range of an int32 during
935 # summation of any axis
936 range_val = int((1 << 31) / max(shapeList[0]))
937 values_arr = np.int32(
938 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
939 )
940 placeholders = []
941 placeholders.append(
942 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
943 )
944 return placeholders
945 else:
946 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000947 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100948 )
949
950
951class TosaArgGen:
952 """Argument generators create exhaustive or random lists of attributes for
953 operators that take attributes or other parameters.
954
955 The return value is a list of (descriptive_name, [arglist]) tuples where
956 the descriptive_name is appended to the test name and the arglist is expanded
957 as arguments to the operator build function.
958 """
959
960 def __init__(self):
961 pass
962
963 @staticmethod
964 def agNone(testGen, opName, shapeList, dtype, error_name=None):
965 """A trivial argument generator for operators that don't take any
966 non-tensor arguments"""
967 return [("", [])]
968
969 @staticmethod
970 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
971 """Build the axis argument for operators that take a single axis"""
972 axes = []
973 shape = shapeList[0]
974
975 if error_name == ErrorIf.AxisSmallerZero:
976 small_axis = testGen.rng.integers(-5, 0)
977 axes.append(("axis{}".format(small_axis), [small_axis]))
978 elif error_name == ErrorIf.AxisLargerRank:
979 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
980 axes.append(("axis{}".format(large_axis), [large_axis]))
981 else:
982 for a in range(0, len(shape)):
983 axes.append(("axis{}".format(a), [a]))
984
985 return axes
986
987 @staticmethod
James Ward8b390432022-08-12 20:48:56 +0100988 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100989 arg_list = []
990
991 ifm_shape = shapeList[0]
992 filter_shape = shapeList[1]
993 # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
994 k = [int(x) for x in opName.split("_")[-1].split("x")]
995
James Ward8b390432022-08-12 20:48:56 +0100996 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
997
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100998 # Check the rank
999 rank = 5 if opName.startswith("conv3d") else 4
1000 if error_name != ErrorIf.WrongRank:
1001 assert len(ifm_shape) == rank
1002 assert len(filter_shape) == rank
1003
1004 # kernel rank omits batch and channels
1005 k_rank = rank - 2
1006 assert len(k) == k_rank
1007
1008 # Generate comprehensive argument lists
1009 # - except for named errors, which use specific invalid value(s)
1010 if error_name == ErrorIf.PadSmallerZero:
1011 p_vals = [testGen.rng.choice(range(-5, 0))]
1012 else:
1013 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1014 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1015 if error_name == ErrorIf.StrideSmallerOne:
1016 # Can't use stride=0, as it is used to derive output shape, as a divisor
1017 s_vals = [testGen.rng.choice(range(-5, 0))]
1018 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001019 # Stride must be greater than 1 to force non-integer error
Jeremy Johnson93d43902022-09-27 12:26:14 +01001020 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001021 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001022 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1023 if error_name == ErrorIf.DilationSmallerOne:
1024 d_vals = [testGen.rng.choice(range(-5, 1))]
1025 else:
1026 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1027 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
1028
1029 if not error_name and testGen.args.oversize:
1030 # add some oversize argument values
1031 if max(ifm_shape) < 64:
1032 bigPadding = 9
1033 paddings.update(
1034 {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
1035 )
1036 bigStride = 8
1037 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
1038 bigDilation = 7
1039 dilations.update(
1040 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1041 )
1042
1043 # There are too many parameter combinations, so generate them sparsely,
1044 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001045 sparsity_factor = 2 if error_name else 120
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001046 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
1047 # If there are only a small number of tests, just select them all
1048 if sparsity < 13:
1049 sparsity = 1
1050 # To get a variety of parameter combinations sparsity should not be a
1051 # multiple of 2, 3 or 5
1052 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1053 sparsity += 1
1054
1055 n = 0
1056 for s in sorted(list(strides)):
1057 for p in sorted(list(paddings)):
1058 for d in sorted(list(dilations)):
1059 if (
1060 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001061 # the padded shape must exceed the dilation * kernel to get a positive
1062 # sized output shape
1063 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k[0] - 1)
1064 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k[1] - 1)
1065 and (
1066 k_rank < 3
1067 or ((ifm_shape[3] - 1 + p[4] + p[5]) > d[2] * (k[2] - 1))
1068 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001069 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001070 remainders = []
1071 for index in range(k_rank):
1072 pad_offset = index * 2
1073 remainders.append(
1074 (
1075 ifm_shape[index + 1]
1076 - 1
1077 + p[pad_offset]
1078 + p[pad_offset + 1]
1079 - (k[index] - 1) * d[index]
1080 )
1081 % s[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001082 )
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001083 if (
1084 # the parameters must produce integer exact output
1085 error_name != ErrorIf.ConvOutputShapeNonInteger
1086 and max(remainders) == 0
1087 ) or (
1088 error_name == ErrorIf.ConvOutputShapeNonInteger
1089 and max(remainders) > 0
1090 ):
1091 arg_list.append(
1092 (
James Ward8b390432022-08-12 20:48:56 +01001093 "acc{}_st{}_pad{}_dilat{}".format(
1094 testGen.typeStr(accum_dtype),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001095 "".join([str(x) for x in s]),
1096 "".join([str(x) for x in p]),
1097 "".join([str(x) for x in d]),
1098 ),
James Ward8b390432022-08-12 20:48:56 +01001099 [accum_dtype, s, p, d],
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001100 )
1101 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001102 n += 1
1103
1104 return arg_list
1105
1106 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001107 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1108
1109 if isinstance(dtypes, list) or isinstance(dtypes, tuple):
1110 input_dtype = dtypes[0]
1111 else:
1112 input_dtype = dtypes
1113
1114 if error_name == ErrorIf.WrongOutputType:
1115 accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
1116 elif error_name == ErrorIf.WrongInputType:
1117 # Pick some potentially correct output dtype if input type is incorrect
1118 accum_dtype = DType.INT32
1119 else:
1120 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1121
1122 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1123
1124 @staticmethod
1125 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1126 # Get valid accumulate type(s)
1127 if dtype == DType.INT8:
1128 accum_dtypes = [DType.INT32]
1129 elif dtype == DType.INT16:
1130 accum_dtypes = [DType.INT48]
1131 elif dtype == DType.FP16:
1132 accum_dtypes = [DType.FP16, DType.FLOAT]
1133 elif dtype == DType.FLOAT:
1134 accum_dtypes = [DType.FLOAT]
1135 elif error_name is None:
1136 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1137
1138 if error_name == ErrorIf.WrongOutputType:
1139 # Get incorrect output dtype for ErrorIf case
1140 accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
1141 elif error_name == ErrorIf.WrongInputType:
1142 # Pick some potentially correct output dtype if input type is incorrect
1143 accum_dtypes = [DType.INT32]
1144
1145 return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
1146
1147 @staticmethod
1148 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001149 arg_list = []
1150
1151 ifm_shape = shapeList[0]
1152 filter_shape = shapeList[1]
1153
James Ward8b390432022-08-12 20:48:56 +01001154 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1155
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001156 # Must be rank 4
1157 if error_name != ErrorIf.WrongRank:
1158 assert len(ifm_shape) == 4
1159 assert len(filter_shape) == 4
1160
1161 # Generate comprehensive argument lists
1162 # - except for named errors, which use specific invalid value(s)
Eric Kunzec1a97832022-07-01 16:56:09 -07001163 smallest_padding_size = -min(filter_shape[1], filter_shape[2]) + 1
1164 if error_name == ErrorIf.PadLargerEqualKernel:
1165 max_filter_size = -max(filter_shape[1], filter_shape[2])
1166 p_vals = [testGen.rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001167 else:
Eric Kunzec1a97832022-07-01 16:56:09 -07001168 p_vals = [
1169 x
1170 for x in range(smallest_padding_size, testGen.args.max_conv_padding + 1)
1171 ]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001172 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001173 if error_name == ErrorIf.StrideSmallerOne:
1174 # Can't use stride=0, as it is used to derive output shape, as a divisor
1175 s_vals = [testGen.rng.choice(range(-5, 0))]
1176 else:
1177 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1178 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001179
Jeremy Johnson5860df62022-05-04 15:30:58 +01001180 if not error_name and testGen.args.oversize:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001181 # add some oversize argument values
1182 if max(ifm_shape) < 64:
1183 bigPadding = 9
1184 paddings.update(
Eric Kunzec1a97832022-07-01 16:56:09 -07001185 {
1186 x
1187 for x in itertools.product(
1188 *([[smallest_padding_size, bigPadding]] * 4)
1189 )
1190 }
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001191 )
1192 bigStride = 8
1193 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001194
1195 # There are too many parameter combinations, so generate them sparsely,
1196 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001197 sparsity_factor = 2 if error_name else 10
TatWai Chong24594f52022-06-08 00:48:04 -07001198 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001199 # If there are only a small number of tests, just select them all
1200 if sparsity < 13:
1201 sparsity = 1
1202 # To get a variety of parameter combinations sparsity should not be a
1203 # multiple of 2, 3 or 5
1204 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1205 sparsity += 1
1206
1207 n = 0
1208 for s in sorted(list(strides)):
1209 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001210 if n % sparsity == 0:
1211 # Determine the output shape
Eric Kunzec1a97832022-07-01 16:56:09 -07001212 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + filter_shape[1]
1213 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + filter_shape[2]
TatWai Chong24594f52022-06-08 00:48:04 -07001214 os = [ifm_shape[0], oh, ow, filter_shape[0]]
1215 arg_list.append(
1216 (
James Ward8b390432022-08-12 20:48:56 +01001217 "acc{}_st{}_pad{}_os{}".format(
1218 testGen.typeStr(accum_dtype),
TatWai Chong24594f52022-06-08 00:48:04 -07001219 "".join([str(x) for x in s]),
1220 "".join([str(x) for x in p]),
1221 "x".join([str(x) for x in os]),
1222 ),
James Ward8b390432022-08-12 20:48:56 +01001223 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001224 )
TatWai Chong24594f52022-06-08 00:48:04 -07001225 )
1226 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001227
1228 return arg_list
1229
1230 @staticmethod
1231 def agPad(testGen, opName, shapeList, dtype, error_name=None):
1232 arg_list = []
1233 rank = len(shapeList[0])
1234
1235 # Exhaustively test combinations of padding on each side of each dimension
1236 # - the range of padding values is defined by pad_min and pad_max
1237 # - for padding >9, the name format needs to be more distinctive
1238 pad_min, pad_max = 0, 1
1239 pad_values = [x for x in range(pad_min, pad_max + 1)]
1240 if error_name == ErrorIf.PadSmallerZero:
1241 pad_values = [x for x in range(-2, 0)]
1242 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1243 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1244
1245 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1246 pad_const_int = testGen.getRandNumberDType(dtype)
1247 pad_const_fp = 0
James Ward8b390432022-08-12 20:48:56 +01001248 elif dtype in (DType.FP16, DType.FLOAT):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001249 pad_const_int = 0
1250 pad_const_fp = testGen.getRandNumberDType(dtype)
1251 else:
1252 return []
1253
1254 for paddings in shape_pad_values:
James Ward8b390432022-08-12 20:48:56 +01001255 paddings = list(paddings)
1256 args_valid = True
1257
1258 if error_name == ErrorIf.PadSmallerZero:
1259 # Prevent negative output shapes while ensuring still testing for negative padding
1260 for i in range(rank):
1261 dim_after_padding = (
1262 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1263 )
1264 if dim_after_padding < 1:
1265 paddings[i] = (0, 0)
1266 if all([p > -1 for p in paddings[i]]):
1267 args_valid = False
1268
1269 if args_valid:
1270 name = "pad"
1271 for r in range(rank):
1272 before, after = paddings[r]
1273 name = f"{name}{before}{after}"
1274 arg_list.append(
1275 (name, [np.array(paddings), pad_const_int, pad_const_fp])
1276 )
1277
1278 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1279 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001280
1281 return arg_list
1282
1283 @staticmethod
1284 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1285 arg_list = []
1286
1287 shape = shapeList[0]
1288 if error_name != ErrorIf.WrongRank:
1289 assert len(shape) == 4
1290
1291 # Generate comprehensive argument lists
1292 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
1293 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001294 # Stride must be greater than 1 to force non-integer error
1295 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
1296 s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001297 strides = {x for x in itertools.product(*([s_vals] * 2))}
1298 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
1299 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1300
James Ward8b390432022-08-12 20:48:56 +01001301 if opName == "max_pool2d":
1302 accum_dtypes = [None] # max_pool has no accumulate dtype
1303 elif dtype == DType.INT8 or dtype == DType.INT16:
1304 accum_dtypes = [DType.INT32]
1305 elif dtype == DType.FP16:
1306 accum_dtypes = [DType.FP16, DType.FLOAT]
1307 elif dtype == DType.FLOAT:
1308 accum_dtypes = [DType.FLOAT]
1309 elif error_name is None:
1310 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1311 else:
1312 # Set to something for the ErrorIf case which has
1313 # incorrect input data-type
1314 accum_dtypes = [DType.INT32]
1315
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001316 if testGen.args.oversize:
1317 # add some oversize argument values
1318 bigStride = 7
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001319 strides.update(
1320 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
1321 )
1322 bigKernel = 9
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001323 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
1324 if max(shape) < 64:
1325 # padding must be less than the kernel size
1326 bigPadding = bigKernel - 1
1327 paddings.update(
1328 {x for x in itertools.product(*([[0, bigPadding]] * 4))}
1329 )
1330
1331 # There are too many parameter combinations, so generate them sparsely,
1332 # very sparse for negative tests
1333 sparsity_factor = 2 if error_name else 500
1334 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1335
James Ward8b390432022-08-12 20:48:56 +01001336 arg_str = (
1337 "acc{}_st{}_kern{}_pad{}"
1338 if accum_dtypes[0] is not None
1339 else "st{}_kern{}_pad{}"
1340 )
1341
1342 def get_arg_list_element(accum, stride, pad, kern):
1343 # Return tuple containing the formatted argument string and
1344 # the corresponding argument values
1345 arg_str_elems = [
1346 "".join([str(x) for x in stride]),
1347 "".join([str(x) for x in kern]),
1348 "".join([str(x) for x in pad]),
1349 ]
1350 # Note: different order to string
1351 arg_val_elems = [stride, pad, kern]
1352
1353 if accum is not None:
1354 arg_str_elems.insert(0, testGen.typeStr(accum))
1355 arg_val_elems.insert(0, accum)
1356 return (arg_str.format(*arg_str_elems), arg_val_elems)
1357
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001358 n = 0
James Ward8b390432022-08-12 20:48:56 +01001359 for a in accum_dtypes:
1360 for s in sorted(list(strides)):
1361 for p in sorted(list(paddings)):
1362 for k in sorted(list(kernels)):
1363 if error_name in [
1364 ErrorIf.StrideSmallerOne,
1365 ErrorIf.KernelSmallerOne,
1366 ErrorIf.PadSmallerZero,
1367 ErrorIf.PadLargerEqualKernel,
1368 ]:
1369 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1370 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001371 )
James Ward8b390432022-08-12 20:48:56 +01001372 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
1373 arg_vals = [a, sNew, pNew, kNew]
1374 arg_list.append(get_arg_list_element(*arg_vals))
1375 elif (
1376 n % sparsity == 0
1377 # padding must not exceed the kernel size
1378 and p[0] < k[0]
1379 and p[1] < k[0]
1380 and p[2] < k[1]
1381 and p[3] < k[1]
1382 # the padded shape must exceed the kernel size
1383 and (shape[1] + p[0] + p[1]) > k[0]
1384 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001385 ):
James Ward8b390432022-08-12 20:48:56 +01001386 remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
1387 remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
1388 if (
1389 # the parameters must produce integer exact output
1390 error_name != ErrorIf.PoolingOutputShapeNonInteger
1391 and remainder_h == 0
1392 and remainder_w == 0
1393 ) or (
1394 error_name == ErrorIf.PoolingOutputShapeNonInteger
1395 and (remainder_h != 0 or remainder_w != 0)
1396 ):
1397 arg_vals = [a, s, p, k]
1398 arg_list.append(get_arg_list_element(*arg_vals))
1399 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001400
1401 return arg_list
1402
1403 @staticmethod
1404 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1405 arg_list = []
1406
1407 # Enumerate the output types here
1408 if error_name == ErrorIf.WrongOutputType:
1409 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1410 elif inDtype == DType.INT8:
1411 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
1412 elif inDtype == DType.INT16:
1413 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
1414 elif inDtype == DType.INT32:
1415 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
1416 elif inDtype == DType.BOOL:
1417 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001418 elif inDtype == DType.FP16:
1419 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001420 elif inDtype == DType.FLOAT:
1421 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
1422 elif error_name == ErrorIf.WrongInputType:
1423 # Pick some potentially correct output type for incorrect input type
1424 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
1425 else:
1426 raise Exception("Unexpected input dtype: {}".format(inDtype))
1427
1428 for dtype in dtypeList:
1429 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
1430
1431 return arg_list
1432
1433 @staticmethod
1434 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1435 arg_list = []
1436
1437 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001438 for outDtype in [
1439 DType.UINT8,
1440 DType.INT8,
1441 DType.INT16,
1442 DType.INT32,
1443 DType.UINT16,
1444 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001445 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001446 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001447 and error_name == ErrorIf.OutputZeroPointNotZero
1448 ):
1449 continue
1450 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001451 outDtype != DType.UINT16
1452 and error_name == ErrorIf.U16OutputZeroPointNotValid
1453 ) or (
1454 inDtype != DType.UINT16
1455 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001456 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001457 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001458 continue
1459 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001460 inDtype == DType.UINT8
1461 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001462 and error_name != ErrorIf.WrongOutputType
1463 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001464 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1465 continue
1466 if (
1467 inDtype not in [DType.INT8, DType.INT16]
1468 and outDtype == DType.UINT8
1469 and error_name != ErrorIf.WrongOutputType
1470 ):
1471 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1472 continue
1473 if (
1474 inDtype == DType.UINT16
1475 and outDtype != DType.INT16
1476 and error_name != ErrorIf.WrongOutputType
1477 ):
1478 # The only output dtype for UINT16 is INT16, skip all others
1479 continue
1480 if (
1481 inDtype != DType.INT16
1482 and outDtype == DType.UINT16
1483 and error_name != ErrorIf.WrongOutputType
1484 ):
1485 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001486 continue
1487 if (
1488 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001489 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001490 ):
1491 continue
1492
1493 for scale32 in [False, True]:
1494 if error_name == ErrorIf.ScaleTrue and not scale32:
1495 continue
1496 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1497 continue
1498 for double_round in [False, True]:
1499 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1500 continue
1501 for per_channel in [False, True]:
1502
1503 if (
1504 inDtype == DType.INT48
1505 and scale32
1506 and error_name != ErrorIf.ScaleTrue
1507 ):
1508 # Illegal condition. Must be scale32=False
1509 continue
1510 if (
1511 double_round
1512 and not scale32
1513 and error_name != ErrorIf.ScaleNotTrue
1514 ):
1515 # Illegal condition. ERROR_IF(!scale32 && double_round)
1516 continue
1517
1518 arg_list.append(
1519 (
1520 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001521 DTypeNames[outDtype],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001522 int(scale32),
1523 int(double_round),
1524 int(per_channel),
1525 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001526 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001527 )
1528 )
1529
1530 return arg_list
1531
1532 @staticmethod
1533 def agMul(testGen, opName, shapeList, dtype, error_name=None):
1534 arg_list = []
1535
1536 if dtype is DType.INT32:
1537 for p in range(testGen.args.num_rand_permutations):
1538
1539 shift = testGen.randInt(0, 32)
1540
1541 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
1542 else:
1543 arg_list.append(("perm0_shift0", [0]))
1544
1545 return arg_list
1546
1547 @staticmethod
1548 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
1549 arg_list = []
1550
1551 arg_list.append(("roundTrue", [True]))
1552 arg_list.append(("roundFalse", [False]))
1553
1554 return arg_list
1555
1556 # Helper function for reshape. Gets some factors of a larger number.
1557 @staticmethod
1558 def getFactors(val, start=1):
1559 factors = []
1560
1561 for i in range(start, int(np.sqrt(val)) + 1):
1562 if (val % i) == 0:
1563 factors.append(i)
1564
1565 return factors
1566
1567 @staticmethod
1568 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
1569 arg_list = []
1570
1571 origShape = shapeList[0]
1572
1573 totalElements = 1
1574 for s in origShape:
1575 totalElements *= s
1576
1577 # This code is NOT fast. Fortunately, the numbers are fairly small.
1578 factors = TosaArgGen.getFactors(totalElements)
1579
1580 for p in range(testGen.args.num_rand_permutations):
1581 newRank = testGen.randInt(1, 7)
1582 if len(factors) < newRank:
1583 continue
1584
1585 found = True
1586 # escape_counter breaks while loop if it continues on for too long
1587 escape_counter = 0
1588 while found:
1589 newShape = []
1590 # Generate newShape ensuring it isn't a duplicate
1591 remainingElements = totalElements
1592 shuffledFactors = testGen.rng.permutation(factors)
1593 for i in range(1, newRank):
1594 # pick rank-1 factors
1595 newShape.append(shuffledFactors[0])
1596 remainingElements = remainingElements // shuffledFactors[0]
1597 shuffledFactors = testGen.rng.permutation(
1598 TosaArgGen.getFactors(remainingElements)
1599 )
1600 newShape.append(remainingElements)
1601
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001602 # Check for duplicates
1603 found = False
1604 for name, other_shape in arg_list:
1605 if other_shape[0] == newShape:
1606 found = True
1607 break
1608
1609 escape_counter += 1
1610 if escape_counter >= 100:
1611 break
1612
1613 if not found:
1614 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
1615
1616 return arg_list
1617
1618 @staticmethod
1619 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
1620 arg_list = []
1621
1622 ifm_shape = shapeList[0]
1623
1624 if error_name == ErrorIf.IndexOutsideBounds:
1625 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
1626 incorrect_small_index = range(-len(ifm_shape), 0)
1627 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1628 permutations.extend(
1629 [p for p in itertools.permutations(incorrect_small_index)]
1630 )
1631 elif error_name == ErrorIf.IndexUsedTwice:
1632 # Create list with a duplicated index
1633 perm_range = list(range(len(ifm_shape)))
1634 index_choice = testGen.rng.choice(range(len(perm_range)))
1635 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1636 permutations = [p for p in itertools.permutations(perm_range)]
1637
1638 else:
1639 # Get all permutations
1640 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
1641
1642 # Limit to possible permutations from shape dimension or argument setting
1643 limit = min(len(permutations), testGen.args.num_rand_permutations)
1644
1645 # Get random permutation generator that uses all permutations
1646 random_permutations = testGen.rng.permutation(permutations)
1647
1648 # Create list of required amount of permutations
1649 arg_list = [
1650 ("perm{}".format(p), [random_permutations[p].tolist()])
1651 for p in range(limit)
1652 ]
1653 return arg_list
1654
1655 @staticmethod
1656 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
1657 arg_list = []
1658
1659 ifm_shape = shapeList[0]
1660 rank = len(ifm_shape)
1661
1662 for p in range(testGen.args.num_rand_permutations):
1663 start = []
1664 size = []
1665
1666 valid = True
1667
1668 for i in range(rank):
1669 if ifm_shape[i] > 1:
1670 start.append(testGen.randInt(0, ifm_shape[i]))
1671 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
1672
1673 # Invalid slice size?
1674 if size[i] == 0:
1675 valid = False
1676 else:
1677 start.append(0)
1678 size.append(1)
1679
1680 if valid:
1681 # If ERROR_IF test required then incorrect start, size will be returned
1682 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
1683 testGen, error_name, ifm_shape, start, size
1684 )
1685 arg_list.append(("perm{}".format(p), [start, size]))
1686 return arg_list
1687
1688 @staticmethod
1689 def agTile(testGen, opName, shapeList, dtype, error_name=None):
1690 arg_list = []
1691
1692 ifm_shape = shapeList[0]
1693 rank = len(ifm_shape)
1694
1695 for p in range(testGen.args.num_rand_permutations):
1696
1697 # Pick a few random, but small multiple values
1698 # because otherwise this has a tendency to generate
1699 # enormous tensors
1700 multiples = []
1701 for i in range(rank):
1702 if ifm_shape[i] > 1000:
1703 # Multiple of 1 if ifm_shape dimension is large to reduce
1704 # tensor size
1705 multiples.append(1)
1706 elif max(ifm_shape) > 1000:
1707 multiples.append(2)
1708 else:
1709 multiples.append(testGen.randInt(1, 4))
1710 arg_list.append(("perm{}".format(p), [multiples]))
1711
1712 return arg_list
1713
1714 @staticmethod
1715 def agResize(testGen, opName, shapeList, dtype, error_name=None):
1716 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001717 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001718
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001719 def get_aspect_ratio_resize_params():
1720 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
1721 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
1722 invert = testGen.rng.choice((False, True))
1723 letterbox = testGen.rng.choice((False, True))
1724
1725 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
1726 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
1727 scale_y_d = scale_x_d = 1
1728 offset_x = offset_y = 0
1729
1730 if letterbox:
1731 max_border = scale_y_n
1732 border_y = testGen.randInt(low=0, high=max_border)
1733 border_x = 0
1734 else:
1735 # Pillarboxing
1736 border_y = 0
1737 max_border = scale_x_n
1738 border_x = testGen.randInt(low=0, high=max_border)
1739
1740 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1741 offset = (offset_y, offset_x)
1742 border = (border_y, border_x)
1743
1744 return scale, offset, border
1745
1746 def get_upscale_downscale_params():
1747 valid_params = False
1748 while not valid_params:
1749 upscale = testGen.rng.choice((False, True))
1750
1751 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
1752 origin_sampling = testGen.rng.choice((False, True))
1753
1754 if upscale:
1755 shift = testGen.randInt(low=1, high=4)
1756 scale_x_d = scale_y_d = 1
1757 scale_x_n = scale_y_n = (
1758 1 << shift if origin_sampling else 2 << shift
1759 )
1760 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
1761 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
1762 else:
1763 scale_x_n = 1
1764 scale_y_n = 1
1765
1766 # Return list of valid scale_*_d values (max value 4) given input dim shape
1767 def get_valid_denom(ifm_dim):
1768 return [x for x in range(1, 5) if ifm_dim % x == 1]
1769
1770 # Generate list of valid downscale values and choose one randomly
1771 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
1772 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
1773
1774 if not valid_scale_y_ds and not valid_scale_x_ds:
1775 # Bad parameters, skip
1776 continue
1777
1778 if not valid_scale_y_ds:
1779 scale_y_d = 1
1780 else:
1781 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
1782
1783 if not valid_scale_x_ds:
1784 scale_x_d = 1
1785 else:
1786 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
1787
1788 border_x = border_y = 0
1789 offset_y = testGen.randInt(0, 16 * scale_y_n)
1790 offset_x = testGen.randInt(0, 16 * scale_x_n)
1791 valid_params = True
1792
1793 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1794 offset = (offset_y, offset_x)
1795 border = (border_y, border_x)
1796 return scale, offset, border
1797
1798 def get_rand_params():
1799 # Scale
1800 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
1801 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
1802
1803 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
1804 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
1805
1806 # Offsets and border within the scale
1807 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
1808 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
1809 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
1810 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
1811
1812 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1813 offset = (offset_y, offset_x)
1814 border = (border_y, border_x)
1815 return scale, offset, border
1816
1817 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001818 # Exclude illegal {mode, type} configurations. Pick legal output types
1819 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1820 outputDTypeList = [DType.INT8]
1821 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1822 outputDTypeList = [DType.INT16]
1823 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1824 outputDTypeList = [DType.INT32]
1825 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1826 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01001827 elif dtype == DType.FP16:
1828 outputDTypeList = [DType.FP16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001829 elif dtype == DType.FLOAT:
1830 outputDTypeList = [DType.FLOAT]
1831 elif error_name == ErrorIf.WrongInputType:
1832 # If an incorrect input type is used then we set a 'correct'
1833 # output type to avoid other errors
1834 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
1835 else:
1836 continue
1837
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001838 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
1839
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001840 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001841 perm = 0
1842 while perm < testGen.args.num_rand_permutations:
1843 # Random choice of type of params we are testing
1844 _rnd_param_fn = testGen.rng.choice(
1845 (
1846 get_rand_params,
1847 get_upscale_downscale_params,
1848 get_aspect_ratio_resize_params,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001849 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001850 )
1851 scale, offset, border = _rnd_param_fn()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001852
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001853 # Expand params for bounds-checking
1854 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
1855 (offset_y, offset_x) = offset
1856 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001857
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001858 # Make sure output dimensions OH and OW are integers
1859 partial_output_y = (
1860 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
1861 )
1862 partial_output_x = (
1863 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
1864 )
1865 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
1866 if (
1867 partial_output_y % scale_y_d == 0
1868 and partial_output_x % scale_x_d == 0
1869 ):
1870 # Skip this test as it doesn't produce NonInteger output
1871 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001872 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001873 else:
1874 while partial_output_y % scale_y_d != 0:
1875 scale_y_d -= 1
1876 while partial_output_x % scale_x_d != 0:
1877 scale_x_d -= 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001878
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001879 output_y = partial_output_y // scale_y_d + 1
1880 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001881
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001882 if (
1883 output_y >= testGen.args.max_resize_output_dim
1884 or output_x >= testGen.args.max_resize_output_dim
1885 ) and error_name is None:
1886 # Skip positive test if output dim will be too high
1887 # Avoid high test latency and OOM issues
1888 perm += 1
1889 continue
1890
1891 if (
1892 output_y <= 0
1893 or output_y >= MAX_RESIZE_DIMENSION
1894 or output_x <= 0
1895 or output_x >= MAX_RESIZE_DIMENSION
1896 ):
1897 # Output dimensions out of scope
1898 if error_name is not None and perm > 0:
1899 # As long as we have one ERROR_IF test, don't worry
1900 # about creating all the other permutations
1901 perm += 1
1902 continue
1903
1904 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
1905 (
1906 output_y + scale_y_d >= MAX_RESIZE_DIMENSION
1907 and output_y - scale_y_d < 1
1908 )
1909 or (
1910 output_x + scale_x_d >= MAX_RESIZE_DIMENSION
1911 and output_x - scale_x_d < 1
1912 )
1913 ):
1914 # Can't create a negative test with these params as it
1915 # will create invalid output size
1916 if perm > 0:
1917 perm += 1
1918 continue
1919
1920 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
1921 offset = [offset_y, offset_x]
1922 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001923
1924 # Common for all data types
1925 if error_name is not None:
1926 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001927 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001928 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001929 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001930 outputDTypeNew,
1931 ) = TosaErrorIfArgGen.eiResizeErrorIf(
1932 testGen,
1933 error_name,
1934 mode,
1935 dtype,
1936 shapeList,
1937 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001938 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001939 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001940 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001941 )
1942 else:
1943 outputDTypeNew = outputDType
1944
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001945 arg_to_append = (
1946 arg_str.format(
1947 "N" if mode == ResizeMode.NEAREST else "B",
1948 testGen.typeStr(outputDTypeNew),
1949 scale[0],
1950 scale[1],
1951 scale[2],
1952 scale[3],
1953 offset[0],
1954 offset[1],
1955 border[0],
1956 border[1],
1957 ),
1958 [
1959 mode,
1960 scale,
1961 offset,
1962 border,
1963 dtype,
1964 outputDTypeNew,
1965 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001966 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001967 if arg_to_append in arg_list:
1968 # Skip already generated test params
1969 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001970
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001971 # Valid permutation
1972 perm += 1
1973 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001974 return arg_list
1975
1976 @staticmethod
1977 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1978 arg_list = []
1979
1980 if dtype == DType.INT8:
1981 table = np.int32(
1982 testGen.rng.integers(low=-128, high=128, size=[256])
1983 ).tolist()
1984 else: # INT16
1985 table = np.int32(
1986 testGen.rng.integers(low=-32768, high=32768, size=[513])
1987 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07001988 # Make sure all slopes are within REQUIRE min/max 16-bit int
1989 for idx in range(len(table) - 1):
1990 slope = table[idx + 1] - table[idx]
1991 # Alter the next table entry to force the slope to be ok
1992 if slope > 32767:
1993 table[idx + 1] -= slope - 32767
1994 if slope < -32768:
1995 table[idx + 1] -= slope + 32768
1996 slope = table[idx + 1] - table[idx]
1997 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001998 arg_list.append(
1999 (
2000 "",
2001 [table],
2002 )
2003 )
2004 return arg_list
2005
2006 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2007 # CondIf generates the condition values here.
2008 # Convert to tensors in the build function, along with the
2009 # then and else blocks
2010 arg_list = []
2011
2012 for c in [False, True]:
2013 arg_list.append(("cond{}".format(int(c)), [c]))
2014
2015 return arg_list
2016
2017 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2018 # While loop: 0 iterations, 1, more than 1
2019 arg_list = []
2020
2021 for iter in [0, 1, 4]:
2022 arg_list.append(("iter{}".format(iter), [iter]))
2023
2024 return arg_list