blob: 61a1de04424d6f90e38ef05947764362e862edb4 [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson015c3552022-02-23 12:15:03 +00005import numpy as np
6
7
8class ArgGen:
9 """Argument generator functions. These functions take a shape and dtype to
10 create arguments for an operator. Methods are prefixed with 'ag' to make
11 search easy."""
12
13 def __init__(self):
14 pass
15
16 @staticmethod
17 def agNone(op, shapes, rng):
18 """A trivial argument generator for operators that only take tensor
19 operands"""
20 return [("", [])]
21
22 # Build the axis argument for operators where we want to iterate over N axes
23 # as an argument
24 @staticmethod
25 def agAxes(op, shapes, rng):
26 axes = []
27 for i in range(-len(shapes), len(shapes), 1):
28 if i >= 0:
29 axes.append(["_axis_{}".format(i), [i]])
30 else:
31 axes.append(["_axis_m{}".format(-i), [i]])
32 return axes
33
34 # Build the axis LIST argument for operators that take an axis list.
35 # This builds a list of each axis individually, plus one element
36 # that contains a list of all axes. Note that we need to pack the list in
37 # an additional list so that it isn't exploded when being passed to the
38 # build_operator function.
39 # tensor_arg_count not used
40 def agAxesList(op, shapes, rng):
41 axes = ArgGen.agAxes(op, shapes, rng)
42 axes_list = []
43 for desc, a in axes:
44 axes_list.append([desc, [a]])
45
46 axes_list.append(["_axisall", [list(range(len(shapes)))]])
47 axes_list.append(["_axisall_none", [None]])
48 return axes_list
49
50 def agAxesListKeepdims(op, shapes, rng):
51 axes = ArgGen.agAxes(op, shapes, rng)
52 axes_list = []
53 for desc, a in axes:
54 axes_list.append([desc + "_keep0", [a, False]])
55 # avoid trying to reduce an axis of shape 1, as the TFL converter
56 # will optimize away the entire reduction
57 if (a[0] >= 0 and shapes[a[0]] != 1) or (
58 a[0] < 0 and shapes[len(shapes) + a[0]] != 1
59 ):
60 axes_list.append([desc + "_keep1", [a, True]])
61
62 axes_list.append(["_axisall_keep0", [list(range(len(shapes))), False]])
63 axes_list.append(["_axisall_keep0_none", [None, False]])
64 # another instance where the reduce gets optimized out.
65 if len(shapes) != 1:
66 axes_list.append(["_axisall_keep1", [list(range(len(shapes))), True]])
67 axes_list.append(["_axisall_keep1_none", [None, True]])
68 # no longer test axis empty, as TFL converter optimizes the reduce out
69 return axes_list
70
71 # conv2d argument generators build the TF constants
72 def agConv2d(op, shapes, rng):
73 arg_list = []
74
75 # Must be rank 4
76 if len(shapes) < 4:
77 return arg_list
78
79 filter_h, filter_w = op["filter"]
80
81 # strides, padding, dilations,
82 for stride_h in [1, 2]:
83 for stride_w in [1, 2]:
84 for padding in ["SAME", "VALID"]:
85 for dilation_h in [1, 2]:
86 for dilation_w in [1, 2]:
87
88 # Disqualify argument combinations that would cause
89 # an illegal convolution
90
91 if (padding == "VALID") and (
92 (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
93 or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
94 ):
95 continue
96
Jeremy Johnson0e6218e2022-05-05 17:08:04 +010097 if (
98 (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride_h
99 != 0
100 ) or (
101 (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride_w
102 != 0
103 ):
104 # Not an exact integer output
105 continue
106
Jeremy Johnson015c3552022-02-23 12:15:03 +0000107 arg_list.append(
108 [
109 "_st{}{}_pad{}_dilat{}{}".format(
110 stride_h,
111 stride_w,
112 padding,
113 dilation_h,
114 dilation_w,
115 ),
116 [
117 [stride_h, stride_w],
118 padding,
119 [dilation_h, dilation_w],
120 ],
121 ]
122 )
123 return arg_list
124
TatWai Chongfd629052022-07-25 04:01:58 +0000125 # conv3d argument generators build the TF constants
126 def agConv3d(op, shapes, rng):
127 arg_list = []
128
129 # input shape = [OC, KD, KH, KW, IC]
130 # Must be rank 5
131 if len(shapes) != 5:
132 return arg_list
133
134 if len(op["filter"]) < 3:
135 return arg_list
136
137 filter_d, filter_h, filter_w = op["filter"]
138
139 # strides, padding, dilations,
140 for stride_d in [1, 2]:
141 for stride_h in [1, 2]:
142 for stride_w in [1, 2]:
143 for padding in ["SAME", "VALID"]:
144 for dilation_d in [1, 2]:
145 for dilation_h in [1, 2]:
146 for dilation_w in [1, 2]:
147
148 # Disqualify argument combinations that would cause
149 # an illegal convolution
150 # fmt: off
151 if (padding == "VALID") and (
152 (shapes[1] - (filter_d - 1) * 2 - dilation_d) <= 0
153 or (shapes[2] - (filter_h - 1) * 2 - dilation_h) <= 0
154 or (shapes[3] - (filter_w - 1) * 2 - dilation_w) <= 0
155 ):
156 continue
157
158 if (
159 (shapes[1] - 1 - (filter_d - 1) * dilation_d) % stride_d
160 != 0
161 ) or (
162 (shapes[2] - 1 - (filter_h - 1) * dilation_h) % stride_h
163 != 0
164 ) or (
165 (shapes[3] - 1 - (filter_w - 1) * dilation_w) % stride_w
166 != 0
167 ):
168 # Not an exact integer output
169 continue
170 # fmt: on
171
172 # TODO investigate the error of `CPU implementation of Conv3D
173 # currently only supports dilated rates of 1.` from Tensorflow.
174 # Only test dilations = [1, 1, 1, 1, 1] for now.
175 if (
176 (dilation_d != 1)
177 or (dilation_h != 1)
178 or (dilation_w != 1)
179 ):
180 continue
181
182 # Tensorflow expects strides is a list of ints that has length >= 5.
183 # Strides and dilations in the batch and depth dimensions must be 1.
184 arg_list.append(
185 [
186 "_st{}{}{}{}{}_pad{}_dilat{}{}{}{}{}".format(
187 1,
188 stride_d,
189 stride_h,
190 stride_w,
191 1,
192 padding,
193 1,
194 dilation_d,
195 dilation_h,
196 dilation_w,
197 1,
198 ),
199 [
200 [1, stride_d, stride_h, stride_w, 1],
201 padding,
202 [
203 1,
204 dilation_d,
205 dilation_h,
206 dilation_w,
207 1,
208 ],
209 ],
210 ]
211 )
212 return arg_list
213
Jeremy Johnson015c3552022-02-23 12:15:03 +0000214 # conv2d argument generators build the TF constants
215 def agDepthwiseConv2d(op, shapes, rng):
216 arg_list = []
217
218 # Must be rank 4
219 if len(shapes) < 4:
220 return arg_list
221
222 filter_h, filter_w = op["filter"]
223
224 # strides, padding, dilations, Depthwise conv2d is the same as conv2d
225 # except that strides in h/w must be the same and the argument must be
226 # formatted as [1, stride_h, stride_w, 1] in TF.
227 for stride in [1, 2]:
228 for padding in ["SAME", "VALID"]:
229 for dilation_h in [1, 2]:
230 for dilation_w in [1, 2]:
231
232 # Disqualify argument combinations that would cause an illegal
233 # convolution
234
235 if (padding == "VALID") and (
236 (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
237 or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
238 ):
239 continue
240
241 # When dilation is used, stride must be 1x1 (TF rules)
242 if dilation_h > 1 or dilation_w > 1:
243 if stride > 1:
244 continue
245
246 # Dilation must evenly divide the tensor. Some of our inputs
247 # intentionally use odd-sized tensors.
248 if shapes[1] % dilation_h != 0 or shapes[2] % dilation_w != 0:
249 continue
250
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100251 if (
252 (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride != 0
253 ) or (
254 (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride != 0
255 ):
256 # Not an exact integer output
257 continue
258
Jeremy Johnson015c3552022-02-23 12:15:03 +0000259 arg_list.append(
260 [
261 "_st{}{}_pad{}_dilat{}{}".format(
262 stride, stride, padding, dilation_h, dilation_w
263 ),
264 [
265 [1, stride, stride, 1],
266 padding,
267 [dilation_h, dilation_w],
268 ],
269 ]
270 )
271 return arg_list
272
273 # conv2d argument generators build the TF constants
274 def agTransposeConv2d(op, shapes, rng):
275 arg_list = []
276
277 # Must be rank 4
278 if len(shapes) < 4:
279 return arg_list
280
281 filter_h, filter_w = op["filter"]
282
283 # strides, padding, dilations,
284 for stride_h in [1, 2]:
285 for stride_w in [1, 2]:
286 for padding in ["SAME", "VALID"]:
287 if padding == "SAME":
288 out_height = (shapes[1]) * stride_h
289 out_width = (shapes[2]) * stride_w
290 else: # padding == 'VALID'
291 out_height = (shapes[1] - 1) * stride_h + filter_h
292 out_width = (shapes[2] - 1) * stride_w + filter_w
293
294 output_shape = [shapes[0], out_height, out_width, shapes[3] * 2]
295 arg_list.append(
296 [
297 "_st{}{}_pad{}".format(stride_h, stride_w, padding),
298 [output_shape, [stride_h, stride_w], padding],
299 ]
300 )
301 return arg_list
302
303 def agPooling(op, shapes, rng):
304 arg_list = []
305
306 # Must be rank 4
307 if len(shapes) < 4:
308 return arg_list
309
310 for stride_h in [1, 2]:
311 for stride_w in [1, 2]:
312 for kernel_h in [1, 2]:
313 for kernel_w in [1, 2]:
314 for padding in ["SAME", "VALID"]:
315
316 if (padding == "VALID") and (
317 (shapes[1] % (kernel_h * stride_h) > 0)
318 or (shapes[2] % (kernel_w * stride_w) > 0)
319 or (shapes[1] <= kernel_h)
320 or (shapes[2] <= kernel_w)
321 ):
322 continue
323
324 if (padding == "SAME") and (
325 (shapes[1] < kernel_h) or (shapes[2] < kernel_w)
326 ):
327 continue
328
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100329 if ((shapes[1] - kernel_h) % stride_h != 0) or (
330 (shapes[2] - kernel_w) % stride_w != 0
331 ):
332 # Not an exact integer output
333 continue
334
Jeremy Johnson015c3552022-02-23 12:15:03 +0000335 arg_list.append(
336 [
337 "_st{}{}_pad{}_kern{}{}".format(
338 stride_h, stride_w, padding, kernel_h, kernel_w
339 ),
340 [
341 [stride_h, stride_w],
342 [kernel_h, kernel_w],
343 padding,
344 ],
345 ]
346 )
347 return arg_list
348
349 def getFactors(val, start=1):
350 factors = []
351 for i in range(start, int(np.sqrt(val))):
352 if (val % i) == 0:
353 factors.append(i)
354
355 return factors
356
357 def agReshape(op, shapes, rng):
358 # This is slow code. Fortunately, the numbers involved are small
359 arg_list = []
360
361 total_elements = 1
362 for s in shapes:
363 total_elements *= s
364
365 # Find integer factors of this shape
366 factors = ArgGen.getFactors(total_elements)
367
368 for rank in range(1, len(shapes) + 1):
369 if len(factors) < rank:
370 break
371
372 new_shape = []
373 remaining_elements = total_elements
374
375 # Randomly shuffle the factors and iteratively pick from the factors
376 # of the remaining elements
377 shuffled_factors = rng.permutation(factors)
378 for i in range(rank):
379 # Pick rank - 1 factors
380 new_shape.append(shuffled_factors[0])
381 remaining_elements = remaining_elements // shuffled_factors[0]
382 shuffled_factors = rng.permutation(
383 ArgGen.getFactors(remaining_elements)
384 )
385 new_shape.append(remaining_elements)
386
387 # Don't do no-op reshapes because TFLite optimizes out the op
388 if new_shape == list(shapes):
389 continue
390
391 arg_list.append(["_rank{}".format(rank), [new_shape]])
392
393 return arg_list
394
395 def agTranspose(op, shapes, rng):
396 arg_list = []
397
398 # Must have at least two dimensions to transpose
399 if (len(shapes)) < 2:
400 return arg_list
401
402 # Pick a bunch of random permutations
403 range_arr = np.arange(len(shapes))
404 for i in range(len(shapes)):
405 perm = rng.permutation(range_arr).astype(np.int32)
406 # print('\n shape {} permute{} perm: {} arr: {}'.format(shapes, i,
407 # perm, range_arr))
408 if np.allclose(perm, range_arr):
409 print("skipped")
410 continue
411 arg_list.append(["_permute{}".format(i), [perm]])
412
413 return arg_list
414
415 def agSlice(op, shapes, rng):
416 arg_list = []
417
418 rank = len(shapes)
419
420 if rank == 1 and shapes[0] == 1:
421 return arg_list
422
423 for i in range(4):
424 # Pick a few random start points, axes, and strides
425 start = np.empty((rank), dtype=int)
426 size = np.empty((rank), dtype=int)
427 for j in range(rank):
428 if shapes[j] > 2:
429 start[j] = rng.integers(0, shapes[j] - 2)
430 # print('j = {}: {} - {} - 1: {}'.format(j, shapes[j],
431 # start[j], shapes[j] - start[j] - 1))
432 size[j] = rng.integers(1, shapes[j] - start[j] - 1)
433 else:
434 start[j] = 0
435 size[j] = shapes[j]
436
437 arg_list.append(["_perm{}".format(i), [start, size]])
438
439 return arg_list
440
441 def agStridedSlice(op, shapes, rng):
442 arg_list = []
443
444 rank = len(shapes)
445
446 # Reference model is limited to rank=6 internally right now
447 if rank > 3:
448 return arg_list
449
450 if rank == 1 and shapes[0] == 1:
451 return arg_list
452
453 for i in range(4):
454 # Pick a few random begin points, axes, and strides
455 begin = np.empty((rank), dtype=int)
456 end = np.empty((rank), dtype=int)
457 strides = np.empty((rank), dtype=int)
458
459 begin_mask = rng.integers(0, (1 << (rank - 1)))
460 end_mask = rng.integers(0, (1 << (rank - 1)))
461
462 for j in range(rank):
463
464 if begin_mask & (1 << j) or shapes[j] < 2:
465 begin[j] = 0
466 else:
467 begin[j] = rng.integers(0, shapes[j] - 1)
468
469 if end_mask & (1 << j) or shapes[j] < 2 or (begin[j] + 2) >= shapes[j]:
470 end[j] = shapes[j]
471 else:
472 end[j] = rng.integers(begin[j] + 1, shapes[j] - 1)
473
474 possible_stride = ArgGen.getFactors(end[j] - begin[j], 2)
475
476 if not possible_stride:
477 strides[j] = 1
478 else:
479 strides[j] = rng.choice(possible_stride)
480
481 # Randomly set the masks, except ellipsis_mask and new_axis_mask
482 # which must be zero for now For begin/end mask this to work,
483 # strides must be adjusted to still be divsible...
484 ellipsis_mask = 0
485 new_axis_mask = 0
486
487 # if rng.choice([0, 1]) and rank > 1:
488 # new_axis_mask = 1 << rng.integers(0, rank - 1)
489 # else:
490 # new_axis_mask = 0
491
492 if rng.choice([0, 1]) and rank > 1:
493 shrink_axis_mask = 1 << rng.integers(0, rank - 1)
494 else:
495 shrink_axis_mask = 0
496
497 # Only one of these bits may be set. Prefer shrink_axis_mask
498 new_axis_mask = new_axis_mask & ~shrink_axis_mask
499
500 arg_list.append(
501 [
502 "_perm{}".format(i),
503 [
504 begin,
505 end,
506 strides,
507 begin_mask,
508 end_mask,
509 ellipsis_mask,
510 new_axis_mask,
511 shrink_axis_mask,
512 ],
513 ]
514 )
515
516 # print('Shape: {} begin={} end={} strides={} begin_mask={:x}
517 # end_mask={:x} new_axis_mask={:x} shrink_mask={:x}'.format(shapes,
518 # begin, end, strides, begin_mask, end_mask, new_axis_mask,
519 # shrink_axis_mask))
520
521 return arg_list
522
523 # tf.stack axis can be [0, rank(input)]
524 def agStack(op, shapes, rng):
525 axes = []
526 for i in range(len(shapes) + 1):
527 axes.append(["_axis{}".format(i), [i]])
528 return axes
529
TatWai Chongf7008da2022-09-09 09:35:40 +0000530 def agMirrorPad(op, shapes, rng):
531 arg_list = []
532
533 rank = len(shapes)
534 for mode in ["REFLECT", "SYMMETRIC"]:
535 for left in range(3):
536 for right in range(3):
537 paddings = np.zeros((rank, 2), dtype=np.int32)
538 is_valid = True
539
540 # Fill in the padding parameter if the values are valid on each dimension,
541 # otherwise drop that case.
542 for d in range(rank):
543 paddings[d, 0] = left
544 paddings[d, 1] = right
545
546 # In "REFLECT" mode, paddings must be no greater than tensor dim size - 1.
547 if mode == "REFLECT":
548 if (left > shapes[d] - 1) or (right > shapes[d] - 1):
549 is_valid = False
550 break
551
552 # In "SYMMETRIC" mode, paddings must be no greater than tensor dim size.
553 else:
554 if (left > shapes[d]) or (right > shapes[d]):
555 is_valid = False
556 break
557
558 if is_valid:
559 arg_list.append(
560 [
561 "_pad{}{}_{}".format(left, right, mode[0:3].lower()),
562 [paddings, mode],
563 ]
564 )
565 return arg_list
566
Jeremy Johnson015c3552022-02-23 12:15:03 +0000567 def agPad(op, shapes, rng):
568 arg_list = []
569
570 rank = len(shapes)
571 for left in range(3):
572 for right in range(3):
573 paddings = np.zeros((rank, 2), dtype=np.int32)
574 for d in range(rank):
575 paddings[d, 0] = left
576 paddings[d, 1] = right
577
578 arg_list.append(["_pad{}{}".format(left, right), [paddings]])
579 return arg_list
580
581 def agFill(op, shapes, rng):
582 values = []
583 for i in range(4):
584 value = rng.integers(0, 10, dtype=np.int32)
585 values.append(["_value{}".format(value), [shapes, value]])
586 return values
587
588 def getValuesToSum(total, rng):
589 # Get a list of random integers that sum up to 'total'
590 vals = []
591
592 # np.random.randint() min and max to be different, so if the remainder
593 # is 1, give up
594 while total > 1:
595 vals.append(rng.integers(1, total))
596 total = total - vals[-1]
597
598 if total == 1:
599 vals.append(1)
600
601 return vals
602
603 def agSplit(op, shapes, rng):
604 arg_list = []
605
606 rank = len(shapes)
607
608 # Shuffle the random number generator a few more times to get
609 # a better range of axes across shapes
610 for i in range(rank):
611 for j in range(shapes[i]):
612 rng.integers(shapes[i])
613
614 for i in range(3):
615 # Need to generate tests for both the num_splits and size_vector versions.
616 axis = rng.choice(np.arange(0, rank))
617
618 # For num_splits, get a few divisors of the given axis
619 divs = ArgGen.getFactors(shapes[axis], 2)
620
621 if divs:
622 # Get no more than 2 samples
623 splits = list(rng.choice(divs, size=2))
624
625 for s in splits:
626 arg_list.append(
627 ["_split{}_axis{}".format(int(s), axis), [int(s), axis]]
628 )
629
630 # For vector splits, get a list of integers that sum up to the axis size
631 vals = ArgGen.getValuesToSum(shapes[axis], rng)
632
633 if len(vals) > 1:
634 arg_list.append(["_splitv_axis{}".format(axis), [vals, axis]])
635
636 return arg_list
637
638 def agTile(op, shapes, rng):
639 arg_list = []
640
641 rank = len(shapes)
642
643 # create 1D multiples list
644 multiples = list()
645 for i in range(rank):
646 multiples.append(rng.integers(1, 4))
647
648 multiples_str = "x".join(list(str(i) for i in multiples))
649
650 arg_list.append(["_tile_{}".format(multiples_str), [multiples]])
651
652 return arg_list
653
654 def agGather(op, shapes, rng):
655 args = []
656 for batch_dims in range(len(shapes) - 1):
657 for axis in range(batch_dims, len(shapes)):
658 # indices value must be within [0, shapes[i])
659
660 # Create an arbitrary shape for the indices
661 indices_rank = rng.integers(batch_dims + 1, 4)
662 indices_shape = rng.integers(1, 8, size=indices_rank)
663
664 # Copy in the batch dimensions because they must match
665 for b in range(batch_dims):
666 indices_shape[b] = shapes[b]
667
668 # Calculate total element count
669 indices_size = 1
670 for j in range(indices_rank):
671 indices_size = indices_shape[j] * indices_size
672
673 indices = rng.integers(0, shapes[axis], indices_size, np.int32).reshape(
674 indices_shape
675 )
676
677 args.append(
678 [
679 "_batchdims_{}_axis_{}".format(batch_dims, axis),
680 [indices, batch_dims, axis],
681 ]
682 )
683 return args
684
685 def agGatherND(op, shapes, rng):
686 args = []
687
688 for N in range(1, len(shapes) - 1):
689 # Rank includes the N dimension
690 indices_rank = rng.integers(2, 4, size=1)[0]
691 indices_shape = []
692
693 indices_shape = rng.integers(1, 8, size=indices_rank)
694 indices_shape[-1] = N
695
696 indices_count = 1
697 for i in range(indices_rank - 1):
698 indices_count = indices_count * indices_shape[i]
699
700 indices_list = np.zeros(shape=(indices_count, N), dtype=np.int32)
701
702 for i in range(indices_count):
703 for j in range(N):
704 indices_list[i, j] = rng.integers(0, shapes[j], size=1)[0]
705
706 indices = indices_list.reshape(indices_shape)
707
708 args.append(["_n{}".format(N), [indices]])
709
710 return args
711
712 def agScatterND(op, shapes, rng):
713 args = []
714
715 # ScatterND has to generate a constant shapes tensor, indices
716 # tensor, and a tensor of updates. Unforunately, the updates
717 # need to be a size that's based on the N generated in this
718 # function and the dtype known only in the TensorGen function,
719 # but not in ArgGen.
720 #
721 # There are many bad ways to solve this and we'll choose the
722 # least of the evils which still gives reasonable coverage of
723 # the possible operand shapes.
724 for N in range(1, len(shapes)):
725 # Rank includes the N dimension
726 indices_rank = rng.integers(2, 4, size=1)[0]
727 indices_shape = []
728
729 indices_shape = rng.integers(1, 8, size=indices_rank)
730 indices_shape[-1] = N
731
732 # Store the Shapes, and the indicies value tensor as arguments.
733 args.append(["_n{}".format(N), [shapes, indices_shape, N, rng]])
734
735 return args
736
737 def agSpaceToBatch(op, shapes, rng):
738 batch_rank = 1
739 channel_rank = 1
740 block_rank = len(shapes) - batch_rank - channel_rank
741
742 # must have at least rank 1 (M) block
743 if block_rank < 1:
744 return []
745
746 args = []
747 block_shape = []
748 padding_shape = []
749
750 for i in range(block_rank):
751 block_size = 2
752 padding_size = block_size - (shapes[i + 1] % block_size)
753 block_shape.append(block_size)
754 padding_shape.append([0, padding_size])
755
756 args.append(["_blockrank_{}".format(block_rank), [block_shape, padding_shape]])
757 return args
758
759 def agBatchToSpace(op, shapes, rng):
760 batch_rank = 1
761 channel_rank = 1
762 block_rank = len(shapes) - batch_rank - channel_rank
763
764 # must have at least rank 1 (M) block
765 if block_rank < 1:
766 return []
767
768 args = []
769 block_shape = []
770 padding_shape = []
771 block_prod = 1
772
773 for i in range(block_rank):
774 block_size = 2
775 block_prod = block_prod * block_size
776 crop_size = 0
777 block_shape.append(block_size)
778 padding_shape.append([0, crop_size])
779
780 # batch / prod(block_shape[i]) must be integer
781 # transpose to swap depth and batch. so shape[-1] would be batch dim
782 if shapes[-1] % block_prod == 0:
783 args.append(
784 ["_blockrank_{}".format(block_rank), [block_shape, padding_shape]]
785 )
786
787 return args
788
789 def agSpaceToDepth(op, shapes, rng):
790 # must be rank 4 input tensor
791 if len(shapes) != 4:
792 return []
793
794 block_size = 2
795
796 # spatial dimension must be divisible by block_size
797 if shapes[1] % block_size != 0 or shapes[2] % block_size != 0:
798 return []
799
800 args = []
801 args.append(["_blocksize_{}".format(block_size), [block_size]])
802
803 return args
804
805 def agDepthToSpace(op, shapes, rng):
806 # must be rank 4 input tensor
807 if len(shapes) != 4:
808 return []
809
810 block_size = 2
811 # depth dimension must be divisible by block_size * block_size
812 if shapes[3] % (block_size * block_size) != 0:
813 return []
814
815 args = []
816 args.append(["_blocksize_{}".format(block_size), [block_size]])
817
818 return args
819
820 def agFakequant(op, shapes, rng):
821 args = []
822 for num_bits in [8, 16]:
823 for narrow in [False, True]:
824 args.append(
825 ["_bits{}_narrow{}".format(num_bits, narrow), [num_bits, narrow]]
826 )
827
828 return args
829
830 def agShift(op, shapes, rng):
831 args = []
832
833 for shift in rng.integers(0, 32, size=8):
834 args.append(["_shift{}".format(shift), [shift]])
835
836 return args
837
838 def agFloat(op, shapes, rng):
839 args = []
840
841 i = 0
842 for alpha in np.float32(rng.random(size=2)):
843 args.append(["_{}".format(i), [alpha]])
844
845 return args
846
847 # Similar to agAxes, but tf.OneHot only allow axis from [-1, rank(input)]
848 def agOneHot(op, shapes, rng):
849 axes = []
850 for i in range(-1, len(shapes) + 1, 1):
851 if i >= 0:
852 axes.append(["_axis_{}".format(i), [i]])
853 else:
854 axes.append(["_axis_m{}".format(-i), [i]])
855 return axes
Luke Hutton261b7b62023-01-10 14:50:31 +0000856
857 def agRFFT2d(op, shape, rng):
858 args = []
859
860 # Must be rank 3 input tensor
861 if len(shape) != 3:
862 return []
863
864 # Check rfft2d with enforced fft_length
865 for fft_length_h in [2, 32]:
866 for fft_length_w in [2, 8, 16]:
867 fft_length = [fft_length_h, fft_length_w]
868 args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]])
869
870 # Check rfft2d with no fft_length provided (fft_length=None).
871 # In this case, the height and width of the input should be
872 # used for the calculation. Therefore, we need to check that
873 # the input shape is already a power of two.
874 def is_power_of_two(x):
875 return math.log(x, 2).is_integer()
876
877 height, width = shape[1:3]
878 if is_power_of_two(height) and is_power_of_two(width):
879 args.append(["_fft_length_None", [None]])
880
881 return args