blob: 5de995b7bf6ba7ff493ba8038c01ac2c37e32761 [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson015c3552022-02-23 12:15:03 +00005import numpy as np
6
7
8class ArgGen:
9 """Argument generator functions. These functions take a shape and dtype to
10 create arguments for an operator. Methods are prefixed with 'ag' to make
11 search easy."""
12
13 def __init__(self):
14 pass
15
16 @staticmethod
17 def agNone(op, shapes, rng):
18 """A trivial argument generator for operators that only take tensor
19 operands"""
20 return [("", [])]
21
22 # Build the axis argument for operators where we want to iterate over N axes
23 # as an argument
24 @staticmethod
25 def agAxes(op, shapes, rng):
26 axes = []
27 for i in range(-len(shapes), len(shapes), 1):
28 if i >= 0:
29 axes.append(["_axis_{}".format(i), [i]])
30 else:
31 axes.append(["_axis_m{}".format(-i), [i]])
32 return axes
33
34 # Build the axis LIST argument for operators that take an axis list.
35 # This builds a list of each axis individually, plus one element
36 # that contains a list of all axes. Note that we need to pack the list in
37 # an additional list so that it isn't exploded when being passed to the
38 # build_operator function.
39 # tensor_arg_count not used
40 def agAxesList(op, shapes, rng):
41 axes = ArgGen.agAxes(op, shapes, rng)
42 axes_list = []
43 for desc, a in axes:
44 axes_list.append([desc, [a]])
45
46 axes_list.append(["_axisall", [list(range(len(shapes)))]])
47 axes_list.append(["_axisall_none", [None]])
48 return axes_list
49
50 def agAxesListKeepdims(op, shapes, rng):
51 axes = ArgGen.agAxes(op, shapes, rng)
52 axes_list = []
53 for desc, a in axes:
54 axes_list.append([desc + "_keep0", [a, False]])
55 # avoid trying to reduce an axis of shape 1, as the TFL converter
56 # will optimize away the entire reduction
57 if (a[0] >= 0 and shapes[a[0]] != 1) or (
58 a[0] < 0 and shapes[len(shapes) + a[0]] != 1
59 ):
60 axes_list.append([desc + "_keep1", [a, True]])
61
62 axes_list.append(["_axisall_keep0", [list(range(len(shapes))), False]])
63 axes_list.append(["_axisall_keep0_none", [None, False]])
64 # another instance where the reduce gets optimized out.
65 if len(shapes) != 1:
66 axes_list.append(["_axisall_keep1", [list(range(len(shapes))), True]])
67 axes_list.append(["_axisall_keep1_none", [None, True]])
68 # no longer test axis empty, as TFL converter optimizes the reduce out
69 return axes_list
70
71 # conv2d argument generators build the TF constants
72 def agConv2d(op, shapes, rng):
73 arg_list = []
74
75 # Must be rank 4
76 if len(shapes) < 4:
77 return arg_list
78
79 filter_h, filter_w = op["filter"]
80
81 # strides, padding, dilations,
82 for stride_h in [1, 2]:
83 for stride_w in [1, 2]:
84 for padding in ["SAME", "VALID"]:
85 for dilation_h in [1, 2]:
86 for dilation_w in [1, 2]:
87
88 # Disqualify argument combinations that would cause
89 # an illegal convolution
90
91 if (padding == "VALID") and (
92 (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
93 or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
94 ):
95 continue
96
Jeremy Johnson0e6218e2022-05-05 17:08:04 +010097 if (
98 (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride_h
99 != 0
100 ) or (
101 (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride_w
102 != 0
103 ):
104 # Not an exact integer output
105 continue
106
Jeremy Johnson015c3552022-02-23 12:15:03 +0000107 arg_list.append(
108 [
109 "_st{}{}_pad{}_dilat{}{}".format(
110 stride_h,
111 stride_w,
112 padding,
113 dilation_h,
114 dilation_w,
115 ),
116 [
117 [stride_h, stride_w],
118 padding,
119 [dilation_h, dilation_w],
120 ],
121 ]
122 )
123 return arg_list
124
TatWai Chongfd629052022-07-25 04:01:58 +0000125 # conv3d argument generators build the TF constants
126 def agConv3d(op, shapes, rng):
127 arg_list = []
128
129 # input shape = [OC, KD, KH, KW, IC]
130 # Must be rank 5
131 if len(shapes) != 5:
132 return arg_list
133
134 if len(op["filter"]) < 3:
135 return arg_list
136
137 filter_d, filter_h, filter_w = op["filter"]
138
139 # strides, padding, dilations,
140 for stride_d in [1, 2]:
141 for stride_h in [1, 2]:
142 for stride_w in [1, 2]:
143 for padding in ["SAME", "VALID"]:
144 for dilation_d in [1, 2]:
145 for dilation_h in [1, 2]:
146 for dilation_w in [1, 2]:
147
148 # Disqualify argument combinations that would cause
149 # an illegal convolution
150 # fmt: off
151 if (padding == "VALID") and (
152 (shapes[1] - (filter_d - 1) * 2 - dilation_d) <= 0
153 or (shapes[2] - (filter_h - 1) * 2 - dilation_h) <= 0
154 or (shapes[3] - (filter_w - 1) * 2 - dilation_w) <= 0
155 ):
156 continue
157
158 if (
159 (shapes[1] - 1 - (filter_d - 1) * dilation_d) % stride_d
160 != 0
161 ) or (
162 (shapes[2] - 1 - (filter_h - 1) * dilation_h) % stride_h
163 != 0
164 ) or (
165 (shapes[3] - 1 - (filter_w - 1) * dilation_w) % stride_w
166 != 0
167 ):
168 # Not an exact integer output
169 continue
170 # fmt: on
171
172 # TODO investigate the error of `CPU implementation of Conv3D
173 # currently only supports dilated rates of 1.` from Tensorflow.
174 # Only test dilations = [1, 1, 1, 1, 1] for now.
175 if (
176 (dilation_d != 1)
177 or (dilation_h != 1)
178 or (dilation_w != 1)
179 ):
180 continue
181
182 # Tensorflow expects strides is a list of ints that has length >= 5.
183 # Strides and dilations in the batch and depth dimensions must be 1.
184 arg_list.append(
185 [
186 "_st{}{}{}{}{}_pad{}_dilat{}{}{}{}{}".format(
187 1,
188 stride_d,
189 stride_h,
190 stride_w,
191 1,
192 padding,
193 1,
194 dilation_d,
195 dilation_h,
196 dilation_w,
197 1,
198 ),
199 [
200 [1, stride_d, stride_h, stride_w, 1],
201 padding,
202 [
203 1,
204 dilation_d,
205 dilation_h,
206 dilation_w,
207 1,
208 ],
209 ],
210 ]
211 )
212 return arg_list
213
Jeremy Johnson015c3552022-02-23 12:15:03 +0000214 # conv2d argument generators build the TF constants
215 def agDepthwiseConv2d(op, shapes, rng):
216 arg_list = []
217
218 # Must be rank 4
219 if len(shapes) < 4:
220 return arg_list
221
222 filter_h, filter_w = op["filter"]
223
224 # strides, padding, dilations, Depthwise conv2d is the same as conv2d
225 # except that strides in h/w must be the same and the argument must be
226 # formatted as [1, stride_h, stride_w, 1] in TF.
227 for stride in [1, 2]:
228 for padding in ["SAME", "VALID"]:
229 for dilation_h in [1, 2]:
230 for dilation_w in [1, 2]:
231
232 # Disqualify argument combinations that would cause an illegal
233 # convolution
234
235 if (padding == "VALID") and (
236 (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
237 or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
238 ):
239 continue
240
241 # When dilation is used, stride must be 1x1 (TF rules)
242 if dilation_h > 1 or dilation_w > 1:
243 if stride > 1:
244 continue
245
246 # Dilation must evenly divide the tensor. Some of our inputs
247 # intentionally use odd-sized tensors.
248 if shapes[1] % dilation_h != 0 or shapes[2] % dilation_w != 0:
249 continue
250
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100251 if (
252 (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride != 0
253 ) or (
254 (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride != 0
255 ):
256 # Not an exact integer output
257 continue
258
Jeremy Johnson015c3552022-02-23 12:15:03 +0000259 arg_list.append(
260 [
261 "_st{}{}_pad{}_dilat{}{}".format(
262 stride, stride, padding, dilation_h, dilation_w
263 ),
264 [
265 [1, stride, stride, 1],
266 padding,
267 [dilation_h, dilation_w],
268 ],
269 ]
270 )
271 return arg_list
272
273 # conv2d argument generators build the TF constants
274 def agTransposeConv2d(op, shapes, rng):
275 arg_list = []
276
277 # Must be rank 4
278 if len(shapes) < 4:
279 return arg_list
280
281 filter_h, filter_w = op["filter"]
282
283 # strides, padding, dilations,
284 for stride_h in [1, 2]:
285 for stride_w in [1, 2]:
286 for padding in ["SAME", "VALID"]:
287 if padding == "SAME":
288 out_height = (shapes[1]) * stride_h
289 out_width = (shapes[2]) * stride_w
290 else: # padding == 'VALID'
291 out_height = (shapes[1] - 1) * stride_h + filter_h
292 out_width = (shapes[2] - 1) * stride_w + filter_w
293
294 output_shape = [shapes[0], out_height, out_width, shapes[3] * 2]
295 arg_list.append(
296 [
297 "_st{}{}_pad{}".format(stride_h, stride_w, padding),
298 [output_shape, [stride_h, stride_w], padding],
299 ]
300 )
301 return arg_list
302
303 def agPooling(op, shapes, rng):
304 arg_list = []
305
306 # Must be rank 4
307 if len(shapes) < 4:
308 return arg_list
309
310 for stride_h in [1, 2]:
311 for stride_w in [1, 2]:
312 for kernel_h in [1, 2]:
313 for kernel_w in [1, 2]:
314 for padding in ["SAME", "VALID"]:
315
316 if (padding == "VALID") and (
317 (shapes[1] % (kernel_h * stride_h) > 0)
318 or (shapes[2] % (kernel_w * stride_w) > 0)
319 or (shapes[1] <= kernel_h)
320 or (shapes[2] <= kernel_w)
321 ):
322 continue
323
324 if (padding == "SAME") and (
325 (shapes[1] < kernel_h) or (shapes[2] < kernel_w)
326 ):
327 continue
328
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100329 if ((shapes[1] - kernel_h) % stride_h != 0) or (
330 (shapes[2] - kernel_w) % stride_w != 0
331 ):
332 # Not an exact integer output
333 continue
334
Jeremy Johnson015c3552022-02-23 12:15:03 +0000335 arg_list.append(
336 [
337 "_st{}{}_pad{}_kern{}{}".format(
338 stride_h, stride_w, padding, kernel_h, kernel_w
339 ),
340 [
341 [stride_h, stride_w],
342 [kernel_h, kernel_w],
343 padding,
344 ],
345 ]
346 )
347 return arg_list
348
349 def getFactors(val, start=1):
350 factors = []
351 for i in range(start, int(np.sqrt(val))):
352 if (val % i) == 0:
353 factors.append(i)
354
355 return factors
356
357 def agReshape(op, shapes, rng):
358 # This is slow code. Fortunately, the numbers involved are small
359 arg_list = []
360
361 total_elements = 1
362 for s in shapes:
363 total_elements *= s
364
365 # Find integer factors of this shape
366 factors = ArgGen.getFactors(total_elements)
367
368 for rank in range(1, len(shapes) + 1):
369 if len(factors) < rank:
370 break
371
372 new_shape = []
373 remaining_elements = total_elements
374
375 # Randomly shuffle the factors and iteratively pick from the factors
376 # of the remaining elements
377 shuffled_factors = rng.permutation(factors)
378 for i in range(rank):
379 # Pick rank - 1 factors
380 new_shape.append(shuffled_factors[0])
381 remaining_elements = remaining_elements // shuffled_factors[0]
382 shuffled_factors = rng.permutation(
383 ArgGen.getFactors(remaining_elements)
384 )
385 new_shape.append(remaining_elements)
386
387 # Don't do no-op reshapes because TFLite optimizes out the op
388 if new_shape == list(shapes):
389 continue
390
391 arg_list.append(["_rank{}".format(rank), [new_shape]])
392
393 return arg_list
394
395 def agTranspose(op, shapes, rng):
396 arg_list = []
397
398 # Must have at least two dimensions to transpose
399 if (len(shapes)) < 2:
400 return arg_list
401
402 # Pick a bunch of random permutations
403 range_arr = np.arange(len(shapes))
404 for i in range(len(shapes)):
405 perm = rng.permutation(range_arr).astype(np.int32)
406 # print('\n shape {} permute{} perm: {} arr: {}'.format(shapes, i,
407 # perm, range_arr))
408 if np.allclose(perm, range_arr):
409 print("skipped")
410 continue
411 arg_list.append(["_permute{}".format(i), [perm]])
412
413 return arg_list
414
415 def agSlice(op, shapes, rng):
416 arg_list = []
417
418 rank = len(shapes)
419
420 if rank == 1 and shapes[0] == 1:
421 return arg_list
422
423 for i in range(4):
424 # Pick a few random start points, axes, and strides
425 start = np.empty((rank), dtype=int)
426 size = np.empty((rank), dtype=int)
427 for j in range(rank):
428 if shapes[j] > 2:
429 start[j] = rng.integers(0, shapes[j] - 2)
430 # print('j = {}: {} - {} - 1: {}'.format(j, shapes[j],
431 # start[j], shapes[j] - start[j] - 1))
432 size[j] = rng.integers(1, shapes[j] - start[j] - 1)
433 else:
434 start[j] = 0
435 size[j] = shapes[j]
436
437 arg_list.append(["_perm{}".format(i), [start, size]])
438
439 return arg_list
440
441 def agStridedSlice(op, shapes, rng):
442 arg_list = []
443
444 rank = len(shapes)
445
446 # Reference model is limited to rank=6 internally right now
447 if rank > 3:
448 return arg_list
449
450 if rank == 1 and shapes[0] == 1:
451 return arg_list
452
453 for i in range(4):
454 # Pick a few random begin points, axes, and strides
455 begin = np.empty((rank), dtype=int)
456 end = np.empty((rank), dtype=int)
457 strides = np.empty((rank), dtype=int)
458
459 begin_mask = rng.integers(0, (1 << (rank - 1)))
460 end_mask = rng.integers(0, (1 << (rank - 1)))
461
462 for j in range(rank):
463
464 if begin_mask & (1 << j) or shapes[j] < 2:
465 begin[j] = 0
466 else:
467 begin[j] = rng.integers(0, shapes[j] - 1)
468
469 if end_mask & (1 << j) or shapes[j] < 2 or (begin[j] + 2) >= shapes[j]:
470 end[j] = shapes[j]
471 else:
472 end[j] = rng.integers(begin[j] + 1, shapes[j] - 1)
473
474 possible_stride = ArgGen.getFactors(end[j] - begin[j], 2)
475
476 if not possible_stride:
477 strides[j] = 1
478 else:
479 strides[j] = rng.choice(possible_stride)
480
481 # Randomly set the masks, except ellipsis_mask and new_axis_mask
482 # which must be zero for now For begin/end mask this to work,
483 # strides must be adjusted to still be divsible...
484 ellipsis_mask = 0
485 new_axis_mask = 0
486
487 # if rng.choice([0, 1]) and rank > 1:
488 # new_axis_mask = 1 << rng.integers(0, rank - 1)
489 # else:
490 # new_axis_mask = 0
491
492 if rng.choice([0, 1]) and rank > 1:
493 shrink_axis_mask = 1 << rng.integers(0, rank - 1)
494 else:
495 shrink_axis_mask = 0
496
497 # Only one of these bits may be set. Prefer shrink_axis_mask
498 new_axis_mask = new_axis_mask & ~shrink_axis_mask
499
500 arg_list.append(
501 [
502 "_perm{}".format(i),
503 [
504 begin,
505 end,
506 strides,
507 begin_mask,
508 end_mask,
509 ellipsis_mask,
510 new_axis_mask,
511 shrink_axis_mask,
512 ],
513 ]
514 )
515
516 # print('Shape: {} begin={} end={} strides={} begin_mask={:x}
517 # end_mask={:x} new_axis_mask={:x} shrink_mask={:x}'.format(shapes,
518 # begin, end, strides, begin_mask, end_mask, new_axis_mask,
519 # shrink_axis_mask))
520
521 return arg_list
522
523 # tf.stack axis can be [0, rank(input)]
524 def agStack(op, shapes, rng):
525 axes = []
526 for i in range(len(shapes) + 1):
527 axes.append(["_axis{}".format(i), [i]])
528 return axes
529
TatWai Chongf7008da2022-09-09 09:35:40 +0000530 def agMirrorPad(op, shapes, rng):
531 arg_list = []
532
533 rank = len(shapes)
534 for mode in ["REFLECT", "SYMMETRIC"]:
535 for left in range(3):
536 for right in range(3):
537 paddings = np.zeros((rank, 2), dtype=np.int32)
538 is_valid = True
539
540 # Fill in the padding parameter if the values are valid on each dimension,
541 # otherwise drop that case.
542 for d in range(rank):
543 paddings[d, 0] = left
544 paddings[d, 1] = right
545
546 # In "REFLECT" mode, paddings must be no greater than tensor dim size - 1.
547 if mode == "REFLECT":
548 if (left > shapes[d] - 1) or (right > shapes[d] - 1):
549 is_valid = False
550 break
551
552 # In "SYMMETRIC" mode, paddings must be no greater than tensor dim size.
553 else:
554 if (left > shapes[d]) or (right > shapes[d]):
555 is_valid = False
556 break
557
558 if is_valid:
559 arg_list.append(
560 [
561 "_pad{}{}_{}".format(left, right, mode[0:3].lower()),
562 [paddings, mode],
563 ]
564 )
565 return arg_list
566
Jeremy Johnson015c3552022-02-23 12:15:03 +0000567 def agPad(op, shapes, rng):
568 arg_list = []
569
570 rank = len(shapes)
571 for left in range(3):
572 for right in range(3):
TatWai Chong2226f902023-02-22 18:38:01 -0800573 # Padding nothing in tensorflow lite causes the interpreter fail to set
574 # the input tensor properly due to date type mismatch.
575 if (left == 0) and (right == 0):
576 continue
Jeremy Johnson015c3552022-02-23 12:15:03 +0000577
TatWai Chong2226f902023-02-22 18:38:01 -0800578 # A simple way to generate explicit pad_const including zero.
579 pad_const = (left - right) * rng.integers(0, 5, dtype=np.int32)
580 padding = np.zeros((rank, 2), dtype=np.int32)
581 for d in range(rank):
582 padding[d, 0] = left
583 padding[d, 1] = right
584
585 arg_list.append(
586 ["_pad{}{}".format(left, right), [padding, pad_const]]
587 )
Jeremy Johnson015c3552022-02-23 12:15:03 +0000588 return arg_list
589
TatWai Chong0cef07e2023-02-27 13:22:52 -0800590 def agResize(op, shapes, rng):
591 args = []
592 for mode in ["nearest", "bilinear"]:
593 for align_corners in [True, False]:
594 for half_pixel in [True, False]:
595 # If half_pixel_centers is True, align_corners must be False.
Jerry Gec9376092023-03-08 17:50:49 +0000596 if (align_corners is True) and (half_pixel is True):
TatWai Chong0cef07e2023-02-27 13:22:52 -0800597 continue
598
599 for i in range(1, 4):
600 args.append(
601 [
602 "_{}_align{}_half{}_scale{}".format(
603 mode, int(align_corners), int(half_pixel), i
604 ),
605 [mode, align_corners, half_pixel, i],
606 ]
607 )
608 return args
609
Jeremy Johnson015c3552022-02-23 12:15:03 +0000610 def agFill(op, shapes, rng):
611 values = []
612 for i in range(4):
613 value = rng.integers(0, 10, dtype=np.int32)
614 values.append(["_value{}".format(value), [shapes, value]])
615 return values
616
617 def getValuesToSum(total, rng):
618 # Get a list of random integers that sum up to 'total'
619 vals = []
620
621 # np.random.randint() min and max to be different, so if the remainder
622 # is 1, give up
623 while total > 1:
624 vals.append(rng.integers(1, total))
625 total = total - vals[-1]
626
627 if total == 1:
628 vals.append(1)
629
630 return vals
631
632 def agSplit(op, shapes, rng):
633 arg_list = []
634
635 rank = len(shapes)
636
637 # Shuffle the random number generator a few more times to get
638 # a better range of axes across shapes
639 for i in range(rank):
640 for j in range(shapes[i]):
641 rng.integers(shapes[i])
642
643 for i in range(3):
644 # Need to generate tests for both the num_splits and size_vector versions.
645 axis = rng.choice(np.arange(0, rank))
646
647 # For num_splits, get a few divisors of the given axis
648 divs = ArgGen.getFactors(shapes[axis], 2)
649
650 if divs:
651 # Get no more than 2 samples
652 splits = list(rng.choice(divs, size=2))
653
654 for s in splits:
655 arg_list.append(
656 ["_split{}_axis{}".format(int(s), axis), [int(s), axis]]
657 )
658
659 # For vector splits, get a list of integers that sum up to the axis size
660 vals = ArgGen.getValuesToSum(shapes[axis], rng)
661
662 if len(vals) > 1:
663 arg_list.append(["_splitv_axis{}".format(axis), [vals, axis]])
664
665 return arg_list
666
667 def agTile(op, shapes, rng):
668 arg_list = []
669
670 rank = len(shapes)
671
672 # create 1D multiples list
673 multiples = list()
674 for i in range(rank):
675 multiples.append(rng.integers(1, 4))
676
677 multiples_str = "x".join(list(str(i) for i in multiples))
678
679 arg_list.append(["_tile_{}".format(multiples_str), [multiples]])
680
681 return arg_list
682
683 def agGather(op, shapes, rng):
684 args = []
685 for batch_dims in range(len(shapes) - 1):
686 for axis in range(batch_dims, len(shapes)):
687 # indices value must be within [0, shapes[i])
688
689 # Create an arbitrary shape for the indices
690 indices_rank = rng.integers(batch_dims + 1, 4)
691 indices_shape = rng.integers(1, 8, size=indices_rank)
692
693 # Copy in the batch dimensions because they must match
694 for b in range(batch_dims):
695 indices_shape[b] = shapes[b]
696
697 # Calculate total element count
698 indices_size = 1
699 for j in range(indices_rank):
700 indices_size = indices_shape[j] * indices_size
701
702 indices = rng.integers(0, shapes[axis], indices_size, np.int32).reshape(
703 indices_shape
704 )
705
706 args.append(
707 [
708 "_batchdims_{}_axis_{}".format(batch_dims, axis),
709 [indices, batch_dims, axis],
710 ]
711 )
712 return args
713
714 def agGatherND(op, shapes, rng):
715 args = []
716
717 for N in range(1, len(shapes) - 1):
718 # Rank includes the N dimension
719 indices_rank = rng.integers(2, 4, size=1)[0]
720 indices_shape = []
721
722 indices_shape = rng.integers(1, 8, size=indices_rank)
723 indices_shape[-1] = N
724
725 indices_count = 1
726 for i in range(indices_rank - 1):
727 indices_count = indices_count * indices_shape[i]
728
729 indices_list = np.zeros(shape=(indices_count, N), dtype=np.int32)
730
731 for i in range(indices_count):
732 for j in range(N):
733 indices_list[i, j] = rng.integers(0, shapes[j], size=1)[0]
734
735 indices = indices_list.reshape(indices_shape)
736
737 args.append(["_n{}".format(N), [indices]])
738
739 return args
740
741 def agScatterND(op, shapes, rng):
742 args = []
743
744 # ScatterND has to generate a constant shapes tensor, indices
745 # tensor, and a tensor of updates. Unforunately, the updates
746 # need to be a size that's based on the N generated in this
747 # function and the dtype known only in the TensorGen function,
748 # but not in ArgGen.
749 #
750 # There are many bad ways to solve this and we'll choose the
751 # least of the evils which still gives reasonable coverage of
752 # the possible operand shapes.
753 for N in range(1, len(shapes)):
754 # Rank includes the N dimension
755 indices_rank = rng.integers(2, 4, size=1)[0]
756 indices_shape = []
757
758 indices_shape = rng.integers(1, 8, size=indices_rank)
759 indices_shape[-1] = N
760
761 # Store the Shapes, and the indicies value tensor as arguments.
762 args.append(["_n{}".format(N), [shapes, indices_shape, N, rng]])
763
764 return args
765
766 def agSpaceToBatch(op, shapes, rng):
767 batch_rank = 1
768 channel_rank = 1
769 block_rank = len(shapes) - batch_rank - channel_rank
770
771 # must have at least rank 1 (M) block
772 if block_rank < 1:
773 return []
774
775 args = []
776 block_shape = []
777 padding_shape = []
778
779 for i in range(block_rank):
780 block_size = 2
781 padding_size = block_size - (shapes[i + 1] % block_size)
782 block_shape.append(block_size)
783 padding_shape.append([0, padding_size])
784
785 args.append(["_blockrank_{}".format(block_rank), [block_shape, padding_shape]])
786 return args
787
788 def agBatchToSpace(op, shapes, rng):
789 batch_rank = 1
790 channel_rank = 1
791 block_rank = len(shapes) - batch_rank - channel_rank
792
793 # must have at least rank 1 (M) block
794 if block_rank < 1:
795 return []
796
797 args = []
798 block_shape = []
799 padding_shape = []
800 block_prod = 1
801
802 for i in range(block_rank):
803 block_size = 2
804 block_prod = block_prod * block_size
805 crop_size = 0
806 block_shape.append(block_size)
807 padding_shape.append([0, crop_size])
808
809 # batch / prod(block_shape[i]) must be integer
810 # transpose to swap depth and batch. so shape[-1] would be batch dim
811 if shapes[-1] % block_prod == 0:
812 args.append(
813 ["_blockrank_{}".format(block_rank), [block_shape, padding_shape]]
814 )
815
816 return args
817
818 def agSpaceToDepth(op, shapes, rng):
819 # must be rank 4 input tensor
820 if len(shapes) != 4:
821 return []
822
823 block_size = 2
824
825 # spatial dimension must be divisible by block_size
826 if shapes[1] % block_size != 0 or shapes[2] % block_size != 0:
827 return []
828
829 args = []
830 args.append(["_blocksize_{}".format(block_size), [block_size]])
831
832 return args
833
834 def agDepthToSpace(op, shapes, rng):
835 # must be rank 4 input tensor
836 if len(shapes) != 4:
837 return []
838
839 block_size = 2
840 # depth dimension must be divisible by block_size * block_size
841 if shapes[3] % (block_size * block_size) != 0:
842 return []
843
844 args = []
845 args.append(["_blocksize_{}".format(block_size), [block_size]])
846
847 return args
848
849 def agFakequant(op, shapes, rng):
850 args = []
851 for num_bits in [8, 16]:
852 for narrow in [False, True]:
853 args.append(
854 ["_bits{}_narrow{}".format(num_bits, narrow), [num_bits, narrow]]
855 )
856
857 return args
858
859 def agShift(op, shapes, rng):
860 args = []
861
862 for shift in rng.integers(0, 32, size=8):
863 args.append(["_shift{}".format(shift), [shift]])
864
865 return args
866
867 def agFloat(op, shapes, rng):
868 args = []
869
870 i = 0
871 for alpha in np.float32(rng.random(size=2)):
872 args.append(["_{}".format(i), [alpha]])
873
874 return args
875
876 # Similar to agAxes, but tf.OneHot only allow axis from [-1, rank(input)]
877 def agOneHot(op, shapes, rng):
878 axes = []
879 for i in range(-1, len(shapes) + 1, 1):
880 if i >= 0:
881 axes.append(["_axis_{}".format(i), [i]])
882 else:
883 axes.append(["_axis_m{}".format(-i), [i]])
884 return axes
Luke Hutton261b7b62023-01-10 14:50:31 +0000885
886 def agRFFT2d(op, shape, rng):
887 args = []
888
889 # Must be rank 3 input tensor
890 if len(shape) != 3:
891 return []
892
893 # Check rfft2d with enforced fft_length
894 for fft_length_h in [2, 32]:
895 for fft_length_w in [2, 8, 16]:
896 fft_length = [fft_length_h, fft_length_w]
897 args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]])
898
899 # Check rfft2d with no fft_length provided (fft_length=None).
900 # In this case, the height and width of the input should be
901 # used for the calculation. Therefore, we need to check that
902 # the input shape is already a power of two.
903 def is_power_of_two(x):
904 return math.log(x, 2).is_integer()
905
906 height, width = shape[1:3]
907 if is_power_of_two(height) and is_power_of_two(width):
908 args.append(["_fft_length_None", [None]])
909
910 return args