blob: 143f186597deef1b38a585c8eda1ed8c7979dbf6 [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson015c3552022-02-23 12:15:03 +00005import numpy as np
Jerry Gecb7201e2023-06-01 22:45:26 +00006from tosa.DType import DType
7
8DTYPE_ATTRIBUTES = {
9 DType.BOOL: {"str": "b", "width": 1},
10 DType.INT4: {"str": "i4", "width": 4},
11 DType.INT8: {"str": "i8", "width": 8},
12 DType.UINT8: {"str": "u8", "width": 8},
13 DType.INT16: {"str": "i16", "width": 16},
14 DType.UINT16: {"str": "u16", "width": 16},
15 DType.INT32: {"str": "i32", "width": 32},
16 DType.INT48: {"str": "i48", "width": 48},
17 DType.FP16: {"str": "f16", "width": 16},
18 DType.BF16: {"str": "bf16", "width": 16},
19 DType.FP32: {"str": "f32", "width": 32},
20}
Jeremy Johnson015c3552022-02-23 12:15:03 +000021
22
23class ArgGen:
24 """Argument generator functions. These functions take a shape and dtype to
25 create arguments for an operator. Methods are prefixed with 'ag' to make
26 search easy."""
27
28 def __init__(self):
29 pass
30
Jerry Gecb7201e2023-06-01 22:45:26 +000031 def typeStr(dtype):
32 if dtype in DTYPE_ATTRIBUTES:
33 return DTYPE_ATTRIBUTES[dtype]["str"]
34 else:
35 raise Exception("Unknown dtype, cannot convert to string: {}".format(dtype))
36
Jeremy Johnson015c3552022-02-23 12:15:03 +000037 @staticmethod
38 def agNone(op, shapes, rng):
39 """A trivial argument generator for operators that only take tensor
40 operands"""
41 return [("", [])]
42
43 # Build the axis argument for operators where we want to iterate over N axes
44 # as an argument
45 @staticmethod
46 def agAxes(op, shapes, rng):
47 axes = []
Won Jeonf9c0cee2023-09-18 16:32:45 -070048 if shapes == ():
49 axes.append(["_axis_0", [0]])
50 return axes
51
Jeremy Johnson015c3552022-02-23 12:15:03 +000052 for i in range(-len(shapes), len(shapes), 1):
53 if i >= 0:
54 axes.append(["_axis_{}".format(i), [i]])
55 else:
56 axes.append(["_axis_m{}".format(-i), [i]])
57 return axes
58
59 # Build the axis LIST argument for operators that take an axis list.
60 # This builds a list of each axis individually, plus one element
61 # that contains a list of all axes. Note that we need to pack the list in
62 # an additional list so that it isn't exploded when being passed to the
63 # build_operator function.
64 # tensor_arg_count not used
65 def agAxesList(op, shapes, rng):
66 axes = ArgGen.agAxes(op, shapes, rng)
67 axes_list = []
68 for desc, a in axes:
69 axes_list.append([desc, [a]])
70
71 axes_list.append(["_axisall", [list(range(len(shapes)))]])
72 axes_list.append(["_axisall_none", [None]])
73 return axes_list
74
75 def agAxesListKeepdims(op, shapes, rng):
76 axes = ArgGen.agAxes(op, shapes, rng)
77 axes_list = []
78 for desc, a in axes:
79 axes_list.append([desc + "_keep0", [a, False]])
80 # avoid trying to reduce an axis of shape 1, as the TFL converter
81 # will optimize away the entire reduction
82 if (a[0] >= 0 and shapes[a[0]] != 1) or (
83 a[0] < 0 and shapes[len(shapes) + a[0]] != 1
84 ):
85 axes_list.append([desc + "_keep1", [a, True]])
86
87 axes_list.append(["_axisall_keep0", [list(range(len(shapes))), False]])
88 axes_list.append(["_axisall_keep0_none", [None, False]])
89 # another instance where the reduce gets optimized out.
90 if len(shapes) != 1:
91 axes_list.append(["_axisall_keep1", [list(range(len(shapes))), True]])
92 axes_list.append(["_axisall_keep1_none", [None, True]])
93 # no longer test axis empty, as TFL converter optimizes the reduce out
94 return axes_list
95
96 # conv2d argument generators build the TF constants
97 def agConv2d(op, shapes, rng):
98 arg_list = []
99
100 # Must be rank 4
101 if len(shapes) < 4:
102 return arg_list
103
104 filter_h, filter_w = op["filter"]
105
106 # strides, padding, dilations,
107 for stride_h in [1, 2]:
108 for stride_w in [1, 2]:
109 for padding in ["SAME", "VALID"]:
110 for dilation_h in [1, 2]:
111 for dilation_w in [1, 2]:
112
113 # Disqualify argument combinations that would cause
114 # an illegal convolution
115
116 if (padding == "VALID") and (
117 (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
118 or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
119 ):
120 continue
121
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100122 if (
123 (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride_h
124 != 0
125 ) or (
126 (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride_w
127 != 0
128 ):
129 # Not an exact integer output
130 continue
131
Jeremy Johnson015c3552022-02-23 12:15:03 +0000132 arg_list.append(
133 [
134 "_st{}{}_pad{}_dilat{}{}".format(
135 stride_h,
136 stride_w,
137 padding,
138 dilation_h,
139 dilation_w,
140 ),
141 [
142 [stride_h, stride_w],
143 padding,
144 [dilation_h, dilation_w],
145 ],
146 ]
147 )
148 return arg_list
149
TatWai Chongfd629052022-07-25 04:01:58 +0000150 # conv3d argument generators build the TF constants
151 def agConv3d(op, shapes, rng):
152 arg_list = []
153
154 # input shape = [OC, KD, KH, KW, IC]
155 # Must be rank 5
156 if len(shapes) != 5:
157 return arg_list
158
159 if len(op["filter"]) < 3:
160 return arg_list
161
162 filter_d, filter_h, filter_w = op["filter"]
163
164 # strides, padding, dilations,
165 for stride_d in [1, 2]:
166 for stride_h in [1, 2]:
167 for stride_w in [1, 2]:
168 for padding in ["SAME", "VALID"]:
169 for dilation_d in [1, 2]:
170 for dilation_h in [1, 2]:
171 for dilation_w in [1, 2]:
172
173 # Disqualify argument combinations that would cause
174 # an illegal convolution
175 # fmt: off
176 if (padding == "VALID") and (
177 (shapes[1] - (filter_d - 1) * 2 - dilation_d) <= 0
178 or (shapes[2] - (filter_h - 1) * 2 - dilation_h) <= 0
179 or (shapes[3] - (filter_w - 1) * 2 - dilation_w) <= 0
180 ):
181 continue
182
183 if (
184 (shapes[1] - 1 - (filter_d - 1) * dilation_d) % stride_d
185 != 0
186 ) or (
187 (shapes[2] - 1 - (filter_h - 1) * dilation_h) % stride_h
188 != 0
189 ) or (
190 (shapes[3] - 1 - (filter_w - 1) * dilation_w) % stride_w
191 != 0
192 ):
193 # Not an exact integer output
194 continue
195 # fmt: on
196
197 # TODO investigate the error of `CPU implementation of Conv3D
198 # currently only supports dilated rates of 1.` from Tensorflow.
199 # Only test dilations = [1, 1, 1, 1, 1] for now.
200 if (
201 (dilation_d != 1)
202 or (dilation_h != 1)
203 or (dilation_w != 1)
204 ):
205 continue
206
207 # Tensorflow expects strides is a list of ints that has length >= 5.
208 # Strides and dilations in the batch and depth dimensions must be 1.
209 arg_list.append(
210 [
211 "_st{}{}{}{}{}_pad{}_dilat{}{}{}{}{}".format(
212 1,
213 stride_d,
214 stride_h,
215 stride_w,
216 1,
217 padding,
218 1,
219 dilation_d,
220 dilation_h,
221 dilation_w,
222 1,
223 ),
224 [
225 [1, stride_d, stride_h, stride_w, 1],
226 padding,
227 [
228 1,
229 dilation_d,
230 dilation_h,
231 dilation_w,
232 1,
233 ],
234 ],
235 ]
236 )
237 return arg_list
238
Jeremy Johnson015c3552022-02-23 12:15:03 +0000239 # conv2d argument generators build the TF constants
240 def agDepthwiseConv2d(op, shapes, rng):
241 arg_list = []
242
243 # Must be rank 4
244 if len(shapes) < 4:
245 return arg_list
246
247 filter_h, filter_w = op["filter"]
248
249 # strides, padding, dilations, Depthwise conv2d is the same as conv2d
250 # except that strides in h/w must be the same and the argument must be
251 # formatted as [1, stride_h, stride_w, 1] in TF.
252 for stride in [1, 2]:
253 for padding in ["SAME", "VALID"]:
254 for dilation_h in [1, 2]:
255 for dilation_w in [1, 2]:
256
257 # Disqualify argument combinations that would cause an illegal
258 # convolution
259
260 if (padding == "VALID") and (
261 (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
262 or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
263 ):
264 continue
265
266 # When dilation is used, stride must be 1x1 (TF rules)
267 if dilation_h > 1 or dilation_w > 1:
268 if stride > 1:
269 continue
270
271 # Dilation must evenly divide the tensor. Some of our inputs
272 # intentionally use odd-sized tensors.
273 if shapes[1] % dilation_h != 0 or shapes[2] % dilation_w != 0:
274 continue
275
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100276 if (
277 (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride != 0
278 ) or (
279 (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride != 0
280 ):
281 # Not an exact integer output
282 continue
283
Jeremy Johnson015c3552022-02-23 12:15:03 +0000284 arg_list.append(
285 [
286 "_st{}{}_pad{}_dilat{}{}".format(
287 stride, stride, padding, dilation_h, dilation_w
288 ),
289 [
290 [1, stride, stride, 1],
291 padding,
292 [dilation_h, dilation_w],
293 ],
294 ]
295 )
296 return arg_list
297
298 # conv2d argument generators build the TF constants
299 def agTransposeConv2d(op, shapes, rng):
300 arg_list = []
301
302 # Must be rank 4
303 if len(shapes) < 4:
304 return arg_list
305
306 filter_h, filter_w = op["filter"]
307
308 # strides, padding, dilations,
309 for stride_h in [1, 2]:
310 for stride_w in [1, 2]:
311 for padding in ["SAME", "VALID"]:
312 if padding == "SAME":
313 out_height = (shapes[1]) * stride_h
314 out_width = (shapes[2]) * stride_w
315 else: # padding == 'VALID'
316 out_height = (shapes[1] - 1) * stride_h + filter_h
317 out_width = (shapes[2] - 1) * stride_w + filter_w
318
319 output_shape = [shapes[0], out_height, out_width, shapes[3] * 2]
320 arg_list.append(
321 [
322 "_st{}{}_pad{}".format(stride_h, stride_w, padding),
323 [output_shape, [stride_h, stride_w], padding],
324 ]
325 )
326 return arg_list
327
328 def agPooling(op, shapes, rng):
329 arg_list = []
330
331 # Must be rank 4
332 if len(shapes) < 4:
333 return arg_list
334
335 for stride_h in [1, 2]:
336 for stride_w in [1, 2]:
337 for kernel_h in [1, 2]:
338 for kernel_w in [1, 2]:
339 for padding in ["SAME", "VALID"]:
340
341 if (padding == "VALID") and (
342 (shapes[1] % (kernel_h * stride_h) > 0)
343 or (shapes[2] % (kernel_w * stride_w) > 0)
344 or (shapes[1] <= kernel_h)
345 or (shapes[2] <= kernel_w)
346 ):
347 continue
348
349 if (padding == "SAME") and (
350 (shapes[1] < kernel_h) or (shapes[2] < kernel_w)
351 ):
352 continue
353
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100354 if ((shapes[1] - kernel_h) % stride_h != 0) or (
355 (shapes[2] - kernel_w) % stride_w != 0
356 ):
357 # Not an exact integer output
358 continue
359
Jerry Gecb7201e2023-06-01 22:45:26 +0000360 # Note: tf.nn.avg_pool2d API doesn't support setting accumtype
361 # setting a dummy value to the test name as an reminder
362 accum_dtype = ArgGen.typeStr(DType.INT32)
Jeremy Johnson015c3552022-02-23 12:15:03 +0000363 arg_list.append(
364 [
Jerry Gecb7201e2023-06-01 22:45:26 +0000365 "_st{}{}_pad{}_kern{}{}_acc{}".format(
366 stride_h,
367 stride_w,
368 padding,
369 kernel_h,
370 kernel_w,
371 accum_dtype,
Jeremy Johnson015c3552022-02-23 12:15:03 +0000372 ),
373 [
374 [stride_h, stride_w],
375 [kernel_h, kernel_w],
376 padding,
377 ],
378 ]
379 )
380 return arg_list
381
382 def getFactors(val, start=1):
383 factors = []
384 for i in range(start, int(np.sqrt(val))):
385 if (val % i) == 0:
386 factors.append(i)
387
388 return factors
389
390 def agReshape(op, shapes, rng):
391 # This is slow code. Fortunately, the numbers involved are small
392 arg_list = []
393
394 total_elements = 1
395 for s in shapes:
396 total_elements *= s
397
398 # Find integer factors of this shape
399 factors = ArgGen.getFactors(total_elements)
400
401 for rank in range(1, len(shapes) + 1):
402 if len(factors) < rank:
403 break
404
405 new_shape = []
406 remaining_elements = total_elements
407
408 # Randomly shuffle the factors and iteratively pick from the factors
409 # of the remaining elements
410 shuffled_factors = rng.permutation(factors)
411 for i in range(rank):
412 # Pick rank - 1 factors
413 new_shape.append(shuffled_factors[0])
414 remaining_elements = remaining_elements // shuffled_factors[0]
415 shuffled_factors = rng.permutation(
416 ArgGen.getFactors(remaining_elements)
417 )
418 new_shape.append(remaining_elements)
419
420 # Don't do no-op reshapes because TFLite optimizes out the op
421 if new_shape == list(shapes):
422 continue
423
424 arg_list.append(["_rank{}".format(rank), [new_shape]])
425
426 return arg_list
427
428 def agTranspose(op, shapes, rng):
429 arg_list = []
430
431 # Must have at least two dimensions to transpose
432 if (len(shapes)) < 2:
433 return arg_list
434
435 # Pick a bunch of random permutations
436 range_arr = np.arange(len(shapes))
437 for i in range(len(shapes)):
438 perm = rng.permutation(range_arr).astype(np.int32)
439 # print('\n shape {} permute{} perm: {} arr: {}'.format(shapes, i,
440 # perm, range_arr))
441 if np.allclose(perm, range_arr):
442 print("skipped")
443 continue
444 arg_list.append(["_permute{}".format(i), [perm]])
445
446 return arg_list
447
448 def agSlice(op, shapes, rng):
449 arg_list = []
450
451 rank = len(shapes)
452
453 if rank == 1 and shapes[0] == 1:
454 return arg_list
455
456 for i in range(4):
457 # Pick a few random start points, axes, and strides
458 start = np.empty((rank), dtype=int)
459 size = np.empty((rank), dtype=int)
460 for j in range(rank):
461 if shapes[j] > 2:
462 start[j] = rng.integers(0, shapes[j] - 2)
463 # print('j = {}: {} - {} - 1: {}'.format(j, shapes[j],
464 # start[j], shapes[j] - start[j] - 1))
465 size[j] = rng.integers(1, shapes[j] - start[j] - 1)
466 else:
467 start[j] = 0
468 size[j] = shapes[j]
469
470 arg_list.append(["_perm{}".format(i), [start, size]])
471
472 return arg_list
473
474 def agStridedSlice(op, shapes, rng):
475 arg_list = []
476
477 rank = len(shapes)
478
479 # Reference model is limited to rank=6 internally right now
480 if rank > 3:
481 return arg_list
482
483 if rank == 1 and shapes[0] == 1:
484 return arg_list
485
486 for i in range(4):
487 # Pick a few random begin points, axes, and strides
488 begin = np.empty((rank), dtype=int)
489 end = np.empty((rank), dtype=int)
490 strides = np.empty((rank), dtype=int)
491
492 begin_mask = rng.integers(0, (1 << (rank - 1)))
493 end_mask = rng.integers(0, (1 << (rank - 1)))
494
495 for j in range(rank):
496
497 if begin_mask & (1 << j) or shapes[j] < 2:
498 begin[j] = 0
499 else:
500 begin[j] = rng.integers(0, shapes[j] - 1)
501
502 if end_mask & (1 << j) or shapes[j] < 2 or (begin[j] + 2) >= shapes[j]:
503 end[j] = shapes[j]
504 else:
505 end[j] = rng.integers(begin[j] + 1, shapes[j] - 1)
506
507 possible_stride = ArgGen.getFactors(end[j] - begin[j], 2)
508
509 if not possible_stride:
510 strides[j] = 1
511 else:
512 strides[j] = rng.choice(possible_stride)
513
514 # Randomly set the masks, except ellipsis_mask and new_axis_mask
515 # which must be zero for now For begin/end mask this to work,
516 # strides must be adjusted to still be divsible...
517 ellipsis_mask = 0
518 new_axis_mask = 0
519
520 # if rng.choice([0, 1]) and rank > 1:
521 # new_axis_mask = 1 << rng.integers(0, rank - 1)
522 # else:
523 # new_axis_mask = 0
524
525 if rng.choice([0, 1]) and rank > 1:
526 shrink_axis_mask = 1 << rng.integers(0, rank - 1)
527 else:
528 shrink_axis_mask = 0
529
530 # Only one of these bits may be set. Prefer shrink_axis_mask
531 new_axis_mask = new_axis_mask & ~shrink_axis_mask
532
533 arg_list.append(
534 [
535 "_perm{}".format(i),
536 [
537 begin,
538 end,
539 strides,
540 begin_mask,
541 end_mask,
542 ellipsis_mask,
543 new_axis_mask,
544 shrink_axis_mask,
545 ],
546 ]
547 )
548
549 # print('Shape: {} begin={} end={} strides={} begin_mask={:x}
550 # end_mask={:x} new_axis_mask={:x} shrink_mask={:x}'.format(shapes,
551 # begin, end, strides, begin_mask, end_mask, new_axis_mask,
552 # shrink_axis_mask))
553
554 return arg_list
555
Tai Ly5345d262024-06-05 20:47:45 +0000556 # tf.stack axis can be [-rank(input)-1, rank(input)]
Jeremy Johnson015c3552022-02-23 12:15:03 +0000557 def agStack(op, shapes, rng):
558 axes = []
Tai Ly5345d262024-06-05 20:47:45 +0000559 for i in range(-len(shapes) - 1, len(shapes) + 1):
560 if i >= 0:
561 axes.append(["_axis_{}".format(i), [i]])
562 else:
563 axes.append(["_axis_m{}".format(-i), [i]])
Jeremy Johnson015c3552022-02-23 12:15:03 +0000564 return axes
565
TatWai Chongf7008da2022-09-09 09:35:40 +0000566 def agMirrorPad(op, shapes, rng):
567 arg_list = []
568
569 rank = len(shapes)
570 for mode in ["REFLECT", "SYMMETRIC"]:
571 for left in range(3):
572 for right in range(3):
573 paddings = np.zeros((rank, 2), dtype=np.int32)
574 is_valid = True
575
576 # Fill in the padding parameter if the values are valid on each dimension,
577 # otherwise drop that case.
578 for d in range(rank):
579 paddings[d, 0] = left
580 paddings[d, 1] = right
581
582 # In "REFLECT" mode, paddings must be no greater than tensor dim size - 1.
583 if mode == "REFLECT":
584 if (left > shapes[d] - 1) or (right > shapes[d] - 1):
585 is_valid = False
586 break
587
588 # In "SYMMETRIC" mode, paddings must be no greater than tensor dim size.
589 else:
590 if (left > shapes[d]) or (right > shapes[d]):
591 is_valid = False
592 break
593
594 if is_valid:
595 arg_list.append(
596 [
597 "_pad{}{}_{}".format(left, right, mode[0:3].lower()),
598 [paddings, mode],
599 ]
600 )
601 return arg_list
602
Jeremy Johnson015c3552022-02-23 12:15:03 +0000603 def agPad(op, shapes, rng):
604 arg_list = []
605
606 rank = len(shapes)
607 for left in range(3):
608 for right in range(3):
TatWai Chong2226f902023-02-22 18:38:01 -0800609 # Padding nothing in tensorflow lite causes the interpreter fail to set
610 # the input tensor properly due to date type mismatch.
611 if (left == 0) and (right == 0):
612 continue
Jeremy Johnson015c3552022-02-23 12:15:03 +0000613
TatWai Chong2226f902023-02-22 18:38:01 -0800614 # A simple way to generate explicit pad_const including zero.
615 pad_const = (left - right) * rng.integers(0, 5, dtype=np.int32)
616 padding = np.zeros((rank, 2), dtype=np.int32)
617 for d in range(rank):
618 padding[d, 0] = left
619 padding[d, 1] = right
620
621 arg_list.append(
622 ["_pad{}{}".format(left, right), [padding, pad_const]]
623 )
Jeremy Johnson015c3552022-02-23 12:15:03 +0000624 return arg_list
625
TatWai Chong0cef07e2023-02-27 13:22:52 -0800626 def agResize(op, shapes, rng):
627 args = []
628 for mode in ["nearest", "bilinear"]:
629 for align_corners in [True, False]:
630 for half_pixel in [True, False]:
631 # If half_pixel_centers is True, align_corners must be False.
Jerry Gec9376092023-03-08 17:50:49 +0000632 if (align_corners is True) and (half_pixel is True):
TatWai Chong0cef07e2023-02-27 13:22:52 -0800633 continue
634
635 for i in range(1, 4):
636 args.append(
637 [
638 "_{}_align{}_half{}_scale{}".format(
639 mode, int(align_corners), int(half_pixel), i
640 ),
641 [mode, align_corners, half_pixel, i],
642 ]
643 )
644 return args
645
Jeremy Johnson015c3552022-02-23 12:15:03 +0000646 def agFill(op, shapes, rng):
647 values = []
648 for i in range(4):
649 value = rng.integers(0, 10, dtype=np.int32)
650 values.append(["_value{}".format(value), [shapes, value]])
651 return values
652
653 def getValuesToSum(total, rng):
654 # Get a list of random integers that sum up to 'total'
655 vals = []
656
657 # np.random.randint() min and max to be different, so if the remainder
658 # is 1, give up
659 while total > 1:
660 vals.append(rng.integers(1, total))
661 total = total - vals[-1]
662
663 if total == 1:
664 vals.append(1)
665
666 return vals
667
668 def agSplit(op, shapes, rng):
669 arg_list = []
670
671 rank = len(shapes)
672
673 # Shuffle the random number generator a few more times to get
674 # a better range of axes across shapes
675 for i in range(rank):
676 for j in range(shapes[i]):
677 rng.integers(shapes[i])
678
679 for i in range(3):
680 # Need to generate tests for both the num_splits and size_vector versions.
681 axis = rng.choice(np.arange(0, rank))
682
683 # For num_splits, get a few divisors of the given axis
684 divs = ArgGen.getFactors(shapes[axis], 2)
685
686 if divs:
687 # Get no more than 2 samples
688 splits = list(rng.choice(divs, size=2))
689
690 for s in splits:
691 arg_list.append(
692 ["_split{}_axis{}".format(int(s), axis), [int(s), axis]]
693 )
694
695 # For vector splits, get a list of integers that sum up to the axis size
696 vals = ArgGen.getValuesToSum(shapes[axis], rng)
697
698 if len(vals) > 1:
699 arg_list.append(["_splitv_axis{}".format(axis), [vals, axis]])
700
701 return arg_list
702
703 def agTile(op, shapes, rng):
704 arg_list = []
705
706 rank = len(shapes)
707
708 # create 1D multiples list
709 multiples = list()
710 for i in range(rank):
711 multiples.append(rng.integers(1, 4))
712
713 multiples_str = "x".join(list(str(i) for i in multiples))
714
715 arg_list.append(["_tile_{}".format(multiples_str), [multiples]])
716
717 return arg_list
718
719 def agGather(op, shapes, rng):
720 args = []
721 for batch_dims in range(len(shapes) - 1):
722 for axis in range(batch_dims, len(shapes)):
723 # indices value must be within [0, shapes[i])
724
725 # Create an arbitrary shape for the indices
726 indices_rank = rng.integers(batch_dims + 1, 4)
727 indices_shape = rng.integers(1, 8, size=indices_rank)
728
729 # Copy in the batch dimensions because they must match
730 for b in range(batch_dims):
731 indices_shape[b] = shapes[b]
732
733 # Calculate total element count
734 indices_size = 1
735 for j in range(indices_rank):
736 indices_size = indices_shape[j] * indices_size
737
738 indices = rng.integers(0, shapes[axis], indices_size, np.int32).reshape(
739 indices_shape
740 )
741
742 args.append(
743 [
744 "_batchdims_{}_axis_{}".format(batch_dims, axis),
745 [indices, batch_dims, axis],
746 ]
747 )
748 return args
749
750 def agGatherND(op, shapes, rng):
751 args = []
752
753 for N in range(1, len(shapes) - 1):
754 # Rank includes the N dimension
755 indices_rank = rng.integers(2, 4, size=1)[0]
756 indices_shape = []
757
758 indices_shape = rng.integers(1, 8, size=indices_rank)
759 indices_shape[-1] = N
760
761 indices_count = 1
762 for i in range(indices_rank - 1):
763 indices_count = indices_count * indices_shape[i]
764
765 indices_list = np.zeros(shape=(indices_count, N), dtype=np.int32)
766
767 for i in range(indices_count):
768 for j in range(N):
769 indices_list[i, j] = rng.integers(0, shapes[j], size=1)[0]
770
771 indices = indices_list.reshape(indices_shape)
772
773 args.append(["_n{}".format(N), [indices]])
774
775 return args
776
777 def agScatterND(op, shapes, rng):
778 args = []
779
780 # ScatterND has to generate a constant shapes tensor, indices
781 # tensor, and a tensor of updates. Unforunately, the updates
782 # need to be a size that's based on the N generated in this
783 # function and the dtype known only in the TensorGen function,
784 # but not in ArgGen.
785 #
786 # There are many bad ways to solve this and we'll choose the
787 # least of the evils which still gives reasonable coverage of
788 # the possible operand shapes.
789 for N in range(1, len(shapes)):
790 # Rank includes the N dimension
791 indices_rank = rng.integers(2, 4, size=1)[0]
792 indices_shape = []
793
794 indices_shape = rng.integers(1, 8, size=indices_rank)
795 indices_shape[-1] = N
796
797 # Store the Shapes, and the indicies value tensor as arguments.
798 args.append(["_n{}".format(N), [shapes, indices_shape, N, rng]])
799
800 return args
801
802 def agSpaceToBatch(op, shapes, rng):
803 batch_rank = 1
804 channel_rank = 1
805 block_rank = len(shapes) - batch_rank - channel_rank
806
807 # must have at least rank 1 (M) block
808 if block_rank < 1:
809 return []
810
811 args = []
812 block_shape = []
813 padding_shape = []
814
815 for i in range(block_rank):
816 block_size = 2
817 padding_size = block_size - (shapes[i + 1] % block_size)
818 block_shape.append(block_size)
819 padding_shape.append([0, padding_size])
820
821 args.append(["_blockrank_{}".format(block_rank), [block_shape, padding_shape]])
822 return args
823
824 def agBatchToSpace(op, shapes, rng):
825 batch_rank = 1
826 channel_rank = 1
827 block_rank = len(shapes) - batch_rank - channel_rank
828
829 # must have at least rank 1 (M) block
830 if block_rank < 1:
831 return []
832
833 args = []
834 block_shape = []
835 padding_shape = []
836 block_prod = 1
837
838 for i in range(block_rank):
839 block_size = 2
840 block_prod = block_prod * block_size
841 crop_size = 0
842 block_shape.append(block_size)
843 padding_shape.append([0, crop_size])
844
845 # batch / prod(block_shape[i]) must be integer
846 # transpose to swap depth and batch. so shape[-1] would be batch dim
847 if shapes[-1] % block_prod == 0:
848 args.append(
849 ["_blockrank_{}".format(block_rank), [block_shape, padding_shape]]
850 )
851
852 return args
853
854 def agSpaceToDepth(op, shapes, rng):
855 # must be rank 4 input tensor
856 if len(shapes) != 4:
857 return []
858
859 block_size = 2
860
861 # spatial dimension must be divisible by block_size
862 if shapes[1] % block_size != 0 or shapes[2] % block_size != 0:
863 return []
864
865 args = []
866 args.append(["_blocksize_{}".format(block_size), [block_size]])
867
868 return args
869
870 def agDepthToSpace(op, shapes, rng):
871 # must be rank 4 input tensor
872 if len(shapes) != 4:
873 return []
874
875 block_size = 2
876 # depth dimension must be divisible by block_size * block_size
877 if shapes[3] % (block_size * block_size) != 0:
878 return []
879
880 args = []
881 args.append(["_blocksize_{}".format(block_size), [block_size]])
882
883 return args
884
885 def agFakequant(op, shapes, rng):
886 args = []
887 for num_bits in [8, 16]:
888 for narrow in [False, True]:
889 args.append(
890 ["_bits{}_narrow{}".format(num_bits, narrow), [num_bits, narrow]]
891 )
892
893 return args
894
895 def agShift(op, shapes, rng):
896 args = []
897
898 for shift in rng.integers(0, 32, size=8):
899 args.append(["_shift{}".format(shift), [shift]])
900
901 return args
902
903 def agFloat(op, shapes, rng):
904 args = []
905
906 i = 0
907 for alpha in np.float32(rng.random(size=2)):
908 args.append(["_{}".format(i), [alpha]])
909
910 return args
911
912 # Similar to agAxes, but tf.OneHot only allow axis from [-1, rank(input)]
913 def agOneHot(op, shapes, rng):
914 axes = []
915 for i in range(-1, len(shapes) + 1, 1):
916 if i >= 0:
917 axes.append(["_axis_{}".format(i), [i]])
918 else:
919 axes.append(["_axis_m{}".format(-i), [i]])
920 return axes
Luke Hutton261b7b62023-01-10 14:50:31 +0000921
922 def agRFFT2d(op, shape, rng):
923 args = []
924
925 # Must be rank 3 input tensor
926 if len(shape) != 3:
927 return []
928
929 # Check rfft2d with enforced fft_length
930 for fft_length_h in [2, 32]:
931 for fft_length_w in [2, 8, 16]:
932 fft_length = [fft_length_h, fft_length_w]
933 args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]])
934
935 # Check rfft2d with no fft_length provided (fft_length=None).
936 # In this case, the height and width of the input should be
937 # used for the calculation. Therefore, we need to check that
938 # the input shape is already a power of two.
939 def is_power_of_two(x):
940 return math.log(x, 2).is_integer()
941
942 height, width = shape[1:3]
943 if is_power_of_two(height) and is_power_of_two(width):
944 args.append(["_fft_length_None", [None]])
945
946 return args