blob: 69968d3af5a9a20d3ce146b10f860ecd005dcf52 [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_error_if import ErrorIf
9from generator.tosa_error_if import TosaErrorIfArgGen
James Ward8b390432022-08-12 20:48:56 +010010from generator.tosa_utils import get_accum_dtype_from_tgTypes
11from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010012from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from serializer.tosa_serializer import DTypeNames
14from tosa.DType import DType
15from tosa.Op import Op
16from tosa.ResizeMode import ResizeMode
17
18# DTypeNames, DType, Op and ResizeMode are convenience variables to the
19# flatc-generated types that should be enums, but aren't
20
21
22class TosaQuantGen:
23 """QuantizedInfo random generator helper functions.
24
25 Specify with 'qgen': in the operator defintion.
26 """
27
28 def __init__(self):
29 pass
30
31 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000032 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010033
34 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010035 if testGen.args.zeropoint is not None:
36 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010037 return testGen.randInt(-128, 128)
38 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010039 if testGen.args.zeropoint is not None:
40 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010041 return testGen.randInt(0, 256)
42 elif error_name in [
43 ErrorIf.InputZeroPointNotZero,
44 ErrorIf.WeightZeroPointNotZero,
45 ErrorIf.OutputZeroPointNotZero,
46 ]:
47 zero_point = testGen.randInt(-128, 128)
48 if zero_point == 0:
49 zero_point = 1
50 return zero_point
51 return 0
52
53 @staticmethod
54 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010055 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000056 qinfo = [
57 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
58 TosaQuantGen.getZeroPoint(testGen, dtype),
59 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010060 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000061 qinfo = [
62 TosaQuantGen.getZeroPoint(testGen, dtype),
63 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
64 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010065 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000066 qinfo = [
67 TosaQuantGen.getZeroPoint(testGen, dtype),
68 TosaQuantGen.getZeroPoint(testGen, dtype),
69 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010074 if isinstance(dtype_or_dtypeList, list):
75 # a list of [input, weights, accumulator] dtypes
76 dtypeList = dtype_or_dtypeList
77 else:
78 # an int, [input, weights, accumulator] dtypes are the same
79 dtypeList = [dtype_or_dtypeList] * 3
80
81 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000082 qinfo = [
83 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
84 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
85 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010086 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000087 qinfo = [
88 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
89 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
90 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010091 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000092 qinfo = [
93 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
94 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
95 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 return qinfo
97
98 @staticmethod
99 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000101 qinfo = [
102 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
103 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
104 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100105 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000106 qinfo = [
107 TosaQuantGen.getZeroPoint(testGen, dtype),
108 TosaQuantGen.getZeroPoint(testGen, dtype),
109 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100110 return qinfo
111
112 @staticmethod
113 def computeMultiplierAndShift(scaleFp, scale32):
114 # Derived from computeMultiplierAndShiftTosaScale32
115 # Provide a floating-point scaling factor and the scale32 parameter
116 # to compute the multiplier and shift
117
118 if scale32:
119 scaleBits = 31
120 else:
121 scaleBits = 15
122
123 m, shift = math.frexp(scaleFp)
124
125 if scaleFp < 0.0:
126 m = -m
127
128 multiplier = round(m * (1 << scaleBits))
129 assert multiplier <= (1 << scaleBits)
130
131 if multiplier == (1 << scaleBits):
132 multiplier = multiplier // 2
133 shift = shift + 1
134
135 shift = (-shift) + scaleBits
136 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
137 # scaleFp, scaleBits, m, multiplier, shift))
138
139 # Adjust multiplier such that shift is in allowed value range.
140 if shift == 0:
141 multiplier = multiplier // 4
142 shift = shift + 2
143 elif shift == 1:
144 multiplier = multiplier // 2
145 shift = shift + 1
146 elif shift == 63:
147 multiplier = multiplier * 2
148 shift = shift - 1
149
150 assert multiplier <= (1 << scaleBits)
151 assert shift >= 2 and shift <= 62
152
153 return multiplier, shift
154
155
156class TosaTensorGen:
157 """Tensor generators create a shape list for the placeholder and const tensor
158 data operands for the operator.
159
160 The actual random data is generated separately for each test.
161 """
162
163 def __init__(self):
164 pass
165
166 @staticmethod
167 def tgBasic(testGen, opName, rank, error_name=None):
168 pl, const = opName["operands"]
169 shape = testGen.makeShape(rank)
170
171 # Constrict the overall size of the shape when creating ERROR_IF tests
172 if error_name:
173 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
174
175 shape_list = []
176 for i in range(pl + const):
177 shape_list.append(shape.copy())
178
179 if error_name == ErrorIf.RankMismatch:
180 if rank == 1 and i != 1:
181 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
182 elif i != 1:
183 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
184
185 return shape_list
186
187 @staticmethod
188 def tgNHWC(testGen, opName, rank, error_name=None):
189 pl, const = opName["operands"]
190
191 if error_name != ErrorIf.WrongRank:
192 assert rank == 4
193
194 shape = testGen.makeShape(rank)
195
196 # Constrict the batch size?
197 if testGen.args.max_batch_size:
198 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
211 def tgScatter(testGen, opName, rank, error_name=None):
212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
219 values_in_shape = testGen.makeShape(rank)
220
221 # ignore max batch size if target shape is set
222 if testGen.args.max_batch_size and not testGen.args.target_shapes:
223 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
224
225 W = testGen.randInt(
226 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
227 )
228 # Constrict W if one dimension is too large to keep tensor size reasonable
229 if max(values_in_shape) > 5000:
230 W = testGen.randInt(0, 16)
231
232 input_shape = [values_in_shape[0], W, values_in_shape[2]]
233
234 shape_list = []
235 shape_list.append(values_in_shape.copy())
236 shape_list.append(input_shape.copy())
237
238 return shape_list
239
240 @staticmethod
241 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
242 shape = testGen.makeShape(rank)
243
244 pl, const = op["operands"]
245
246 shape_list = []
247
248 # Choose one of the inputs to broadcast
249 # Note: Simplifies OutputShaper code if we don't change first shape for errors
250 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
251 for i in range(pl + const):
252 shape_bcast = shape.copy()
253
254 # If the chosen input, pick a random index to broadcast
255 if i == bcast_idx:
256 fuzz_idx = testGen.randInt(0, rank)
257 if error_name == ErrorIf.DimensionMismatch:
258 shape_bcast[fuzz_idx] += 1
259 elif error_name == ErrorIf.RankMismatch:
260 # Add one rank to the shape (or more for rank of 1)
261 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
262 shape_bcast = np.concatenate(
263 (shape_bcast, testGen.makeShape(extra_ranks))
264 )
265 if rank != 1:
266 # Either keep the extra rank, or remove it
267 new_len = testGen.rng.choice([-2, len(shape_bcast)])
268 shape_bcast = shape_bcast[:new_len]
269 else:
270 shape_bcast[fuzz_idx] = 1
271
272 shape_list.append(shape_bcast)
273
274 return shape_list
275
276 @staticmethod
277 def tgConv2D(testGen, op, rank, error_name=None):
278 pl, const = op["operands"]
279
280 if error_name != ErrorIf.WrongRank:
281 assert rank == 4
282
283 # IFM dimensions are NHWC
284 ifm_shape = testGen.makeShape(rank)
285
286 # Constrict the batch size?
287 if testGen.args.max_batch_size:
288 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
289
290 # Constrict the overall size of the shape when creating ERROR_IF tests
291 if error_name:
292 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
293 ifm_shape, max_dim=24, max_items=10000
294 )
295
296 # Get the filter height/width from the operator parameters
297 filter_hw = op["filter"]
298
299 # Generate a random OFM depth
300 ofm_depth = testGen.makeShape(1)[0]
301
302 # The filter dimensions are OHWI
303 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
304
305 # The bias is OC
306 bias_shape = np.asarray([ofm_depth])
307
308 return [ifm_shape, filter_shape, bias_shape]
309
310 @staticmethod
311 def tgConv3D(testGen, op, rank, error_name=None):
312 pl, const = op["operands"]
313
314 if error_name != ErrorIf.WrongRank:
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
319
320 # Constrict the batch size?
321 if testGen.args.max_batch_size:
322 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
323
324 # Constrict the overall size of the shape when creating ERROR_IF tests
325 if error_name:
326 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
327 ifm_shape, max_dim=24, max_items=10000
328 )
329
330 # Get the filter depth/height/width from the operator parameters
331 filter_dhw = op["filter"]
332
333 # Generate a random OFM channel
334 ofm_channel = testGen.makeShape(1)[0]
335
336 # The filter dimensions are ODHWI
337 filter_shape = np.asarray(
338 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
339 )
340
341 # The bias is OC
342 bias_shape = np.asarray([ofm_channel])
343
344 return [ifm_shape, filter_shape, bias_shape]
345
346 @staticmethod
347 def tgTransposeConv2D(testGen, op, rank, error_name=None):
348 pl, const = op["operands"]
349
350 if error_name != ErrorIf.WrongRank:
351 assert rank == 4
352
353 # IFM dimensions are NHWC
354 ifm_shape = testGen.makeShape(rank)
355
356 # Constrict the batch size?
357 if testGen.args.max_batch_size:
358 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
359
360 # Constrict the overall size of the shape when creating ERROR_IF tests
361 if error_name:
362 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
363 ifm_shape, max_dim=24, max_items=10000
364 )
365
366 # Get the filter height/width from the operator parameters
367 filter_hw = op["filter"]
368
369 # Generate a random OFM depth
370 ofm_depth = testGen.makeShape(1)[0]
371
372 # The filter dimensions are OHWI
373 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
374
375 # The bias is OC
376 bias_shape = np.asarray([ofm_depth])
377
378 return [ifm_shape, filter_shape, bias_shape]
379
380 @staticmethod
381 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
382 pl, const = op["operands"]
383
384 if error_name != ErrorIf.WrongRank:
385 assert rank == 4
386 assert pl == 1 and const == 2
387
388 # IFM dimensions are NHWC
389 ifm_shape = testGen.makeShape(rank)
390
391 # Constrict the batch size?
392 if testGen.args.max_batch_size:
393 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
394
395 # Constrict the overall size of the shape when creating ERROR_IF tests
396 if error_name:
397 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
398 ifm_shape, max_dim=24, max_items=10000
399 )
400
401 # Get the filter height/width from the operator parameters
402 # Filter is KH, HW, C, M
403 filter_hw = op["filter"]
404
405 # Generate a random OFM depth, but don't let it get too big because
406 # the output depth is M * C
407 filter_m = (
408 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
409 ) + 1
410
411 # The filter dimensions are HWCM
412 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
413
414 # The bias is M * C
415 bias_shape = np.asarray([ifm_shape[3] * filter_m])
416
417 return [ifm_shape, filter_shape, bias_shape]
418
419 @staticmethod
420 def tgFullyConnected(testGen, op, rank, error_name=None):
421 pl, const = op["operands"]
422
423 if error_name != ErrorIf.WrongRank:
424 assert rank == 2
425
426 input_shape = testGen.makeShape(rank)
427
428 # Constrict the overall size of the shape when creating ERROR_IF tests
429 if error_name:
430 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
431
432 filter_oc = testGen.rng.integers(
433 low=testGen.args.tensor_shape_range[0],
434 high=testGen.args.tensor_shape_range[1],
435 size=1,
436 )[0]
437 filter_shape = np.asarray([filter_oc, input_shape[1]])
438
439 bias_shape = np.asarray([filter_oc])
440
441 return [input_shape, filter_shape, bias_shape]
442
443 @staticmethod
444 def tgMatmul(testGen, op, rank, error_name=None):
445 pl, const = op["operands"]
446
447 if error_name != ErrorIf.WrongRank:
448 assert rank == 3
449 assert pl == 2 and const == 0
450
451 a_shape = testGen.makeShape(rank)
452
453 # Constrict the overall size of the shape when creating ERROR_IF tests
454 if error_name:
455 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
456
457 # Get a random number for b_oc even if target shape is defined
458 b_oc = np.int32(
459 testGen.rng.integers(
460 low=testGen.args.tensor_shape_range[0],
461 high=testGen.args.tensor_shape_range[1],
462 size=1,
463 )
464 )[0]
465 # If N or H is large let b_oc be 1 to reduce output tensor size
466 if max(a_shape) > 1000:
467 b_oc = 1
468
469 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
470 return [a_shape, b_shape]
471
472 @staticmethod
473 def tgConcat(testGen, opName, rank, error_name=None):
474 pl, const = opName["operands"]
475 shape = testGen.makeShape(rank)
476
477 # Create extra tensors to concat.
478 # Take into account value of pl when getting maximum number of concats
479 num_tensors = testGen.randInt(0, 4)
480 shape_list = []
481 for i in range(pl + const + num_tensors):
482 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
483 remove = testGen.rng.choice([True, False])
484 wrongShape = shape.copy()
485
486 if remove and len(shape) > 1:
487 wrongShape = wrongShape[1:]
488 else:
489 wrongShape = list(wrongShape)
490 wrongShape.append(testGen.rng.integers(1, 10))
491
492 shape_list.append(wrongShape)
493 else:
494 shape_list.append(shape.copy())
495
496 return shape_list
497
498 @staticmethod
499 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
500 if error_name in [
501 ErrorIf.AxisSmallerZero,
502 ErrorIf.AxisLargerRank,
503 ErrorIf.ConcatInputRankMismatch,
504 ]:
505 return shapeList
506
507 # Split concat shape along axis to allow for multiple const inputs
508 # without making too many large tensors
509 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
510 # If axis can't be split we still need to invalidate other dimensions
511 if error_name == ErrorIf.ConcatInputDimMismatch:
512 for shape in shapeList[1:]:
513 # Negative test shapeLists are created individually for each test,
514 # so no need to copy the shape before altering it.
515 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
516 return shapeList
517
518 # Create copy of shape we are going to split (so we don't alter shapeList)
519 shape = shapeList[0].copy()
520 # Add original shape as first input
521 new_shapeList = [shape.copy()]
522 length_on_axis = shape[axis]
523 remaining_length = length_on_axis
524 for i in range(len(shapeList) - 2):
525 # Calculate split on axis and remaining value
526 split_shape_val = int(shape[axis] / 2)
527 remaining_length = remaining_length - split_shape_val
528
529 # Append new shape, and set remaining shape
530 shape[axis] = split_shape_val
531 new_shapeList.append(shape.copy())
532
533 # invalidate dimensions
534 if error_name == ErrorIf.ConcatInputDimMismatch:
535 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
536 else:
537 shape[axis] = remaining_length
538
539 if i == len(shapeList) - 3:
540 new_shapeList.append(shape.copy())
541
542 return new_shapeList
543
544
545class TosaTensorValuesGen:
546 """Tensor Value generators create the random data for each test."""
547
548 def __init__(self):
549 pass
550
551 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000552 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100553 pCount, cCount = op["operands"]
554
555 tens = []
556 tens.extend(
557 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
558 )
559 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
560
561 return tens
562
563 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000564 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100565 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100566 pCount, cCount = op["operands"]
567 assert (
568 pCount == 1 and cCount == 0
569 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100570 # Must create tensors with values within accumulator (int32) negatable
571 # range
572 max_val = (1 << 31) - 1
573 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100574 arr = np.int32(
575 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
576 )
577 placeholders = []
578 placeholders.append(
579 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
580 )
581 return placeholders
582 else:
583 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000584 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100585 )
586
587 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000588 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100589 if dtypeList[0] == DType.INT32 and error_name is None:
590 # Make sure the operation does not cause value saturation - where
591 # the number wraps due to limited number of bits to store the answer
592 pCount, cCount = op["operands"]
593 assert (
594 pCount == 2 and cCount == 0
595 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
596 placeholders = []
597 add = op["op"] == Op.ADD
598 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
599 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
600 if add:
601 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
602 else:
603 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
604
605 # Work out the saturation limits
606 max_i32 = (1 << 31) - 1
607 min_i32 = -(1 << 31)
608 max_arr = np.full(shapeList[1], max_i32)
609 min_arr = np.full(shapeList[1], min_i32)
610
611 # Find how much values exceed the maximum/minimums
612 sat_max_arr = np.maximum(res_arr - max_arr, 0)
613 sat_min_arr = np.minimum(res_arr - min_arr, 0)
614
615 if not add:
616 # Swap saturation values and negate values as we need to perform opposite operations
617 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
618
619 # Create new array of unsaturated values by clipping values as needed
620 b_unsat_arr = b_arr
621 if (sat_max_arr != 0).any():
622 # Clip values that cause saturation
623 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
624 # Reduce axes in unsaturated tensor to match original tensor
625 for axis, dim in enumerate(b_arr.shape):
626 if dim != b_unsat_arr.shape[axis]:
627 assert (
628 dim == 1
629 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
630 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
631
632 if (sat_min_arr != 0).any():
633 # Clip values that cause saturation
634 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
635 # Reduce axes in unsaturated tensor to match original tensor
636 for axis, dim in enumerate(b_arr.shape):
637 if dim != b_unsat_arr.shape[axis]:
638 assert (
639 dim == 1
640 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
641 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
642
643 placeholders.append(
644 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
645 )
646 placeholders.append(
647 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
648 )
649
650 return placeholders
651 else:
652 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000653 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100654 )
655
656 @staticmethod
657 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000658 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100659 ):
660 if dtypeList[0] in (
661 DType.INT32,
662 DType.INT16,
663 DType.INT8,
664 ):
665 # Limit input tensors with cond_if_binary or while_loop to stop
666 # saturation of add/sub ops with int32 and keep all logical shift
667 # values between 0 to 31 for int16 or int8
668 pCount, cCount = op["operands"]
669 pRemain = pCount
670 placeholders = []
671 for idx, shape in enumerate(shapeList[:]):
672 if dtypeList[0] == DType.INT32:
673 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
674 else:
675 arr = np.int32(
676 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
677 )
678 if pRemain > 0:
679 placeholders.append(
680 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
681 )
682 pRemain -= 1
683 else:
684 placeholders.append(
685 testGen.ser.addConst(shape, dtypeList[idx], arr)
686 )
687
688 return placeholders
689 else:
690 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000691 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100692 )
693
694 @staticmethod
695 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000696 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100697 ):
698 pCount, cCount = op["operands"]
699 # Force value of operand[1] to be within [0, num_bits]
700 assert (
701 pCount == 2 and cCount == 0
702 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
703
704 placeholders = []
705 for idx, shape in enumerate(shapeList[:]):
706 if idx == 1:
707 if dtypeList[idx] == DType.INT8:
708 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
709 elif dtypeList[idx] == DType.INT16:
710 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
711 elif dtypeList[idx] == DType.INT32:
712 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
713 elif error_name == ErrorIf.WrongInputType:
714 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
715 else:
716 raise Exception("OpArithmeticRightShift: invalid input dtype")
717 else:
718 arr = testGen.getRandTensor(shape, dtypeList[idx])
719 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
720
721 return placeholders
722
723 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000724 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100725 # Set datatype of condition tensor to boolean
726 dtypeList[0] = DType.BOOL
727
728 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000729 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100730 )
731
732 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000733 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100734 if error_name is None:
735 pCount, cCount = op["operands"]
736 assert (
737 pCount == 2 and cCount == 0
738 ), "Op.INTDIV must have 2 placeholders, 0 consts"
739
740 placeholders = []
741
742 # Two invalid cases for Op.INTDIV:
743 # 1. divisor == 0
744 # 2. dividend == -(1<<31) and divisor == -1
745 while True:
746 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
747 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
748
749 if (divisor_arr == 0).any():
750 continue
751
752 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
753 continue
754
755 break
756
757 placeholders.append(
758 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
759 )
760 placeholders.append(
761 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
762 )
763
764 return placeholders
765 else:
766 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000767 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100768 )
769
770 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000771 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100772 if error_name is None:
773 pCount, cCount = op["operands"]
774 assert (
775 pCount == 2 and cCount == 0
776 ), "Op.MUL must have 2 placeholders, 0 consts"
777
778 tens = []
James Ward8b390432022-08-12 20:48:56 +0100779 if dtypeList[0] in (DType.FP16, DType.FLOAT):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100780 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
781 else:
782 placeholders = []
783
784 # Make sure multiply result in int32 range
785 shift = testArgs[0]
786 if dtypeList[0] == DType.INT8:
787 num_bits = 8
788 elif dtypeList[0] == DType.INT16:
789 num_bits = 16
790 elif dtypeList[0] == DType.INT32:
791 num_bits = 32
792 elif error_name == ErrorIf.WrongInputType:
793 num_bits = 8
794 else:
795 raise Exception("OpMul: invalid input dtype")
796
797 for idx, shape in enumerate(shapeList[:]):
798 low = -(2 ** (num_bits - 1))
799 high = (2 ** (num_bits - 1)) - 1
800
801 a_arr = np.int32(
802 testGen.rng.integers(low=low, high=high, size=shapeList[0])
803 )
804 b_arr = np.int32(
805 testGen.rng.integers(low=low, high=high, size=shapeList[1])
806 )
807
808 i = 0
809 while True:
810
811 a_arr_64 = a_arr.astype(np.int64)
812 b_arr_64 = b_arr.astype(np.int64)
813
814 if shift > 0:
815 rounding = 1 << (shift - 1)
816 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
817 else:
818 result_arr = a_arr_64 * b_arr_64
819
820 if (result_arr > -(2**31)).all() and (
821 result_arr <= ((2**31) - 1)
822 ).all():
823 break
824
825 i = i + 1
826 a_arr = a_arr // 2
827 b_arr = b_arr // 2
828
829 placeholders.append(
830 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
831 )
832 placeholders.append(
833 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
834 )
835
836 tens.extend(placeholders)
837
838 return tens
839 else:
840 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000841 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100842 )
843
844 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000845 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100846 count = len(shapeList) - testGen.args.num_const_inputs_concat
847 if count < 1:
848 count = 1
849 if testGen.args.num_const_inputs_concat == 0:
850 count = len(shapeList)
851
852 # Ensure axis is an int
853 testArgs[0] = int(testArgs[0])
854
855 shapeList = TosaTensorGen.tgConcatConstInput(
856 testGen, shapeList, testArgs[0], error_name
857 )
858
859 tens = []
860 tens.extend(
861 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
862 )
863 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
864
865 return tens
866
867 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000868 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100869 pCount, cCount = op["operands"]
870 assert (
871 pCount == 2 and cCount == 0
872 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
873 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
874 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
875 placeholders = []
876 placeholders.append(
877 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
878 )
879 placeholders.append(
880 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
881 )
882
883 return placeholders
884
885 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000886 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100887 if error_name is None:
888 pCount, cCount = op["operands"]
889 assert (
890 pCount == 2 and cCount == 0
891 ), "Op.EQUAL must have 2 placeholders, 0 consts"
892 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
893 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
894 # Using random numbers means that it will be very unlikely that
895 # there are any matching (equal) values, therefore force that
896 # there are twice the number of matching values as the tensor rank
897 for num in range(0, len(shapeList[0]) * 2):
898 a_index = []
899 b_index = []
900 # Choose an index in each axis for the whole shape
901 for axis in range(0, len(shapeList[0])):
902 # Index can be up to the largest dimension in both shapes
903 index = np.int32(
904 testGen.rng.integers(
905 0, max(shapeList[0][axis], shapeList[1][axis])
906 )
907 )
908 # Reduce the index down to a shape's dim for broadcasting
909 a_index.append(min(shapeList[0][axis] - 1, index))
910 b_index.append(min(shapeList[1][axis] - 1, index))
911
912 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
913
914 placeholders = []
915 placeholders.append(
916 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
917 )
918 placeholders.append(
919 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
920 )
921 return placeholders
922 else:
923 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000924 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100925 )
926
927 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000928 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100929 if dtypeList[0] == DType.INT32:
930 pCount, cCount = op["operands"]
931 assert (
932 pCount == 1 and cCount == 0
933 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
934 # Limit values so that the sum cannot exceed the range of an int32 during
935 # summation of any axis
936 range_val = int((1 << 31) / max(shapeList[0]))
937 values_arr = np.int32(
938 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
939 )
940 placeholders = []
941 placeholders.append(
942 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
943 )
944 return placeholders
945 else:
946 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000947 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100948 )
949
950
951class TosaArgGen:
952 """Argument generators create exhaustive or random lists of attributes for
953 operators that take attributes or other parameters.
954
955 The return value is a list of (descriptive_name, [arglist]) tuples where
956 the descriptive_name is appended to the test name and the arglist is expanded
957 as arguments to the operator build function.
958 """
959
960 def __init__(self):
961 pass
962
963 @staticmethod
964 def agNone(testGen, opName, shapeList, dtype, error_name=None):
965 """A trivial argument generator for operators that don't take any
966 non-tensor arguments"""
967 return [("", [])]
968
969 @staticmethod
970 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
971 """Build the axis argument for operators that take a single axis"""
972 axes = []
973 shape = shapeList[0]
974
975 if error_name == ErrorIf.AxisSmallerZero:
976 small_axis = testGen.rng.integers(-5, 0)
977 axes.append(("axis{}".format(small_axis), [small_axis]))
978 elif error_name == ErrorIf.AxisLargerRank:
979 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
980 axes.append(("axis{}".format(large_axis), [large_axis]))
981 else:
982 for a in range(0, len(shape)):
983 axes.append(("axis{}".format(a), [a]))
984
985 return axes
986
987 @staticmethod
James Ward8b390432022-08-12 20:48:56 +0100988 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100989 arg_list = []
990
991 ifm_shape = shapeList[0]
992 filter_shape = shapeList[1]
993 # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
994 k = [int(x) for x in opName.split("_")[-1].split("x")]
995
James Ward8b390432022-08-12 20:48:56 +0100996 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
997
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100998 # Check the rank
999 rank = 5 if opName.startswith("conv3d") else 4
1000 if error_name != ErrorIf.WrongRank:
1001 assert len(ifm_shape) == rank
1002 assert len(filter_shape) == rank
1003
1004 # kernel rank omits batch and channels
1005 k_rank = rank - 2
1006 assert len(k) == k_rank
1007
1008 # Generate comprehensive argument lists
1009 # - except for named errors, which use specific invalid value(s)
1010 if error_name == ErrorIf.PadSmallerZero:
1011 p_vals = [testGen.rng.choice(range(-5, 0))]
1012 else:
1013 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1014 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1015 if error_name == ErrorIf.StrideSmallerOne:
1016 # Can't use stride=0, as it is used to derive output shape, as a divisor
1017 s_vals = [testGen.rng.choice(range(-5, 0))]
1018 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001019 # Stride must be greater than 1 to force non-integer error
1020 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
1021 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001022 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1023 if error_name == ErrorIf.DilationSmallerOne:
1024 d_vals = [testGen.rng.choice(range(-5, 1))]
1025 else:
1026 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1027 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
1028
1029 if not error_name and testGen.args.oversize:
1030 # add some oversize argument values
1031 if max(ifm_shape) < 64:
1032 bigPadding = 9
1033 paddings.update(
1034 {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
1035 )
1036 bigStride = 8
1037 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
1038 bigDilation = 7
1039 dilations.update(
1040 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1041 )
1042
1043 # There are too many parameter combinations, so generate them sparsely,
1044 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001045 sparsity_factor = 2 if error_name else 120
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001046 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
1047 # If there are only a small number of tests, just select them all
1048 if sparsity < 13:
1049 sparsity = 1
1050 # To get a variety of parameter combinations sparsity should not be a
1051 # multiple of 2, 3 or 5
1052 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1053 sparsity += 1
1054
1055 n = 0
1056 for s in sorted(list(strides)):
1057 for p in sorted(list(paddings)):
1058 for d in sorted(list(dilations)):
1059 if (
1060 n % sparsity == 0
1061 # padding must not exceed the kernel size ?
1062 # and p[0] < k[0] and p[1] < k[0]
1063 # and p[2] < k[1] and p[3] < k[1]
1064 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
1065 # the padded shape must exceed the kernel size
1066 and (ifm_shape[1] + p[0] + p[1]) > k[0]
1067 and (ifm_shape[2] + p[2] + p[3]) > k[1]
1068 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
1069 # the padded shape must exceed the dilation
1070 and (ifm_shape[1] + p[0] + p[1]) > d[0]
1071 and (ifm_shape[2] + p[2] + p[3]) > d[1]
1072 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
1073 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001074 remainders = []
1075 for index in range(k_rank):
1076 pad_offset = index * 2
1077 remainders.append(
1078 (
1079 ifm_shape[index + 1]
1080 - 1
1081 + p[pad_offset]
1082 + p[pad_offset + 1]
1083 - (k[index] - 1) * d[index]
1084 )
1085 % s[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001086 )
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001087 if (
1088 # the parameters must produce integer exact output
1089 error_name != ErrorIf.ConvOutputShapeNonInteger
1090 and max(remainders) == 0
1091 ) or (
1092 error_name == ErrorIf.ConvOutputShapeNonInteger
1093 and max(remainders) > 0
1094 ):
1095 arg_list.append(
1096 (
James Ward8b390432022-08-12 20:48:56 +01001097 "acc{}_st{}_pad{}_dilat{}".format(
1098 testGen.typeStr(accum_dtype),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001099 "".join([str(x) for x in s]),
1100 "".join([str(x) for x in p]),
1101 "".join([str(x) for x in d]),
1102 ),
James Ward8b390432022-08-12 20:48:56 +01001103 [accum_dtype, s, p, d],
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001104 )
1105 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001106 n += 1
1107
1108 return arg_list
1109
1110 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001111 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1112
1113 if isinstance(dtypes, list) or isinstance(dtypes, tuple):
1114 input_dtype = dtypes[0]
1115 else:
1116 input_dtype = dtypes
1117
1118 if error_name == ErrorIf.WrongOutputType:
1119 accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
1120 elif error_name == ErrorIf.WrongInputType:
1121 # Pick some potentially correct output dtype if input type is incorrect
1122 accum_dtype = DType.INT32
1123 else:
1124 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1125
1126 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1127
1128 @staticmethod
1129 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1130 # Get valid accumulate type(s)
1131 if dtype == DType.INT8:
1132 accum_dtypes = [DType.INT32]
1133 elif dtype == DType.INT16:
1134 accum_dtypes = [DType.INT48]
1135 elif dtype == DType.FP16:
1136 accum_dtypes = [DType.FP16, DType.FLOAT]
1137 elif dtype == DType.FLOAT:
1138 accum_dtypes = [DType.FLOAT]
1139 elif error_name is None:
1140 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1141
1142 if error_name == ErrorIf.WrongOutputType:
1143 # Get incorrect output dtype for ErrorIf case
1144 accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
1145 elif error_name == ErrorIf.WrongInputType:
1146 # Pick some potentially correct output dtype if input type is incorrect
1147 accum_dtypes = [DType.INT32]
1148
1149 return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
1150
1151 @staticmethod
1152 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001153 arg_list = []
1154
1155 ifm_shape = shapeList[0]
1156 filter_shape = shapeList[1]
1157
James Ward8b390432022-08-12 20:48:56 +01001158 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1159
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001160 # Must be rank 4
1161 if error_name != ErrorIf.WrongRank:
1162 assert len(ifm_shape) == 4
1163 assert len(filter_shape) == 4
1164
1165 # Generate comprehensive argument lists
1166 # - except for named errors, which use specific invalid value(s)
Eric Kunzec1a97832022-07-01 16:56:09 -07001167 smallest_padding_size = -min(filter_shape[1], filter_shape[2]) + 1
1168 if error_name == ErrorIf.PadLargerEqualKernel:
1169 max_filter_size = -max(filter_shape[1], filter_shape[2])
1170 p_vals = [testGen.rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001171 else:
Eric Kunzec1a97832022-07-01 16:56:09 -07001172 p_vals = [
1173 x
1174 for x in range(smallest_padding_size, testGen.args.max_conv_padding + 1)
1175 ]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001176 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001177 if error_name == ErrorIf.StrideSmallerOne:
1178 # Can't use stride=0, as it is used to derive output shape, as a divisor
1179 s_vals = [testGen.rng.choice(range(-5, 0))]
1180 else:
1181 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1182 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001183
Jeremy Johnson5860df62022-05-04 15:30:58 +01001184 if not error_name and testGen.args.oversize:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001185 # add some oversize argument values
1186 if max(ifm_shape) < 64:
1187 bigPadding = 9
1188 paddings.update(
Eric Kunzec1a97832022-07-01 16:56:09 -07001189 {
1190 x
1191 for x in itertools.product(
1192 *([[smallest_padding_size, bigPadding]] * 4)
1193 )
1194 }
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001195 )
1196 bigStride = 8
1197 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001198
1199 # There are too many parameter combinations, so generate them sparsely,
1200 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001201 sparsity_factor = 2 if error_name else 10
TatWai Chong24594f52022-06-08 00:48:04 -07001202 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001203 # If there are only a small number of tests, just select them all
1204 if sparsity < 13:
1205 sparsity = 1
1206 # To get a variety of parameter combinations sparsity should not be a
1207 # multiple of 2, 3 or 5
1208 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1209 sparsity += 1
1210
1211 n = 0
1212 for s in sorted(list(strides)):
1213 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001214 if n % sparsity == 0:
1215 # Determine the output shape
Eric Kunzec1a97832022-07-01 16:56:09 -07001216 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + filter_shape[1]
1217 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + filter_shape[2]
TatWai Chong24594f52022-06-08 00:48:04 -07001218 os = [ifm_shape[0], oh, ow, filter_shape[0]]
1219 arg_list.append(
1220 (
James Ward8b390432022-08-12 20:48:56 +01001221 "acc{}_st{}_pad{}_os{}".format(
1222 testGen.typeStr(accum_dtype),
TatWai Chong24594f52022-06-08 00:48:04 -07001223 "".join([str(x) for x in s]),
1224 "".join([str(x) for x in p]),
1225 "x".join([str(x) for x in os]),
1226 ),
James Ward8b390432022-08-12 20:48:56 +01001227 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001228 )
TatWai Chong24594f52022-06-08 00:48:04 -07001229 )
1230 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001231
1232 return arg_list
1233
1234 @staticmethod
1235 def agPad(testGen, opName, shapeList, dtype, error_name=None):
1236 arg_list = []
1237 rank = len(shapeList[0])
1238
1239 # Exhaustively test combinations of padding on each side of each dimension
1240 # - the range of padding values is defined by pad_min and pad_max
1241 # - for padding >9, the name format needs to be more distinctive
1242 pad_min, pad_max = 0, 1
1243 pad_values = [x for x in range(pad_min, pad_max + 1)]
1244 if error_name == ErrorIf.PadSmallerZero:
1245 pad_values = [x for x in range(-2, 0)]
1246 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1247 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1248
1249 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1250 pad_const_int = testGen.getRandNumberDType(dtype)
1251 pad_const_fp = 0
James Ward8b390432022-08-12 20:48:56 +01001252 elif dtype in (DType.FP16, DType.FLOAT):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001253 pad_const_int = 0
1254 pad_const_fp = testGen.getRandNumberDType(dtype)
1255 else:
1256 return []
1257
1258 for paddings in shape_pad_values:
James Ward8b390432022-08-12 20:48:56 +01001259 paddings = list(paddings)
1260 args_valid = True
1261
1262 if error_name == ErrorIf.PadSmallerZero:
1263 # Prevent negative output shapes while ensuring still testing for negative padding
1264 for i in range(rank):
1265 dim_after_padding = (
1266 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1267 )
1268 if dim_after_padding < 1:
1269 paddings[i] = (0, 0)
1270 if all([p > -1 for p in paddings[i]]):
1271 args_valid = False
1272
1273 if args_valid:
1274 name = "pad"
1275 for r in range(rank):
1276 before, after = paddings[r]
1277 name = f"{name}{before}{after}"
1278 arg_list.append(
1279 (name, [np.array(paddings), pad_const_int, pad_const_fp])
1280 )
1281
1282 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1283 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001284
1285 return arg_list
1286
1287 @staticmethod
1288 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1289 arg_list = []
1290
1291 shape = shapeList[0]
1292 if error_name != ErrorIf.WrongRank:
1293 assert len(shape) == 4
1294
1295 # Generate comprehensive argument lists
1296 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
1297 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001298 # Stride must be greater than 1 to force non-integer error
1299 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
1300 s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001301 strides = {x for x in itertools.product(*([s_vals] * 2))}
1302 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
1303 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1304
James Ward8b390432022-08-12 20:48:56 +01001305 if opName == "max_pool2d":
1306 accum_dtypes = [None] # max_pool has no accumulate dtype
1307 elif dtype == DType.INT8 or dtype == DType.INT16:
1308 accum_dtypes = [DType.INT32]
1309 elif dtype == DType.FP16:
1310 accum_dtypes = [DType.FP16, DType.FLOAT]
1311 elif dtype == DType.FLOAT:
1312 accum_dtypes = [DType.FLOAT]
1313 elif error_name is None:
1314 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1315 else:
1316 # Set to something for the ErrorIf case which has
1317 # incorrect input data-type
1318 accum_dtypes = [DType.INT32]
1319
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001320 if testGen.args.oversize:
1321 # add some oversize argument values
1322 bigStride = 7
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001323 strides.update(
1324 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
1325 )
1326 bigKernel = 9
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001327 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
1328 if max(shape) < 64:
1329 # padding must be less than the kernel size
1330 bigPadding = bigKernel - 1
1331 paddings.update(
1332 {x for x in itertools.product(*([[0, bigPadding]] * 4))}
1333 )
1334
1335 # There are too many parameter combinations, so generate them sparsely,
1336 # very sparse for negative tests
1337 sparsity_factor = 2 if error_name else 500
1338 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1339
James Ward8b390432022-08-12 20:48:56 +01001340 arg_str = (
1341 "acc{}_st{}_kern{}_pad{}"
1342 if accum_dtypes[0] is not None
1343 else "st{}_kern{}_pad{}"
1344 )
1345
1346 def get_arg_list_element(accum, stride, pad, kern):
1347 # Return tuple containing the formatted argument string and
1348 # the corresponding argument values
1349 arg_str_elems = [
1350 "".join([str(x) for x in stride]),
1351 "".join([str(x) for x in kern]),
1352 "".join([str(x) for x in pad]),
1353 ]
1354 # Note: different order to string
1355 arg_val_elems = [stride, pad, kern]
1356
1357 if accum is not None:
1358 arg_str_elems.insert(0, testGen.typeStr(accum))
1359 arg_val_elems.insert(0, accum)
1360 return (arg_str.format(*arg_str_elems), arg_val_elems)
1361
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001362 n = 0
James Ward8b390432022-08-12 20:48:56 +01001363 for a in accum_dtypes:
1364 for s in sorted(list(strides)):
1365 for p in sorted(list(paddings)):
1366 for k in sorted(list(kernels)):
1367 if error_name in [
1368 ErrorIf.StrideSmallerOne,
1369 ErrorIf.KernelSmallerOne,
1370 ErrorIf.PadSmallerZero,
1371 ErrorIf.PadLargerEqualKernel,
1372 ]:
1373 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1374 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001375 )
James Ward8b390432022-08-12 20:48:56 +01001376 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
1377 arg_vals = [a, sNew, pNew, kNew]
1378 arg_list.append(get_arg_list_element(*arg_vals))
1379 elif (
1380 n % sparsity == 0
1381 # padding must not exceed the kernel size
1382 and p[0] < k[0]
1383 and p[1] < k[0]
1384 and p[2] < k[1]
1385 and p[3] < k[1]
1386 # the padded shape must exceed the kernel size
1387 and (shape[1] + p[0] + p[1]) > k[0]
1388 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001389 ):
James Ward8b390432022-08-12 20:48:56 +01001390 remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
1391 remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
1392 if (
1393 # the parameters must produce integer exact output
1394 error_name != ErrorIf.PoolingOutputShapeNonInteger
1395 and remainder_h == 0
1396 and remainder_w == 0
1397 ) or (
1398 error_name == ErrorIf.PoolingOutputShapeNonInteger
1399 and (remainder_h != 0 or remainder_w != 0)
1400 ):
1401 arg_vals = [a, s, p, k]
1402 arg_list.append(get_arg_list_element(*arg_vals))
1403 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001404
1405 return arg_list
1406
1407 @staticmethod
1408 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1409 arg_list = []
1410
1411 # Enumerate the output types here
1412 if error_name == ErrorIf.WrongOutputType:
1413 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1414 elif inDtype == DType.INT8:
1415 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
1416 elif inDtype == DType.INT16:
1417 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
1418 elif inDtype == DType.INT32:
1419 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
1420 elif inDtype == DType.BOOL:
1421 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001422 elif inDtype == DType.FP16:
1423 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001424 elif inDtype == DType.FLOAT:
1425 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
1426 elif error_name == ErrorIf.WrongInputType:
1427 # Pick some potentially correct output type for incorrect input type
1428 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
1429 else:
1430 raise Exception("Unexpected input dtype: {}".format(inDtype))
1431
1432 for dtype in dtypeList:
1433 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
1434
1435 return arg_list
1436
1437 @staticmethod
1438 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1439 arg_list = []
1440
1441 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001442 for outDtype in [
1443 DType.UINT8,
1444 DType.INT8,
1445 DType.INT16,
1446 DType.INT32,
1447 DType.UINT16,
1448 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001449 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001450 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001451 and error_name == ErrorIf.OutputZeroPointNotZero
1452 ):
1453 continue
1454 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001455 outDtype != DType.UINT16
1456 and error_name == ErrorIf.U16OutputZeroPointNotValid
1457 ) or (
1458 inDtype != DType.UINT16
1459 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001460 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001461 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001462 continue
1463 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001464 inDtype == DType.UINT8
1465 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001466 and error_name != ErrorIf.WrongOutputType
1467 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001468 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1469 continue
1470 if (
1471 inDtype not in [DType.INT8, DType.INT16]
1472 and outDtype == DType.UINT8
1473 and error_name != ErrorIf.WrongOutputType
1474 ):
1475 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1476 continue
1477 if (
1478 inDtype == DType.UINT16
1479 and outDtype != DType.INT16
1480 and error_name != ErrorIf.WrongOutputType
1481 ):
1482 # The only output dtype for UINT16 is INT16, skip all others
1483 continue
1484 if (
1485 inDtype != DType.INT16
1486 and outDtype == DType.UINT16
1487 and error_name != ErrorIf.WrongOutputType
1488 ):
1489 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001490 continue
1491 if (
1492 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001493 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001494 ):
1495 continue
1496
1497 for scale32 in [False, True]:
1498 if error_name == ErrorIf.ScaleTrue and not scale32:
1499 continue
1500 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1501 continue
1502 for double_round in [False, True]:
1503 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1504 continue
1505 for per_channel in [False, True]:
1506
1507 if (
1508 inDtype == DType.INT48
1509 and scale32
1510 and error_name != ErrorIf.ScaleTrue
1511 ):
1512 # Illegal condition. Must be scale32=False
1513 continue
1514 if (
1515 double_round
1516 and not scale32
1517 and error_name != ErrorIf.ScaleNotTrue
1518 ):
1519 # Illegal condition. ERROR_IF(!scale32 && double_round)
1520 continue
1521
1522 arg_list.append(
1523 (
1524 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001525 DTypeNames[outDtype],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001526 int(scale32),
1527 int(double_round),
1528 int(per_channel),
1529 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001530 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001531 )
1532 )
1533
1534 return arg_list
1535
1536 @staticmethod
1537 def agMul(testGen, opName, shapeList, dtype, error_name=None):
1538 arg_list = []
1539
1540 if dtype is DType.INT32:
1541 for p in range(testGen.args.num_rand_permutations):
1542
1543 shift = testGen.randInt(0, 32)
1544
1545 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
1546 else:
1547 arg_list.append(("perm0_shift0", [0]))
1548
1549 return arg_list
1550
1551 @staticmethod
1552 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
1553 arg_list = []
1554
1555 arg_list.append(("roundTrue", [True]))
1556 arg_list.append(("roundFalse", [False]))
1557
1558 return arg_list
1559
1560 # Helper function for reshape. Gets some factors of a larger number.
1561 @staticmethod
1562 def getFactors(val, start=1):
1563 factors = []
1564
1565 for i in range(start, int(np.sqrt(val)) + 1):
1566 if (val % i) == 0:
1567 factors.append(i)
1568
1569 return factors
1570
1571 @staticmethod
1572 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
1573 arg_list = []
1574
1575 origShape = shapeList[0]
1576
1577 totalElements = 1
1578 for s in origShape:
1579 totalElements *= s
1580
1581 # This code is NOT fast. Fortunately, the numbers are fairly small.
1582 factors = TosaArgGen.getFactors(totalElements)
1583
1584 for p in range(testGen.args.num_rand_permutations):
1585 newRank = testGen.randInt(1, 7)
1586 if len(factors) < newRank:
1587 continue
1588
1589 found = True
1590 # escape_counter breaks while loop if it continues on for too long
1591 escape_counter = 0
1592 while found:
1593 newShape = []
1594 # Generate newShape ensuring it isn't a duplicate
1595 remainingElements = totalElements
1596 shuffledFactors = testGen.rng.permutation(factors)
1597 for i in range(1, newRank):
1598 # pick rank-1 factors
1599 newShape.append(shuffledFactors[0])
1600 remainingElements = remainingElements // shuffledFactors[0]
1601 shuffledFactors = testGen.rng.permutation(
1602 TosaArgGen.getFactors(remainingElements)
1603 )
1604 newShape.append(remainingElements)
1605
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001606 # Check for duplicates
1607 found = False
1608 for name, other_shape in arg_list:
1609 if other_shape[0] == newShape:
1610 found = True
1611 break
1612
1613 escape_counter += 1
1614 if escape_counter >= 100:
1615 break
1616
1617 if not found:
1618 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
1619
1620 return arg_list
1621
1622 @staticmethod
1623 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
1624 arg_list = []
1625
1626 ifm_shape = shapeList[0]
1627
1628 if error_name == ErrorIf.IndexOutsideBounds:
1629 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
1630 incorrect_small_index = range(-len(ifm_shape), 0)
1631 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1632 permutations.extend(
1633 [p for p in itertools.permutations(incorrect_small_index)]
1634 )
1635 elif error_name == ErrorIf.IndexUsedTwice:
1636 # Create list with a duplicated index
1637 perm_range = list(range(len(ifm_shape)))
1638 index_choice = testGen.rng.choice(range(len(perm_range)))
1639 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1640 permutations = [p for p in itertools.permutations(perm_range)]
1641
1642 else:
1643 # Get all permutations
1644 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
1645
1646 # Limit to possible permutations from shape dimension or argument setting
1647 limit = min(len(permutations), testGen.args.num_rand_permutations)
1648
1649 # Get random permutation generator that uses all permutations
1650 random_permutations = testGen.rng.permutation(permutations)
1651
1652 # Create list of required amount of permutations
1653 arg_list = [
1654 ("perm{}".format(p), [random_permutations[p].tolist()])
1655 for p in range(limit)
1656 ]
1657 return arg_list
1658
1659 @staticmethod
1660 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
1661 arg_list = []
1662
1663 ifm_shape = shapeList[0]
1664 rank = len(ifm_shape)
1665
1666 for p in range(testGen.args.num_rand_permutations):
1667 start = []
1668 size = []
1669
1670 valid = True
1671
1672 for i in range(rank):
1673 if ifm_shape[i] > 1:
1674 start.append(testGen.randInt(0, ifm_shape[i]))
1675 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
1676
1677 # Invalid slice size?
1678 if size[i] == 0:
1679 valid = False
1680 else:
1681 start.append(0)
1682 size.append(1)
1683
1684 if valid:
1685 # If ERROR_IF test required then incorrect start, size will be returned
1686 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
1687 testGen, error_name, ifm_shape, start, size
1688 )
1689 arg_list.append(("perm{}".format(p), [start, size]))
1690 return arg_list
1691
1692 @staticmethod
1693 def agTile(testGen, opName, shapeList, dtype, error_name=None):
1694 arg_list = []
1695
1696 ifm_shape = shapeList[0]
1697 rank = len(ifm_shape)
1698
1699 for p in range(testGen.args.num_rand_permutations):
1700
1701 # Pick a few random, but small multiple values
1702 # because otherwise this has a tendency to generate
1703 # enormous tensors
1704 multiples = []
1705 for i in range(rank):
1706 if ifm_shape[i] > 1000:
1707 # Multiple of 1 if ifm_shape dimension is large to reduce
1708 # tensor size
1709 multiples.append(1)
1710 elif max(ifm_shape) > 1000:
1711 multiples.append(2)
1712 else:
1713 multiples.append(testGen.randInt(1, 4))
1714 arg_list.append(("perm{}".format(p), [multiples]))
1715
1716 return arg_list
1717
1718 @staticmethod
1719 def agResize(testGen, opName, shapeList, dtype, error_name=None):
1720 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001721 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001722
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001723 def get_aspect_ratio_resize_params():
1724 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
1725 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
1726 invert = testGen.rng.choice((False, True))
1727 letterbox = testGen.rng.choice((False, True))
1728
1729 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
1730 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
1731 scale_y_d = scale_x_d = 1
1732 offset_x = offset_y = 0
1733
1734 if letterbox:
1735 max_border = scale_y_n
1736 border_y = testGen.randInt(low=0, high=max_border)
1737 border_x = 0
1738 else:
1739 # Pillarboxing
1740 border_y = 0
1741 max_border = scale_x_n
1742 border_x = testGen.randInt(low=0, high=max_border)
1743
1744 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1745 offset = (offset_y, offset_x)
1746 border = (border_y, border_x)
1747
1748 return scale, offset, border
1749
1750 def get_upscale_downscale_params():
1751 valid_params = False
1752 while not valid_params:
1753 upscale = testGen.rng.choice((False, True))
1754
1755 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
1756 origin_sampling = testGen.rng.choice((False, True))
1757
1758 if upscale:
1759 shift = testGen.randInt(low=1, high=4)
1760 scale_x_d = scale_y_d = 1
1761 scale_x_n = scale_y_n = (
1762 1 << shift if origin_sampling else 2 << shift
1763 )
1764 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
1765 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
1766 else:
1767 scale_x_n = 1
1768 scale_y_n = 1
1769
1770 # Return list of valid scale_*_d values (max value 4) given input dim shape
1771 def get_valid_denom(ifm_dim):
1772 return [x for x in range(1, 5) if ifm_dim % x == 1]
1773
1774 # Generate list of valid downscale values and choose one randomly
1775 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
1776 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
1777
1778 if not valid_scale_y_ds and not valid_scale_x_ds:
1779 # Bad parameters, skip
1780 continue
1781
1782 if not valid_scale_y_ds:
1783 scale_y_d = 1
1784 else:
1785 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
1786
1787 if not valid_scale_x_ds:
1788 scale_x_d = 1
1789 else:
1790 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
1791
1792 border_x = border_y = 0
1793 offset_y = testGen.randInt(0, 16 * scale_y_n)
1794 offset_x = testGen.randInt(0, 16 * scale_x_n)
1795 valid_params = True
1796
1797 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1798 offset = (offset_y, offset_x)
1799 border = (border_y, border_x)
1800 return scale, offset, border
1801
1802 def get_rand_params():
1803 # Scale
1804 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
1805 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
1806
1807 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
1808 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
1809
1810 # Offsets and border within the scale
1811 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
1812 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
1813 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
1814 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
1815
1816 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1817 offset = (offset_y, offset_x)
1818 border = (border_y, border_x)
1819 return scale, offset, border
1820
1821 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001822 # Exclude illegal {mode, type} configurations. Pick legal output types
1823 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1824 outputDTypeList = [DType.INT8]
1825 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1826 outputDTypeList = [DType.INT16]
1827 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1828 outputDTypeList = [DType.INT32]
1829 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1830 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01001831 elif dtype == DType.FP16:
1832 outputDTypeList = [DType.FP16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001833 elif dtype == DType.FLOAT:
1834 outputDTypeList = [DType.FLOAT]
1835 elif error_name == ErrorIf.WrongInputType:
1836 # If an incorrect input type is used then we set a 'correct'
1837 # output type to avoid other errors
1838 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
1839 else:
1840 continue
1841
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001842 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
1843
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001844 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001845 perm = 0
1846 while perm < testGen.args.num_rand_permutations:
1847 # Random choice of type of params we are testing
1848 _rnd_param_fn = testGen.rng.choice(
1849 (
1850 get_rand_params,
1851 get_upscale_downscale_params,
1852 get_aspect_ratio_resize_params,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001853 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001854 )
1855 scale, offset, border = _rnd_param_fn()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001856
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001857 # Expand params for bounds-checking
1858 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
1859 (offset_y, offset_x) = offset
1860 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001861
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001862 # Make sure output dimensions OH and OW are integers
1863 partial_output_y = (
1864 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
1865 )
1866 partial_output_x = (
1867 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
1868 )
1869 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
1870 if (
1871 partial_output_y % scale_y_d == 0
1872 and partial_output_x % scale_x_d == 0
1873 ):
1874 # Skip this test as it doesn't produce NonInteger output
1875 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001876 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001877 else:
1878 while partial_output_y % scale_y_d != 0:
1879 scale_y_d -= 1
1880 while partial_output_x % scale_x_d != 0:
1881 scale_x_d -= 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001882
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001883 output_y = partial_output_y // scale_y_d + 1
1884 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001885
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001886 if (
1887 output_y >= testGen.args.max_resize_output_dim
1888 or output_x >= testGen.args.max_resize_output_dim
1889 ) and error_name is None:
1890 # Skip positive test if output dim will be too high
1891 # Avoid high test latency and OOM issues
1892 perm += 1
1893 continue
1894
1895 if (
1896 output_y <= 0
1897 or output_y >= MAX_RESIZE_DIMENSION
1898 or output_x <= 0
1899 or output_x >= MAX_RESIZE_DIMENSION
1900 ):
1901 # Output dimensions out of scope
1902 if error_name is not None and perm > 0:
1903 # As long as we have one ERROR_IF test, don't worry
1904 # about creating all the other permutations
1905 perm += 1
1906 continue
1907
1908 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
1909 (
1910 output_y + scale_y_d >= MAX_RESIZE_DIMENSION
1911 and output_y - scale_y_d < 1
1912 )
1913 or (
1914 output_x + scale_x_d >= MAX_RESIZE_DIMENSION
1915 and output_x - scale_x_d < 1
1916 )
1917 ):
1918 # Can't create a negative test with these params as it
1919 # will create invalid output size
1920 if perm > 0:
1921 perm += 1
1922 continue
1923
1924 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
1925 offset = [offset_y, offset_x]
1926 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001927
1928 # Common for all data types
1929 if error_name is not None:
1930 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001931 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001932 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001933 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001934 outputDTypeNew,
1935 ) = TosaErrorIfArgGen.eiResizeErrorIf(
1936 testGen,
1937 error_name,
1938 mode,
1939 dtype,
1940 shapeList,
1941 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001942 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001943 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001944 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001945 )
1946 else:
1947 outputDTypeNew = outputDType
1948
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001949 arg_to_append = (
1950 arg_str.format(
1951 "N" if mode == ResizeMode.NEAREST else "B",
1952 testGen.typeStr(outputDTypeNew),
1953 scale[0],
1954 scale[1],
1955 scale[2],
1956 scale[3],
1957 offset[0],
1958 offset[1],
1959 border[0],
1960 border[1],
1961 ),
1962 [
1963 mode,
1964 scale,
1965 offset,
1966 border,
1967 dtype,
1968 outputDTypeNew,
1969 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001970 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001971 if arg_to_append in arg_list:
1972 # Skip already generated test params
1973 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001974
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001975 # Valid permutation
1976 perm += 1
1977 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001978 return arg_list
1979
1980 @staticmethod
1981 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1982 arg_list = []
1983
1984 if dtype == DType.INT8:
1985 table = np.int32(
1986 testGen.rng.integers(low=-128, high=128, size=[256])
1987 ).tolist()
1988 else: # INT16
1989 table = np.int32(
1990 testGen.rng.integers(low=-32768, high=32768, size=[513])
1991 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07001992 # Make sure all slopes are within REQUIRE min/max 16-bit int
1993 for idx in range(len(table) - 1):
1994 slope = table[idx + 1] - table[idx]
1995 # Alter the next table entry to force the slope to be ok
1996 if slope > 32767:
1997 table[idx + 1] -= slope - 32767
1998 if slope < -32768:
1999 table[idx + 1] -= slope + 32768
2000 slope = table[idx + 1] - table[idx]
2001 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002002 arg_list.append(
2003 (
2004 "",
2005 [table],
2006 )
2007 )
2008 return arg_list
2009
2010 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2011 # CondIf generates the condition values here.
2012 # Convert to tensors in the build function, along with the
2013 # then and else blocks
2014 arg_list = []
2015
2016 for c in [False, True]:
2017 arg_list.append(("cond{}".format(int(c)), [c]))
2018
2019 return arg_list
2020
2021 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2022 # While loop: 0 iterations, 1, more than 1
2023 arg_list = []
2024
2025 for iter in [0, 1, 4]:
2026 arg_list.append(("iter{}".format(iter), [iter]))
2027
2028 return arg_list