blob: a25c205c0da68e124bc2c973ea79bb4d8f66b0ac [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson015c3552022-02-23 12:15:03 +00005import numpy as np
Jerry Gecb7201e2023-06-01 22:45:26 +00006from tosa.DType import DType
7
8DTYPE_ATTRIBUTES = {
9 DType.BOOL: {"str": "b", "width": 1},
10 DType.INT4: {"str": "i4", "width": 4},
11 DType.INT8: {"str": "i8", "width": 8},
12 DType.UINT8: {"str": "u8", "width": 8},
13 DType.INT16: {"str": "i16", "width": 16},
14 DType.UINT16: {"str": "u16", "width": 16},
15 DType.INT32: {"str": "i32", "width": 32},
16 DType.INT48: {"str": "i48", "width": 48},
17 DType.FP16: {"str": "f16", "width": 16},
18 DType.BF16: {"str": "bf16", "width": 16},
19 DType.FP32: {"str": "f32", "width": 32},
20}
Jeremy Johnson015c3552022-02-23 12:15:03 +000021
22
23class ArgGen:
24 """Argument generator functions. These functions take a shape and dtype to
25 create arguments for an operator. Methods are prefixed with 'ag' to make
26 search easy."""
27
28 def __init__(self):
29 pass
30
Jerry Gecb7201e2023-06-01 22:45:26 +000031 def typeStr(dtype):
32 if dtype in DTYPE_ATTRIBUTES:
33 return DTYPE_ATTRIBUTES[dtype]["str"]
34 else:
35 raise Exception("Unknown dtype, cannot convert to string: {}".format(dtype))
36
Jeremy Johnson015c3552022-02-23 12:15:03 +000037 @staticmethod
38 def agNone(op, shapes, rng):
39 """A trivial argument generator for operators that only take tensor
40 operands"""
41 return [("", [])]
42
43 # Build the axis argument for operators where we want to iterate over N axes
44 # as an argument
45 @staticmethod
46 def agAxes(op, shapes, rng):
47 axes = []
48 for i in range(-len(shapes), len(shapes), 1):
49 if i >= 0:
50 axes.append(["_axis_{}".format(i), [i]])
51 else:
52 axes.append(["_axis_m{}".format(-i), [i]])
53 return axes
54
55 # Build the axis LIST argument for operators that take an axis list.
56 # This builds a list of each axis individually, plus one element
57 # that contains a list of all axes. Note that we need to pack the list in
58 # an additional list so that it isn't exploded when being passed to the
59 # build_operator function.
60 # tensor_arg_count not used
61 def agAxesList(op, shapes, rng):
62 axes = ArgGen.agAxes(op, shapes, rng)
63 axes_list = []
64 for desc, a in axes:
65 axes_list.append([desc, [a]])
66
67 axes_list.append(["_axisall", [list(range(len(shapes)))]])
68 axes_list.append(["_axisall_none", [None]])
69 return axes_list
70
71 def agAxesListKeepdims(op, shapes, rng):
72 axes = ArgGen.agAxes(op, shapes, rng)
73 axes_list = []
74 for desc, a in axes:
75 axes_list.append([desc + "_keep0", [a, False]])
76 # avoid trying to reduce an axis of shape 1, as the TFL converter
77 # will optimize away the entire reduction
78 if (a[0] >= 0 and shapes[a[0]] != 1) or (
79 a[0] < 0 and shapes[len(shapes) + a[0]] != 1
80 ):
81 axes_list.append([desc + "_keep1", [a, True]])
82
83 axes_list.append(["_axisall_keep0", [list(range(len(shapes))), False]])
84 axes_list.append(["_axisall_keep0_none", [None, False]])
85 # another instance where the reduce gets optimized out.
86 if len(shapes) != 1:
87 axes_list.append(["_axisall_keep1", [list(range(len(shapes))), True]])
88 axes_list.append(["_axisall_keep1_none", [None, True]])
89 # no longer test axis empty, as TFL converter optimizes the reduce out
90 return axes_list
91
92 # conv2d argument generators build the TF constants
93 def agConv2d(op, shapes, rng):
94 arg_list = []
95
96 # Must be rank 4
97 if len(shapes) < 4:
98 return arg_list
99
100 filter_h, filter_w = op["filter"]
101
102 # strides, padding, dilations,
103 for stride_h in [1, 2]:
104 for stride_w in [1, 2]:
105 for padding in ["SAME", "VALID"]:
106 for dilation_h in [1, 2]:
107 for dilation_w in [1, 2]:
108
109 # Disqualify argument combinations that would cause
110 # an illegal convolution
111
112 if (padding == "VALID") and (
113 (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
114 or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
115 ):
116 continue
117
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100118 if (
119 (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride_h
120 != 0
121 ) or (
122 (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride_w
123 != 0
124 ):
125 # Not an exact integer output
126 continue
127
Jeremy Johnson015c3552022-02-23 12:15:03 +0000128 arg_list.append(
129 [
130 "_st{}{}_pad{}_dilat{}{}".format(
131 stride_h,
132 stride_w,
133 padding,
134 dilation_h,
135 dilation_w,
136 ),
137 [
138 [stride_h, stride_w],
139 padding,
140 [dilation_h, dilation_w],
141 ],
142 ]
143 )
144 return arg_list
145
TatWai Chongfd629052022-07-25 04:01:58 +0000146 # conv3d argument generators build the TF constants
147 def agConv3d(op, shapes, rng):
148 arg_list = []
149
150 # input shape = [OC, KD, KH, KW, IC]
151 # Must be rank 5
152 if len(shapes) != 5:
153 return arg_list
154
155 if len(op["filter"]) < 3:
156 return arg_list
157
158 filter_d, filter_h, filter_w = op["filter"]
159
160 # strides, padding, dilations,
161 for stride_d in [1, 2]:
162 for stride_h in [1, 2]:
163 for stride_w in [1, 2]:
164 for padding in ["SAME", "VALID"]:
165 for dilation_d in [1, 2]:
166 for dilation_h in [1, 2]:
167 for dilation_w in [1, 2]:
168
169 # Disqualify argument combinations that would cause
170 # an illegal convolution
171 # fmt: off
172 if (padding == "VALID") and (
173 (shapes[1] - (filter_d - 1) * 2 - dilation_d) <= 0
174 or (shapes[2] - (filter_h - 1) * 2 - dilation_h) <= 0
175 or (shapes[3] - (filter_w - 1) * 2 - dilation_w) <= 0
176 ):
177 continue
178
179 if (
180 (shapes[1] - 1 - (filter_d - 1) * dilation_d) % stride_d
181 != 0
182 ) or (
183 (shapes[2] - 1 - (filter_h - 1) * dilation_h) % stride_h
184 != 0
185 ) or (
186 (shapes[3] - 1 - (filter_w - 1) * dilation_w) % stride_w
187 != 0
188 ):
189 # Not an exact integer output
190 continue
191 # fmt: on
192
193 # TODO investigate the error of `CPU implementation of Conv3D
194 # currently only supports dilated rates of 1.` from Tensorflow.
195 # Only test dilations = [1, 1, 1, 1, 1] for now.
196 if (
197 (dilation_d != 1)
198 or (dilation_h != 1)
199 or (dilation_w != 1)
200 ):
201 continue
202
203 # Tensorflow expects strides is a list of ints that has length >= 5.
204 # Strides and dilations in the batch and depth dimensions must be 1.
205 arg_list.append(
206 [
207 "_st{}{}{}{}{}_pad{}_dilat{}{}{}{}{}".format(
208 1,
209 stride_d,
210 stride_h,
211 stride_w,
212 1,
213 padding,
214 1,
215 dilation_d,
216 dilation_h,
217 dilation_w,
218 1,
219 ),
220 [
221 [1, stride_d, stride_h, stride_w, 1],
222 padding,
223 [
224 1,
225 dilation_d,
226 dilation_h,
227 dilation_w,
228 1,
229 ],
230 ],
231 ]
232 )
233 return arg_list
234
Jeremy Johnson015c3552022-02-23 12:15:03 +0000235 # conv2d argument generators build the TF constants
236 def agDepthwiseConv2d(op, shapes, rng):
237 arg_list = []
238
239 # Must be rank 4
240 if len(shapes) < 4:
241 return arg_list
242
243 filter_h, filter_w = op["filter"]
244
245 # strides, padding, dilations, Depthwise conv2d is the same as conv2d
246 # except that strides in h/w must be the same and the argument must be
247 # formatted as [1, stride_h, stride_w, 1] in TF.
248 for stride in [1, 2]:
249 for padding in ["SAME", "VALID"]:
250 for dilation_h in [1, 2]:
251 for dilation_w in [1, 2]:
252
253 # Disqualify argument combinations that would cause an illegal
254 # convolution
255
256 if (padding == "VALID") and (
257 (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0
258 or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0
259 ):
260 continue
261
262 # When dilation is used, stride must be 1x1 (TF rules)
263 if dilation_h > 1 or dilation_w > 1:
264 if stride > 1:
265 continue
266
267 # Dilation must evenly divide the tensor. Some of our inputs
268 # intentionally use odd-sized tensors.
269 if shapes[1] % dilation_h != 0 or shapes[2] % dilation_w != 0:
270 continue
271
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100272 if (
273 (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride != 0
274 ) or (
275 (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride != 0
276 ):
277 # Not an exact integer output
278 continue
279
Jeremy Johnson015c3552022-02-23 12:15:03 +0000280 arg_list.append(
281 [
282 "_st{}{}_pad{}_dilat{}{}".format(
283 stride, stride, padding, dilation_h, dilation_w
284 ),
285 [
286 [1, stride, stride, 1],
287 padding,
288 [dilation_h, dilation_w],
289 ],
290 ]
291 )
292 return arg_list
293
294 # conv2d argument generators build the TF constants
295 def agTransposeConv2d(op, shapes, rng):
296 arg_list = []
297
298 # Must be rank 4
299 if len(shapes) < 4:
300 return arg_list
301
302 filter_h, filter_w = op["filter"]
303
304 # strides, padding, dilations,
305 for stride_h in [1, 2]:
306 for stride_w in [1, 2]:
307 for padding in ["SAME", "VALID"]:
308 if padding == "SAME":
309 out_height = (shapes[1]) * stride_h
310 out_width = (shapes[2]) * stride_w
311 else: # padding == 'VALID'
312 out_height = (shapes[1] - 1) * stride_h + filter_h
313 out_width = (shapes[2] - 1) * stride_w + filter_w
314
315 output_shape = [shapes[0], out_height, out_width, shapes[3] * 2]
316 arg_list.append(
317 [
318 "_st{}{}_pad{}".format(stride_h, stride_w, padding),
319 [output_shape, [stride_h, stride_w], padding],
320 ]
321 )
322 return arg_list
323
324 def agPooling(op, shapes, rng):
325 arg_list = []
326
327 # Must be rank 4
328 if len(shapes) < 4:
329 return arg_list
330
331 for stride_h in [1, 2]:
332 for stride_w in [1, 2]:
333 for kernel_h in [1, 2]:
334 for kernel_w in [1, 2]:
335 for padding in ["SAME", "VALID"]:
336
337 if (padding == "VALID") and (
338 (shapes[1] % (kernel_h * stride_h) > 0)
339 or (shapes[2] % (kernel_w * stride_w) > 0)
340 or (shapes[1] <= kernel_h)
341 or (shapes[2] <= kernel_w)
342 ):
343 continue
344
345 if (padding == "SAME") and (
346 (shapes[1] < kernel_h) or (shapes[2] < kernel_w)
347 ):
348 continue
349
Jeremy Johnson0e6218e2022-05-05 17:08:04 +0100350 if ((shapes[1] - kernel_h) % stride_h != 0) or (
351 (shapes[2] - kernel_w) % stride_w != 0
352 ):
353 # Not an exact integer output
354 continue
355
Jerry Gecb7201e2023-06-01 22:45:26 +0000356 # Note: tf.nn.avg_pool2d API doesn't support setting accumtype
357 # setting a dummy value to the test name as an reminder
358 accum_dtype = ArgGen.typeStr(DType.INT32)
Jeremy Johnson015c3552022-02-23 12:15:03 +0000359 arg_list.append(
360 [
Jerry Gecb7201e2023-06-01 22:45:26 +0000361 "_st{}{}_pad{}_kern{}{}_acc{}".format(
362 stride_h,
363 stride_w,
364 padding,
365 kernel_h,
366 kernel_w,
367 accum_dtype,
Jeremy Johnson015c3552022-02-23 12:15:03 +0000368 ),
369 [
370 [stride_h, stride_w],
371 [kernel_h, kernel_w],
372 padding,
373 ],
374 ]
375 )
376 return arg_list
377
378 def getFactors(val, start=1):
379 factors = []
380 for i in range(start, int(np.sqrt(val))):
381 if (val % i) == 0:
382 factors.append(i)
383
384 return factors
385
386 def agReshape(op, shapes, rng):
387 # This is slow code. Fortunately, the numbers involved are small
388 arg_list = []
389
390 total_elements = 1
391 for s in shapes:
392 total_elements *= s
393
394 # Find integer factors of this shape
395 factors = ArgGen.getFactors(total_elements)
396
397 for rank in range(1, len(shapes) + 1):
398 if len(factors) < rank:
399 break
400
401 new_shape = []
402 remaining_elements = total_elements
403
404 # Randomly shuffle the factors and iteratively pick from the factors
405 # of the remaining elements
406 shuffled_factors = rng.permutation(factors)
407 for i in range(rank):
408 # Pick rank - 1 factors
409 new_shape.append(shuffled_factors[0])
410 remaining_elements = remaining_elements // shuffled_factors[0]
411 shuffled_factors = rng.permutation(
412 ArgGen.getFactors(remaining_elements)
413 )
414 new_shape.append(remaining_elements)
415
416 # Don't do no-op reshapes because TFLite optimizes out the op
417 if new_shape == list(shapes):
418 continue
419
420 arg_list.append(["_rank{}".format(rank), [new_shape]])
421
422 return arg_list
423
424 def agTranspose(op, shapes, rng):
425 arg_list = []
426
427 # Must have at least two dimensions to transpose
428 if (len(shapes)) < 2:
429 return arg_list
430
431 # Pick a bunch of random permutations
432 range_arr = np.arange(len(shapes))
433 for i in range(len(shapes)):
434 perm = rng.permutation(range_arr).astype(np.int32)
435 # print('\n shape {} permute{} perm: {} arr: {}'.format(shapes, i,
436 # perm, range_arr))
437 if np.allclose(perm, range_arr):
438 print("skipped")
439 continue
440 arg_list.append(["_permute{}".format(i), [perm]])
441
442 return arg_list
443
444 def agSlice(op, shapes, rng):
445 arg_list = []
446
447 rank = len(shapes)
448
449 if rank == 1 and shapes[0] == 1:
450 return arg_list
451
452 for i in range(4):
453 # Pick a few random start points, axes, and strides
454 start = np.empty((rank), dtype=int)
455 size = np.empty((rank), dtype=int)
456 for j in range(rank):
457 if shapes[j] > 2:
458 start[j] = rng.integers(0, shapes[j] - 2)
459 # print('j = {}: {} - {} - 1: {}'.format(j, shapes[j],
460 # start[j], shapes[j] - start[j] - 1))
461 size[j] = rng.integers(1, shapes[j] - start[j] - 1)
462 else:
463 start[j] = 0
464 size[j] = shapes[j]
465
466 arg_list.append(["_perm{}".format(i), [start, size]])
467
468 return arg_list
469
470 def agStridedSlice(op, shapes, rng):
471 arg_list = []
472
473 rank = len(shapes)
474
475 # Reference model is limited to rank=6 internally right now
476 if rank > 3:
477 return arg_list
478
479 if rank == 1 and shapes[0] == 1:
480 return arg_list
481
482 for i in range(4):
483 # Pick a few random begin points, axes, and strides
484 begin = np.empty((rank), dtype=int)
485 end = np.empty((rank), dtype=int)
486 strides = np.empty((rank), dtype=int)
487
488 begin_mask = rng.integers(0, (1 << (rank - 1)))
489 end_mask = rng.integers(0, (1 << (rank - 1)))
490
491 for j in range(rank):
492
493 if begin_mask & (1 << j) or shapes[j] < 2:
494 begin[j] = 0
495 else:
496 begin[j] = rng.integers(0, shapes[j] - 1)
497
498 if end_mask & (1 << j) or shapes[j] < 2 or (begin[j] + 2) >= shapes[j]:
499 end[j] = shapes[j]
500 else:
501 end[j] = rng.integers(begin[j] + 1, shapes[j] - 1)
502
503 possible_stride = ArgGen.getFactors(end[j] - begin[j], 2)
504
505 if not possible_stride:
506 strides[j] = 1
507 else:
508 strides[j] = rng.choice(possible_stride)
509
510 # Randomly set the masks, except ellipsis_mask and new_axis_mask
511 # which must be zero for now For begin/end mask this to work,
512 # strides must be adjusted to still be divsible...
513 ellipsis_mask = 0
514 new_axis_mask = 0
515
516 # if rng.choice([0, 1]) and rank > 1:
517 # new_axis_mask = 1 << rng.integers(0, rank - 1)
518 # else:
519 # new_axis_mask = 0
520
521 if rng.choice([0, 1]) and rank > 1:
522 shrink_axis_mask = 1 << rng.integers(0, rank - 1)
523 else:
524 shrink_axis_mask = 0
525
526 # Only one of these bits may be set. Prefer shrink_axis_mask
527 new_axis_mask = new_axis_mask & ~shrink_axis_mask
528
529 arg_list.append(
530 [
531 "_perm{}".format(i),
532 [
533 begin,
534 end,
535 strides,
536 begin_mask,
537 end_mask,
538 ellipsis_mask,
539 new_axis_mask,
540 shrink_axis_mask,
541 ],
542 ]
543 )
544
545 # print('Shape: {} begin={} end={} strides={} begin_mask={:x}
546 # end_mask={:x} new_axis_mask={:x} shrink_mask={:x}'.format(shapes,
547 # begin, end, strides, begin_mask, end_mask, new_axis_mask,
548 # shrink_axis_mask))
549
550 return arg_list
551
552 # tf.stack axis can be [0, rank(input)]
553 def agStack(op, shapes, rng):
554 axes = []
555 for i in range(len(shapes) + 1):
556 axes.append(["_axis{}".format(i), [i]])
557 return axes
558
TatWai Chongf7008da2022-09-09 09:35:40 +0000559 def agMirrorPad(op, shapes, rng):
560 arg_list = []
561
562 rank = len(shapes)
563 for mode in ["REFLECT", "SYMMETRIC"]:
564 for left in range(3):
565 for right in range(3):
566 paddings = np.zeros((rank, 2), dtype=np.int32)
567 is_valid = True
568
569 # Fill in the padding parameter if the values are valid on each dimension,
570 # otherwise drop that case.
571 for d in range(rank):
572 paddings[d, 0] = left
573 paddings[d, 1] = right
574
575 # In "REFLECT" mode, paddings must be no greater than tensor dim size - 1.
576 if mode == "REFLECT":
577 if (left > shapes[d] - 1) or (right > shapes[d] - 1):
578 is_valid = False
579 break
580
581 # In "SYMMETRIC" mode, paddings must be no greater than tensor dim size.
582 else:
583 if (left > shapes[d]) or (right > shapes[d]):
584 is_valid = False
585 break
586
587 if is_valid:
588 arg_list.append(
589 [
590 "_pad{}{}_{}".format(left, right, mode[0:3].lower()),
591 [paddings, mode],
592 ]
593 )
594 return arg_list
595
Jeremy Johnson015c3552022-02-23 12:15:03 +0000596 def agPad(op, shapes, rng):
597 arg_list = []
598
599 rank = len(shapes)
600 for left in range(3):
601 for right in range(3):
TatWai Chong2226f902023-02-22 18:38:01 -0800602 # Padding nothing in tensorflow lite causes the interpreter fail to set
603 # the input tensor properly due to date type mismatch.
604 if (left == 0) and (right == 0):
605 continue
Jeremy Johnson015c3552022-02-23 12:15:03 +0000606
TatWai Chong2226f902023-02-22 18:38:01 -0800607 # A simple way to generate explicit pad_const including zero.
608 pad_const = (left - right) * rng.integers(0, 5, dtype=np.int32)
609 padding = np.zeros((rank, 2), dtype=np.int32)
610 for d in range(rank):
611 padding[d, 0] = left
612 padding[d, 1] = right
613
614 arg_list.append(
615 ["_pad{}{}".format(left, right), [padding, pad_const]]
616 )
Jeremy Johnson015c3552022-02-23 12:15:03 +0000617 return arg_list
618
TatWai Chong0cef07e2023-02-27 13:22:52 -0800619 def agResize(op, shapes, rng):
620 args = []
621 for mode in ["nearest", "bilinear"]:
622 for align_corners in [True, False]:
623 for half_pixel in [True, False]:
624 # If half_pixel_centers is True, align_corners must be False.
Jerry Gec9376092023-03-08 17:50:49 +0000625 if (align_corners is True) and (half_pixel is True):
TatWai Chong0cef07e2023-02-27 13:22:52 -0800626 continue
627
628 for i in range(1, 4):
629 args.append(
630 [
631 "_{}_align{}_half{}_scale{}".format(
632 mode, int(align_corners), int(half_pixel), i
633 ),
634 [mode, align_corners, half_pixel, i],
635 ]
636 )
637 return args
638
Jeremy Johnson015c3552022-02-23 12:15:03 +0000639 def agFill(op, shapes, rng):
640 values = []
641 for i in range(4):
642 value = rng.integers(0, 10, dtype=np.int32)
643 values.append(["_value{}".format(value), [shapes, value]])
644 return values
645
646 def getValuesToSum(total, rng):
647 # Get a list of random integers that sum up to 'total'
648 vals = []
649
650 # np.random.randint() min and max to be different, so if the remainder
651 # is 1, give up
652 while total > 1:
653 vals.append(rng.integers(1, total))
654 total = total - vals[-1]
655
656 if total == 1:
657 vals.append(1)
658
659 return vals
660
661 def agSplit(op, shapes, rng):
662 arg_list = []
663
664 rank = len(shapes)
665
666 # Shuffle the random number generator a few more times to get
667 # a better range of axes across shapes
668 for i in range(rank):
669 for j in range(shapes[i]):
670 rng.integers(shapes[i])
671
672 for i in range(3):
673 # Need to generate tests for both the num_splits and size_vector versions.
674 axis = rng.choice(np.arange(0, rank))
675
676 # For num_splits, get a few divisors of the given axis
677 divs = ArgGen.getFactors(shapes[axis], 2)
678
679 if divs:
680 # Get no more than 2 samples
681 splits = list(rng.choice(divs, size=2))
682
683 for s in splits:
684 arg_list.append(
685 ["_split{}_axis{}".format(int(s), axis), [int(s), axis]]
686 )
687
688 # For vector splits, get a list of integers that sum up to the axis size
689 vals = ArgGen.getValuesToSum(shapes[axis], rng)
690
691 if len(vals) > 1:
692 arg_list.append(["_splitv_axis{}".format(axis), [vals, axis]])
693
694 return arg_list
695
696 def agTile(op, shapes, rng):
697 arg_list = []
698
699 rank = len(shapes)
700
701 # create 1D multiples list
702 multiples = list()
703 for i in range(rank):
704 multiples.append(rng.integers(1, 4))
705
706 multiples_str = "x".join(list(str(i) for i in multiples))
707
708 arg_list.append(["_tile_{}".format(multiples_str), [multiples]])
709
710 return arg_list
711
712 def agGather(op, shapes, rng):
713 args = []
714 for batch_dims in range(len(shapes) - 1):
715 for axis in range(batch_dims, len(shapes)):
716 # indices value must be within [0, shapes[i])
717
718 # Create an arbitrary shape for the indices
719 indices_rank = rng.integers(batch_dims + 1, 4)
720 indices_shape = rng.integers(1, 8, size=indices_rank)
721
722 # Copy in the batch dimensions because they must match
723 for b in range(batch_dims):
724 indices_shape[b] = shapes[b]
725
726 # Calculate total element count
727 indices_size = 1
728 for j in range(indices_rank):
729 indices_size = indices_shape[j] * indices_size
730
731 indices = rng.integers(0, shapes[axis], indices_size, np.int32).reshape(
732 indices_shape
733 )
734
735 args.append(
736 [
737 "_batchdims_{}_axis_{}".format(batch_dims, axis),
738 [indices, batch_dims, axis],
739 ]
740 )
741 return args
742
743 def agGatherND(op, shapes, rng):
744 args = []
745
746 for N in range(1, len(shapes) - 1):
747 # Rank includes the N dimension
748 indices_rank = rng.integers(2, 4, size=1)[0]
749 indices_shape = []
750
751 indices_shape = rng.integers(1, 8, size=indices_rank)
752 indices_shape[-1] = N
753
754 indices_count = 1
755 for i in range(indices_rank - 1):
756 indices_count = indices_count * indices_shape[i]
757
758 indices_list = np.zeros(shape=(indices_count, N), dtype=np.int32)
759
760 for i in range(indices_count):
761 for j in range(N):
762 indices_list[i, j] = rng.integers(0, shapes[j], size=1)[0]
763
764 indices = indices_list.reshape(indices_shape)
765
766 args.append(["_n{}".format(N), [indices]])
767
768 return args
769
770 def agScatterND(op, shapes, rng):
771 args = []
772
773 # ScatterND has to generate a constant shapes tensor, indices
774 # tensor, and a tensor of updates. Unforunately, the updates
775 # need to be a size that's based on the N generated in this
776 # function and the dtype known only in the TensorGen function,
777 # but not in ArgGen.
778 #
779 # There are many bad ways to solve this and we'll choose the
780 # least of the evils which still gives reasonable coverage of
781 # the possible operand shapes.
782 for N in range(1, len(shapes)):
783 # Rank includes the N dimension
784 indices_rank = rng.integers(2, 4, size=1)[0]
785 indices_shape = []
786
787 indices_shape = rng.integers(1, 8, size=indices_rank)
788 indices_shape[-1] = N
789
790 # Store the Shapes, and the indicies value tensor as arguments.
791 args.append(["_n{}".format(N), [shapes, indices_shape, N, rng]])
792
793 return args
794
795 def agSpaceToBatch(op, shapes, rng):
796 batch_rank = 1
797 channel_rank = 1
798 block_rank = len(shapes) - batch_rank - channel_rank
799
800 # must have at least rank 1 (M) block
801 if block_rank < 1:
802 return []
803
804 args = []
805 block_shape = []
806 padding_shape = []
807
808 for i in range(block_rank):
809 block_size = 2
810 padding_size = block_size - (shapes[i + 1] % block_size)
811 block_shape.append(block_size)
812 padding_shape.append([0, padding_size])
813
814 args.append(["_blockrank_{}".format(block_rank), [block_shape, padding_shape]])
815 return args
816
817 def agBatchToSpace(op, shapes, rng):
818 batch_rank = 1
819 channel_rank = 1
820 block_rank = len(shapes) - batch_rank - channel_rank
821
822 # must have at least rank 1 (M) block
823 if block_rank < 1:
824 return []
825
826 args = []
827 block_shape = []
828 padding_shape = []
829 block_prod = 1
830
831 for i in range(block_rank):
832 block_size = 2
833 block_prod = block_prod * block_size
834 crop_size = 0
835 block_shape.append(block_size)
836 padding_shape.append([0, crop_size])
837
838 # batch / prod(block_shape[i]) must be integer
839 # transpose to swap depth and batch. so shape[-1] would be batch dim
840 if shapes[-1] % block_prod == 0:
841 args.append(
842 ["_blockrank_{}".format(block_rank), [block_shape, padding_shape]]
843 )
844
845 return args
846
847 def agSpaceToDepth(op, shapes, rng):
848 # must be rank 4 input tensor
849 if len(shapes) != 4:
850 return []
851
852 block_size = 2
853
854 # spatial dimension must be divisible by block_size
855 if shapes[1] % block_size != 0 or shapes[2] % block_size != 0:
856 return []
857
858 args = []
859 args.append(["_blocksize_{}".format(block_size), [block_size]])
860
861 return args
862
863 def agDepthToSpace(op, shapes, rng):
864 # must be rank 4 input tensor
865 if len(shapes) != 4:
866 return []
867
868 block_size = 2
869 # depth dimension must be divisible by block_size * block_size
870 if shapes[3] % (block_size * block_size) != 0:
871 return []
872
873 args = []
874 args.append(["_blocksize_{}".format(block_size), [block_size]])
875
876 return args
877
878 def agFakequant(op, shapes, rng):
879 args = []
880 for num_bits in [8, 16]:
881 for narrow in [False, True]:
882 args.append(
883 ["_bits{}_narrow{}".format(num_bits, narrow), [num_bits, narrow]]
884 )
885
886 return args
887
888 def agShift(op, shapes, rng):
889 args = []
890
891 for shift in rng.integers(0, 32, size=8):
892 args.append(["_shift{}".format(shift), [shift]])
893
894 return args
895
896 def agFloat(op, shapes, rng):
897 args = []
898
899 i = 0
900 for alpha in np.float32(rng.random(size=2)):
901 args.append(["_{}".format(i), [alpha]])
902
903 return args
904
905 # Similar to agAxes, but tf.OneHot only allow axis from [-1, rank(input)]
906 def agOneHot(op, shapes, rng):
907 axes = []
908 for i in range(-1, len(shapes) + 1, 1):
909 if i >= 0:
910 axes.append(["_axis_{}".format(i), [i]])
911 else:
912 axes.append(["_axis_m{}".format(-i), [i]])
913 return axes
Luke Hutton261b7b62023-01-10 14:50:31 +0000914
915 def agRFFT2d(op, shape, rng):
916 args = []
917
918 # Must be rank 3 input tensor
919 if len(shape) != 3:
920 return []
921
922 # Check rfft2d with enforced fft_length
923 for fft_length_h in [2, 32]:
924 for fft_length_w in [2, 8, 16]:
925 fft_length = [fft_length_h, fft_length_w]
926 args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]])
927
928 # Check rfft2d with no fft_length provided (fft_length=None).
929 # In this case, the height and width of the input should be
930 # used for the calculation. Therefore, we need to check that
931 # the input shape is already a power of two.
932 def is_power_of_two(x):
933 return math.log(x, 2).is_integer()
934
935 height, width = shape[1:3]
936 if is_power_of_two(height) and is_power_of_two(width):
937 args.append(["_fft_length_None", [None]])
938
939 return args