blob: 5467fa22e14619acf94edab6a02c71f94e4d131e [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001# Copyright (c) 2020-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3import numpy as np
4
5
6class ArgGen:
7 """Argument generator functions. These functions take a shape and dtype to
8 create arguments for an operator. Methods are prefixed with 'ag' to make
9 search easy."""
10
11 def __init__(self):
12 pass
13
14 @staticmethod
15 def agNone(op, shapes, rng):
16 """A trivial argument generator for operators that only take tensor
17 operands"""
18 return [("", [])]
19
20 # Build the axis argument for operators where we want to iterate over N axes
21 # as an argument
22 @staticmethod
23 def agAxes(op, shapes, rng):
24 axes = []
25 for i in range(-len(shapes), len(shapes), 1):
26 if i >= 0:
27 axes.append(["_axis_{}".format(i), [i]])
28 else:
29 axes.append(["_axis_m{}".format(-i), [i]])
30 return axes
31
32 # Build the axis LIST argument for operators that take an axis list.
33 # This builds a list of each axis individually, plus one element
34 # that contains a list of all axes. Note that we need to pack the list in
35 # an additional list so that it isn't exploded when being passed to the
36 # build_operator function.
37 # tensor_arg_count not used
38 def agAxesList(op, shapes, rng):
39 axes = ArgGen.agAxes(op, shapes, rng)
40 axes_list = []
41 for desc, a in axes:
42 axes_list.append([desc, [a]])
43
44 axes_list.append(["_axisall", [list(range(len(shapes)))]])
45 axes_list.append(["_axisall_none", [None]])
46 return axes_list
47
48 def agAxesListKeepdims(op, shapes, rng):
49 axes = ArgGen.agAxes(op, shapes, rng)
50 axes_list = []
51 for desc, a in axes:
52 axes_list.append([desc + "_keep0", [a, False]])
53 # avoid trying to reduce an axis of shape 1, as the TFL converter
54 # will optimize away the entire reduction
55 if (a[0] >= 0 and shapes[a[0]] != 1) or (
56 a[0] < 0 and shapes[len(shapes) + a[0]] != 1
57 ):
58 axes_list.append([desc + "_keep1", [a, True]])
59
60 axes_list.append(["_axisall_keep0", [list(range(len(shapes))), False]])
61 axes_list.append(["_axisall_keep0_none", [None, False]])
62 # another instance where the reduce gets optimized out.
63 if len(shapes) != 1:
64 axes_list.append(["_axisall_keep1", [list(range(len(shapes))), True]])
65 axes_list.append(["_axisall_keep1_none", [None, True]])
66 # no longer test axis empty, as TFL converter optimizes the reduce out
67 return axes_list
68
69 # conv2d argument generators build the TF constants
70 def agConv2d(op, shapes, rng):
71 arg_list = []
72
73 # Must be rank 4
74 if len(shapes) < 4:
75 return arg_list
76
77 filter_h, filter_w = op["filter"]
78
79 # strides, padding, dilations,
80 for stride_h in [1, 2]:
81 for stride_w in [1, 2]:
82 for padding in ["SAME", "VALID"]:
83 for dilation_h in [1, 2]:
84 for dilation_w in [1, 2]:
85
86 # Disqualify argument combinations that would cause
87 # an illegal convolution
88
89 if (padding == "VALID") and (
90 (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
91 or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
92 ):
93 continue
94
Jeremy Johnson0e6218e2022-05-05 17:08:04 +010095 if (
96 (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride_h
97 != 0
98 ) or (
99 (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride_w
100 != 0
101 ):
102 # Not an exact integer output
103 continue
104
Jeremy Johnson015c3552022-02-23 12:15:03 +0000105 arg_list.append(
106 [
107 "_st{}{}_pad{}_dilat{}{}".format(
108 stride_h,
109 stride_w,
110 padding,
111 dilation_h,
112 dilation_w,
113 ),
114 [
115 [stride_h, stride_w],
116 padding,
117 [dilation_h, dilation_w],
118 ],
119 ]
120 )
121 return arg_list
122
TatWai Chongfd629052022-07-25 04:01:58 +0000123 # conv3d argument generators build the TF constants
124 def agConv3d(op, shapes, rng):
125 arg_list = []
126
127 # input shape = [OC, KD, KH, KW, IC]
128 # Must be rank 5
129 if len(shapes) != 5:
130 return arg_list
131
132 if len(op["filter"]) < 3:
133 return arg_list
134
135 filter_d, filter_h, filter_w = op["filter"]
136
137 # strides, padding, dilations,
138 for stride_d in [1, 2]:
139 for stride_h in [1, 2]:
140 for stride_w in [1, 2]:
141 for padding in ["SAME", "VALID"]:
142 for dilation_d in [1, 2]:
143 for dilation_h in [1, 2]:
144 for dilation_w in [1, 2]:
145
146 # Disqualify argument combinations that would cause
147 # an illegal convolution
148 # fmt: off
149 if (padding == "VALID") and (
150 (shapes[1] - (filter_d - 1) * 2 - dilation_d) <= 0
151 or (shapes[2] - (filter_h - 1) * 2 - dilation_h) <= 0
152 or (shapes[3] - (filter_w - 1) * 2 - dilation_w) <= 0
153 ):
154 continue
155
156 if (
157 (shapes[1] - 1 - (filter_d - 1) * dilation_d) % stride_d
158 != 0
159 ) or (
160 (shapes[2] - 1 - (filter_h - 1) * dilation_h) % stride_h
161 != 0
162 ) or (
163 (shapes[3] - 1 - (filter_w - 1) * dilation_w) % stride_w
164 != 0
165 ):
166 # Not an exact integer output
167 continue
168 # fmt: on
169
170 # TODO investigate the error of `CPU implementation of Conv3D
171 # currently only supports dilated rates of 1.` from Tensorflow.
172 # Only test dilations = [1, 1, 1, 1, 1] for now.
173 if (
174 (dilation_d != 1)
175 or (dilation_h != 1)
176 or (dilation_w != 1)
177 ):
178 continue
179
180 # Tensorflow expects strides is a list of ints that has length >= 5.
181 # Strides and dilations in the batch and depth dimensions must be 1.
182 arg_list.append(
183 [
184 "_st{}{}{}{}{}_pad{}_dilat{}{}{}{}{}".format(
185 1,
186 stride_d,
187 stride_h,
188 stride_w,
189 1,
190 padding,
191 1,
192 dilation_d,
193 dilation_h,
194 dilation_w,
195 1,
196 ),
197 [
198 [1, stride_d, stride_h, stride_w, 1],
199 padding,
200 [
201 1,
202 dilation_d,
203 dilation_h,
204 dilation_w,
205 1,
206 ],
207 ],
208 ]
209 )
210 return arg_list
211
Jeremy Johnson015c3552022-02-23 12:15:03 +0000212 # conv2d argument generators build the TF constants
213 def agDepthwiseConv2d(op, shapes, rng):
214 arg_list = []
215
216 # Must be rank 4
217 if len(shapes) < 4:
218 return arg_list
219
220 filter_h, filter_w = op["filter"]
221
222 # strides, padding, dilations, Depthwise conv2d is the same as conv2d
223 # except that strides in h/w must be the same and the argument must be
224 # formatted as [1, stride_h, stride_w, 1] in TF.
225 for stride in [1, 2]:
226 for padding in ["SAME", "VALID"]:
227 for dilation_h in [1, 2]:
228 for dilation_w in [1, 2]:
229
230 # Disqualify argument combinations that would cause an illegal
231 # convolution
232
233 if (padding == "VALID") and (
234 (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
235 or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
236 ):
237 continue
238
239 # When dilation is used, stride must be 1x1 (TF rules)
240 if dilation_h > 1 or dilation_w > 1:
241 if stride > 1:
242 continue
243
244 # Dilation must evenly divide the tensor. Some of our inputs
245 # intentionally use odd-sized tensors.
246 if shapes[1] % dilation_h != 0 or shapes[2] % dilation_w != 0:
247 continue
248
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100249 if (
250 (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride != 0
251 ) or (
252 (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride != 0
253 ):
254 # Not an exact integer output
255 continue
256
Jeremy Johnson015c3552022-02-23 12:15:03 +0000257 arg_list.append(
258 [
259 "_st{}{}_pad{}_dilat{}{}".format(
260 stride, stride, padding, dilation_h, dilation_w
261 ),
262 [
263 [1, stride, stride, 1],
264 padding,
265 [dilation_h, dilation_w],
266 ],
267 ]
268 )
269 return arg_list
270
271 # conv2d argument generators build the TF constants
272 def agTransposeConv2d(op, shapes, rng):
273 arg_list = []
274
275 # Must be rank 4
276 if len(shapes) < 4:
277 return arg_list
278
279 filter_h, filter_w = op["filter"]
280
281 # strides, padding, dilations,
282 for stride_h in [1, 2]:
283 for stride_w in [1, 2]:
284 for padding in ["SAME", "VALID"]:
285 if padding == "SAME":
286 out_height = (shapes[1]) * stride_h
287 out_width = (shapes[2]) * stride_w
288 else: # padding == 'VALID'
289 out_height = (shapes[1] - 1) * stride_h + filter_h
290 out_width = (shapes[2] - 1) * stride_w + filter_w
291
292 output_shape = [shapes[0], out_height, out_width, shapes[3] * 2]
293 arg_list.append(
294 [
295 "_st{}{}_pad{}".format(stride_h, stride_w, padding),
296 [output_shape, [stride_h, stride_w], padding],
297 ]
298 )
299 return arg_list
300
301 def agPooling(op, shapes, rng):
302 arg_list = []
303
304 # Must be rank 4
305 if len(shapes) < 4:
306 return arg_list
307
308 for stride_h in [1, 2]:
309 for stride_w in [1, 2]:
310 for kernel_h in [1, 2]:
311 for kernel_w in [1, 2]:
312 for padding in ["SAME", "VALID"]:
313
314 if (padding == "VALID") and (
315 (shapes[1] % (kernel_h * stride_h) > 0)
316 or (shapes[2] % (kernel_w * stride_w) > 0)
317 or (shapes[1] <= kernel_h)
318 or (shapes[2] <= kernel_w)
319 ):
320 continue
321
322 if (padding == "SAME") and (
323 (shapes[1] < kernel_h) or (shapes[2] < kernel_w)
324 ):
325 continue
326
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100327 if ((shapes[1] - kernel_h) % stride_h != 0) or (
328 (shapes[2] - kernel_w) % stride_w != 0
329 ):
330 # Not an exact integer output
331 continue
332
Jeremy Johnson015c3552022-02-23 12:15:03 +0000333 arg_list.append(
334 [
335 "_st{}{}_pad{}_kern{}{}".format(
336 stride_h, stride_w, padding, kernel_h, kernel_w
337 ),
338 [
339 [stride_h, stride_w],
340 [kernel_h, kernel_w],
341 padding,
342 ],
343 ]
344 )
345 return arg_list
346
347 def getFactors(val, start=1):
348 factors = []
349 for i in range(start, int(np.sqrt(val))):
350 if (val % i) == 0:
351 factors.append(i)
352
353 return factors
354
355 def agReshape(op, shapes, rng):
356 # This is slow code. Fortunately, the numbers involved are small
357 arg_list = []
358
359 total_elements = 1
360 for s in shapes:
361 total_elements *= s
362
363 # Find integer factors of this shape
364 factors = ArgGen.getFactors(total_elements)
365
366 for rank in range(1, len(shapes) + 1):
367 if len(factors) < rank:
368 break
369
370 new_shape = []
371 remaining_elements = total_elements
372
373 # Randomly shuffle the factors and iteratively pick from the factors
374 # of the remaining elements
375 shuffled_factors = rng.permutation(factors)
376 for i in range(rank):
377 # Pick rank - 1 factors
378 new_shape.append(shuffled_factors[0])
379 remaining_elements = remaining_elements // shuffled_factors[0]
380 shuffled_factors = rng.permutation(
381 ArgGen.getFactors(remaining_elements)
382 )
383 new_shape.append(remaining_elements)
384
385 # Don't do no-op reshapes because TFLite optimizes out the op
386 if new_shape == list(shapes):
387 continue
388
389 arg_list.append(["_rank{}".format(rank), [new_shape]])
390
391 return arg_list
392
393 def agTranspose(op, shapes, rng):
394 arg_list = []
395
396 # Must have at least two dimensions to transpose
397 if (len(shapes)) < 2:
398 return arg_list
399
400 # Pick a bunch of random permutations
401 range_arr = np.arange(len(shapes))
402 for i in range(len(shapes)):
403 perm = rng.permutation(range_arr).astype(np.int32)
404 # print('\n shape {} permute{} perm: {} arr: {}'.format(shapes, i,
405 # perm, range_arr))
406 if np.allclose(perm, range_arr):
407 print("skipped")
408 continue
409 arg_list.append(["_permute{}".format(i), [perm]])
410
411 return arg_list
412
413 def agSlice(op, shapes, rng):
414 arg_list = []
415
416 rank = len(shapes)
417
418 if rank == 1 and shapes[0] == 1:
419 return arg_list
420
421 for i in range(4):
422 # Pick a few random start points, axes, and strides
423 start = np.empty((rank), dtype=int)
424 size = np.empty((rank), dtype=int)
425 for j in range(rank):
426 if shapes[j] > 2:
427 start[j] = rng.integers(0, shapes[j] - 2)
428 # print('j = {}: {} - {} - 1: {}'.format(j, shapes[j],
429 # start[j], shapes[j] - start[j] - 1))
430 size[j] = rng.integers(1, shapes[j] - start[j] - 1)
431 else:
432 start[j] = 0
433 size[j] = shapes[j]
434
435 arg_list.append(["_perm{}".format(i), [start, size]])
436
437 return arg_list
438
439 def agStridedSlice(op, shapes, rng):
440 arg_list = []
441
442 rank = len(shapes)
443
444 # Reference model is limited to rank=6 internally right now
445 if rank > 3:
446 return arg_list
447
448 if rank == 1 and shapes[0] == 1:
449 return arg_list
450
451 for i in range(4):
452 # Pick a few random begin points, axes, and strides
453 begin = np.empty((rank), dtype=int)
454 end = np.empty((rank), dtype=int)
455 strides = np.empty((rank), dtype=int)
456
457 begin_mask = rng.integers(0, (1 << (rank - 1)))
458 end_mask = rng.integers(0, (1 << (rank - 1)))
459
460 for j in range(rank):
461
462 if begin_mask & (1 << j) or shapes[j] < 2:
463 begin[j] = 0
464 else:
465 begin[j] = rng.integers(0, shapes[j] - 1)
466
467 if end_mask & (1 << j) or shapes[j] < 2 or (begin[j] + 2) >= shapes[j]:
468 end[j] = shapes[j]
469 else:
470 end[j] = rng.integers(begin[j] + 1, shapes[j] - 1)
471
472 possible_stride = ArgGen.getFactors(end[j] - begin[j], 2)
473
474 if not possible_stride:
475 strides[j] = 1
476 else:
477 strides[j] = rng.choice(possible_stride)
478
479 # Randomly set the masks, except ellipsis_mask and new_axis_mask
480 # which must be zero for now For begin/end mask this to work,
481 # strides must be adjusted to still be divsible...
482 ellipsis_mask = 0
483 new_axis_mask = 0
484
485 # if rng.choice([0, 1]) and rank > 1:
486 # new_axis_mask = 1 << rng.integers(0, rank - 1)
487 # else:
488 # new_axis_mask = 0
489
490 if rng.choice([0, 1]) and rank > 1:
491 shrink_axis_mask = 1 << rng.integers(0, rank - 1)
492 else:
493 shrink_axis_mask = 0
494
495 # Only one of these bits may be set. Prefer shrink_axis_mask
496 new_axis_mask = new_axis_mask & ~shrink_axis_mask
497
498 arg_list.append(
499 [
500 "_perm{}".format(i),
501 [
502 begin,
503 end,
504 strides,
505 begin_mask,
506 end_mask,
507 ellipsis_mask,
508 new_axis_mask,
509 shrink_axis_mask,
510 ],
511 ]
512 )
513
514 # print('Shape: {} begin={} end={} strides={} begin_mask={:x}
515 # end_mask={:x} new_axis_mask={:x} shrink_mask={:x}'.format(shapes,
516 # begin, end, strides, begin_mask, end_mask, new_axis_mask,
517 # shrink_axis_mask))
518
519 return arg_list
520
521 # tf.stack axis can be [0, rank(input)]
522 def agStack(op, shapes, rng):
523 axes = []
524 for i in range(len(shapes) + 1):
525 axes.append(["_axis{}".format(i), [i]])
526 return axes
527
528 def agPad(op, shapes, rng):
529 arg_list = []
530
531 rank = len(shapes)
532 for left in range(3):
533 for right in range(3):
534 paddings = np.zeros((rank, 2), dtype=np.int32)
535 for d in range(rank):
536 paddings[d, 0] = left
537 paddings[d, 1] = right
538
539 arg_list.append(["_pad{}{}".format(left, right), [paddings]])
540 return arg_list
541
542 def agFill(op, shapes, rng):
543 values = []
544 for i in range(4):
545 value = rng.integers(0, 10, dtype=np.int32)
546 values.append(["_value{}".format(value), [shapes, value]])
547 return values
548
549 def getValuesToSum(total, rng):
550 # Get a list of random integers that sum up to 'total'
551 vals = []
552
553 # np.random.randint() min and max to be different, so if the remainder
554 # is 1, give up
555 while total > 1:
556 vals.append(rng.integers(1, total))
557 total = total - vals[-1]
558
559 if total == 1:
560 vals.append(1)
561
562 return vals
563
564 def agSplit(op, shapes, rng):
565 arg_list = []
566
567 rank = len(shapes)
568
569 # Shuffle the random number generator a few more times to get
570 # a better range of axes across shapes
571 for i in range(rank):
572 for j in range(shapes[i]):
573 rng.integers(shapes[i])
574
575 for i in range(3):
576 # Need to generate tests for both the num_splits and size_vector versions.
577 axis = rng.choice(np.arange(0, rank))
578
579 # For num_splits, get a few divisors of the given axis
580 divs = ArgGen.getFactors(shapes[axis], 2)
581
582 if divs:
583 # Get no more than 2 samples
584 splits = list(rng.choice(divs, size=2))
585
586 for s in splits:
587 arg_list.append(
588 ["_split{}_axis{}".format(int(s), axis), [int(s), axis]]
589 )
590
591 # For vector splits, get a list of integers that sum up to the axis size
592 vals = ArgGen.getValuesToSum(shapes[axis], rng)
593
594 if len(vals) > 1:
595 arg_list.append(["_splitv_axis{}".format(axis), [vals, axis]])
596
597 return arg_list
598
599 def agTile(op, shapes, rng):
600 arg_list = []
601
602 rank = len(shapes)
603
604 # create 1D multiples list
605 multiples = list()
606 for i in range(rank):
607 multiples.append(rng.integers(1, 4))
608
609 multiples_str = "x".join(list(str(i) for i in multiples))
610
611 arg_list.append(["_tile_{}".format(multiples_str), [multiples]])
612
613 return arg_list
614
615 def agGather(op, shapes, rng):
616 args = []
617 for batch_dims in range(len(shapes) - 1):
618 for axis in range(batch_dims, len(shapes)):
619 # indices value must be within [0, shapes[i])
620
621 # Create an arbitrary shape for the indices
622 indices_rank = rng.integers(batch_dims + 1, 4)
623 indices_shape = rng.integers(1, 8, size=indices_rank)
624
625 # Copy in the batch dimensions because they must match
626 for b in range(batch_dims):
627 indices_shape[b] = shapes[b]
628
629 # Calculate total element count
630 indices_size = 1
631 for j in range(indices_rank):
632 indices_size = indices_shape[j] * indices_size
633
634 indices = rng.integers(0, shapes[axis], indices_size, np.int32).reshape(
635 indices_shape
636 )
637
638 args.append(
639 [
640 "_batchdims_{}_axis_{}".format(batch_dims, axis),
641 [indices, batch_dims, axis],
642 ]
643 )
644 return args
645
646 def agGatherND(op, shapes, rng):
647 args = []
648
649 for N in range(1, len(shapes) - 1):
650 # Rank includes the N dimension
651 indices_rank = rng.integers(2, 4, size=1)[0]
652 indices_shape = []
653
654 indices_shape = rng.integers(1, 8, size=indices_rank)
655 indices_shape[-1] = N
656
657 indices_count = 1
658 for i in range(indices_rank - 1):
659 indices_count = indices_count * indices_shape[i]
660
661 indices_list = np.zeros(shape=(indices_count, N), dtype=np.int32)
662
663 for i in range(indices_count):
664 for j in range(N):
665 indices_list[i, j] = rng.integers(0, shapes[j], size=1)[0]
666
667 indices = indices_list.reshape(indices_shape)
668
669 args.append(["_n{}".format(N), [indices]])
670
671 return args
672
673 def agScatterND(op, shapes, rng):
674 args = []
675
676 # ScatterND has to generate a constant shapes tensor, indices
677 # tensor, and a tensor of updates. Unforunately, the updates
678 # need to be a size that's based on the N generated in this
679 # function and the dtype known only in the TensorGen function,
680 # but not in ArgGen.
681 #
682 # There are many bad ways to solve this and we'll choose the
683 # least of the evils which still gives reasonable coverage of
684 # the possible operand shapes.
685 for N in range(1, len(shapes)):
686 # Rank includes the N dimension
687 indices_rank = rng.integers(2, 4, size=1)[0]
688 indices_shape = []
689
690 indices_shape = rng.integers(1, 8, size=indices_rank)
691 indices_shape[-1] = N
692
693 # Store the Shapes, and the indicies value tensor as arguments.
694 args.append(["_n{}".format(N), [shapes, indices_shape, N, rng]])
695
696 return args
697
698 def agSpaceToBatch(op, shapes, rng):
699 batch_rank = 1
700 channel_rank = 1
701 block_rank = len(shapes) - batch_rank - channel_rank
702
703 # must have at least rank 1 (M) block
704 if block_rank < 1:
705 return []
706
707 args = []
708 block_shape = []
709 padding_shape = []
710
711 for i in range(block_rank):
712 block_size = 2
713 padding_size = block_size - (shapes[i + 1] % block_size)
714 block_shape.append(block_size)
715 padding_shape.append([0, padding_size])
716
717 args.append(["_blockrank_{}".format(block_rank), [block_shape, padding_shape]])
718 return args
719
720 def agBatchToSpace(op, shapes, rng):
721 batch_rank = 1
722 channel_rank = 1
723 block_rank = len(shapes) - batch_rank - channel_rank
724
725 # must have at least rank 1 (M) block
726 if block_rank < 1:
727 return []
728
729 args = []
730 block_shape = []
731 padding_shape = []
732 block_prod = 1
733
734 for i in range(block_rank):
735 block_size = 2
736 block_prod = block_prod * block_size
737 crop_size = 0
738 block_shape.append(block_size)
739 padding_shape.append([0, crop_size])
740
741 # batch / prod(block_shape[i]) must be integer
742 # transpose to swap depth and batch. so shape[-1] would be batch dim
743 if shapes[-1] % block_prod == 0:
744 args.append(
745 ["_blockrank_{}".format(block_rank), [block_shape, padding_shape]]
746 )
747
748 return args
749
750 def agSpaceToDepth(op, shapes, rng):
751 # must be rank 4 input tensor
752 if len(shapes) != 4:
753 return []
754
755 block_size = 2
756
757 # spatial dimension must be divisible by block_size
758 if shapes[1] % block_size != 0 or shapes[2] % block_size != 0:
759 return []
760
761 args = []
762 args.append(["_blocksize_{}".format(block_size), [block_size]])
763
764 return args
765
766 def agDepthToSpace(op, shapes, rng):
767 # must be rank 4 input tensor
768 if len(shapes) != 4:
769 return []
770
771 block_size = 2
772 # depth dimension must be divisible by block_size * block_size
773 if shapes[3] % (block_size * block_size) != 0:
774 return []
775
776 args = []
777 args.append(["_blocksize_{}".format(block_size), [block_size]])
778
779 return args
780
781 def agFakequant(op, shapes, rng):
782 args = []
783 for num_bits in [8, 16]:
784 for narrow in [False, True]:
785 args.append(
786 ["_bits{}_narrow{}".format(num_bits, narrow), [num_bits, narrow]]
787 )
788
789 return args
790
791 def agShift(op, shapes, rng):
792 args = []
793
794 for shift in rng.integers(0, 32, size=8):
795 args.append(["_shift{}".format(shift), [shift]])
796
797 return args
798
799 def agFloat(op, shapes, rng):
800 args = []
801
802 i = 0
803 for alpha in np.float32(rng.random(size=2)):
804 args.append(["_{}".format(i), [alpha]])
805
806 return args
807
808 # Similar to agAxes, but tf.OneHot only allow axis from [-1, rank(input)]
809 def agOneHot(op, shapes, rng):
810 axes = []
811 for i in range(-1, len(shapes) + 1, 1):
812 if i >= 0:
813 axes.append(["_axis_{}".format(i), [i]])
814 else:
815 axes.append(["_axis_m{}".format(-i), [i]])
816 return axes