Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 1 | # Copyright (c) 2020-2022, ARM Limited. |
| 2 | # SPDX-License-Identifier: Apache-2.0 |
| 3 | import numpy as np |
| 4 | |
| 5 | |
| 6 | class ArgGen: |
| 7 | """Argument generator functions. These functions take a shape and dtype to |
| 8 | create arguments for an operator. Methods are prefixed with 'ag' to make |
| 9 | search easy.""" |
| 10 | |
| 11 | def __init__(self): |
| 12 | pass |
| 13 | |
| 14 | @staticmethod |
| 15 | def agNone(op, shapes, rng): |
| 16 | """A trivial argument generator for operators that only take tensor |
| 17 | operands""" |
| 18 | return [("", [])] |
| 19 | |
| 20 | # Build the axis argument for operators where we want to iterate over N axes |
| 21 | # as an argument |
| 22 | @staticmethod |
| 23 | def agAxes(op, shapes, rng): |
| 24 | axes = [] |
| 25 | for i in range(-len(shapes), len(shapes), 1): |
| 26 | if i >= 0: |
| 27 | axes.append(["_axis_{}".format(i), [i]]) |
| 28 | else: |
| 29 | axes.append(["_axis_m{}".format(-i), [i]]) |
| 30 | return axes |
| 31 | |
| 32 | # Build the axis LIST argument for operators that take an axis list. |
| 33 | # This builds a list of each axis individually, plus one element |
| 34 | # that contains a list of all axes. Note that we need to pack the list in |
| 35 | # an additional list so that it isn't exploded when being passed to the |
| 36 | # build_operator function. |
| 37 | # tensor_arg_count not used |
| 38 | def agAxesList(op, shapes, rng): |
| 39 | axes = ArgGen.agAxes(op, shapes, rng) |
| 40 | axes_list = [] |
| 41 | for desc, a in axes: |
| 42 | axes_list.append([desc, [a]]) |
| 43 | |
| 44 | axes_list.append(["_axisall", [list(range(len(shapes)))]]) |
| 45 | axes_list.append(["_axisall_none", [None]]) |
| 46 | return axes_list |
| 47 | |
| 48 | def agAxesListKeepdims(op, shapes, rng): |
| 49 | axes = ArgGen.agAxes(op, shapes, rng) |
| 50 | axes_list = [] |
| 51 | for desc, a in axes: |
| 52 | axes_list.append([desc + "_keep0", [a, False]]) |
| 53 | # avoid trying to reduce an axis of shape 1, as the TFL converter |
| 54 | # will optimize away the entire reduction |
| 55 | if (a[0] >= 0 and shapes[a[0]] != 1) or ( |
| 56 | a[0] < 0 and shapes[len(shapes) + a[0]] != 1 |
| 57 | ): |
| 58 | axes_list.append([desc + "_keep1", [a, True]]) |
| 59 | |
| 60 | axes_list.append(["_axisall_keep0", [list(range(len(shapes))), False]]) |
| 61 | axes_list.append(["_axisall_keep0_none", [None, False]]) |
| 62 | # another instance where the reduce gets optimized out. |
| 63 | if len(shapes) != 1: |
| 64 | axes_list.append(["_axisall_keep1", [list(range(len(shapes))), True]]) |
| 65 | axes_list.append(["_axisall_keep1_none", [None, True]]) |
| 66 | # no longer test axis empty, as TFL converter optimizes the reduce out |
| 67 | return axes_list |
| 68 | |
| 69 | # conv2d argument generators build the TF constants |
| 70 | def agConv2d(op, shapes, rng): |
| 71 | arg_list = [] |
| 72 | |
| 73 | # Must be rank 4 |
| 74 | if len(shapes) < 4: |
| 75 | return arg_list |
| 76 | |
| 77 | filter_h, filter_w = op["filter"] |
| 78 | |
| 79 | # strides, padding, dilations, |
| 80 | for stride_h in [1, 2]: |
| 81 | for stride_w in [1, 2]: |
| 82 | for padding in ["SAME", "VALID"]: |
| 83 | for dilation_h in [1, 2]: |
| 84 | for dilation_w in [1, 2]: |
| 85 | |
| 86 | # Disqualify argument combinations that would cause |
| 87 | # an illegal convolution |
| 88 | |
| 89 | if (padding == "VALID") and ( |
| 90 | (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0 |
| 91 | or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0 |
| 92 | ): |
| 93 | continue |
| 94 | |
| 95 | arg_list.append( |
| 96 | [ |
| 97 | "_st{}{}_pad{}_dilat{}{}".format( |
| 98 | stride_h, |
| 99 | stride_w, |
| 100 | padding, |
| 101 | dilation_h, |
| 102 | dilation_w, |
| 103 | ), |
| 104 | [ |
| 105 | [stride_h, stride_w], |
| 106 | padding, |
| 107 | [dilation_h, dilation_w], |
| 108 | ], |
| 109 | ] |
| 110 | ) |
| 111 | return arg_list |
| 112 | |
| 113 | # conv2d argument generators build the TF constants |
| 114 | def agDepthwiseConv2d(op, shapes, rng): |
| 115 | arg_list = [] |
| 116 | |
| 117 | # Must be rank 4 |
| 118 | if len(shapes) < 4: |
| 119 | return arg_list |
| 120 | |
| 121 | filter_h, filter_w = op["filter"] |
| 122 | |
| 123 | # strides, padding, dilations, Depthwise conv2d is the same as conv2d |
| 124 | # except that strides in h/w must be the same and the argument must be |
| 125 | # formatted as [1, stride_h, stride_w, 1] in TF. |
| 126 | for stride in [1, 2]: |
| 127 | for padding in ["SAME", "VALID"]: |
| 128 | for dilation_h in [1, 2]: |
| 129 | for dilation_w in [1, 2]: |
| 130 | |
| 131 | # Disqualify argument combinations that would cause an illegal |
| 132 | # convolution |
| 133 | |
| 134 | if (padding == "VALID") and ( |
| 135 | (shapes[1] - (filter_h - 1) * 2 - dilation_h) <= 0 |
| 136 | or (shapes[2] - (filter_w - 1) * 2 - dilation_w) <= 0 |
| 137 | ): |
| 138 | continue |
| 139 | |
| 140 | # When dilation is used, stride must be 1x1 (TF rules) |
| 141 | if dilation_h > 1 or dilation_w > 1: |
| 142 | if stride > 1: |
| 143 | continue |
| 144 | |
| 145 | # Dilation must evenly divide the tensor. Some of our inputs |
| 146 | # intentionally use odd-sized tensors. |
| 147 | if shapes[1] % dilation_h != 0 or shapes[2] % dilation_w != 0: |
| 148 | continue |
| 149 | |
| 150 | arg_list.append( |
| 151 | [ |
| 152 | "_st{}{}_pad{}_dilat{}{}".format( |
| 153 | stride, stride, padding, dilation_h, dilation_w |
| 154 | ), |
| 155 | [ |
| 156 | [1, stride, stride, 1], |
| 157 | padding, |
| 158 | [dilation_h, dilation_w], |
| 159 | ], |
| 160 | ] |
| 161 | ) |
| 162 | return arg_list |
| 163 | |
| 164 | # conv2d argument generators build the TF constants |
| 165 | def agTransposeConv2d(op, shapes, rng): |
| 166 | arg_list = [] |
| 167 | |
| 168 | # Must be rank 4 |
| 169 | if len(shapes) < 4: |
| 170 | return arg_list |
| 171 | |
| 172 | filter_h, filter_w = op["filter"] |
| 173 | |
| 174 | # strides, padding, dilations, |
| 175 | for stride_h in [1, 2]: |
| 176 | for stride_w in [1, 2]: |
| 177 | for padding in ["SAME", "VALID"]: |
| 178 | if padding == "SAME": |
| 179 | out_height = (shapes[1]) * stride_h |
| 180 | out_width = (shapes[2]) * stride_w |
| 181 | else: # padding == 'VALID' |
| 182 | out_height = (shapes[1] - 1) * stride_h + filter_h |
| 183 | out_width = (shapes[2] - 1) * stride_w + filter_w |
| 184 | |
| 185 | output_shape = [shapes[0], out_height, out_width, shapes[3] * 2] |
| 186 | arg_list.append( |
| 187 | [ |
| 188 | "_st{}{}_pad{}".format(stride_h, stride_w, padding), |
| 189 | [output_shape, [stride_h, stride_w], padding], |
| 190 | ] |
| 191 | ) |
| 192 | return arg_list |
| 193 | |
| 194 | def agPooling(op, shapes, rng): |
| 195 | arg_list = [] |
| 196 | |
| 197 | # Must be rank 4 |
| 198 | if len(shapes) < 4: |
| 199 | return arg_list |
| 200 | |
| 201 | for stride_h in [1, 2]: |
| 202 | for stride_w in [1, 2]: |
| 203 | for kernel_h in [1, 2]: |
| 204 | for kernel_w in [1, 2]: |
| 205 | for padding in ["SAME", "VALID"]: |
| 206 | |
| 207 | if (padding == "VALID") and ( |
| 208 | (shapes[1] % (kernel_h * stride_h) > 0) |
| 209 | or (shapes[2] % (kernel_w * stride_w) > 0) |
| 210 | or (shapes[1] <= kernel_h) |
| 211 | or (shapes[2] <= kernel_w) |
| 212 | ): |
| 213 | continue |
| 214 | |
| 215 | if (padding == "SAME") and ( |
| 216 | (shapes[1] < kernel_h) or (shapes[2] < kernel_w) |
| 217 | ): |
| 218 | continue |
| 219 | |
| 220 | arg_list.append( |
| 221 | [ |
| 222 | "_st{}{}_pad{}_kern{}{}".format( |
| 223 | stride_h, stride_w, padding, kernel_h, kernel_w |
| 224 | ), |
| 225 | [ |
| 226 | [stride_h, stride_w], |
| 227 | [kernel_h, kernel_w], |
| 228 | padding, |
| 229 | ], |
| 230 | ] |
| 231 | ) |
| 232 | return arg_list |
| 233 | |
| 234 | def getFactors(val, start=1): |
| 235 | factors = [] |
| 236 | for i in range(start, int(np.sqrt(val))): |
| 237 | if (val % i) == 0: |
| 238 | factors.append(i) |
| 239 | |
| 240 | return factors |
| 241 | |
| 242 | def agReshape(op, shapes, rng): |
| 243 | # This is slow code. Fortunately, the numbers involved are small |
| 244 | arg_list = [] |
| 245 | |
| 246 | total_elements = 1 |
| 247 | for s in shapes: |
| 248 | total_elements *= s |
| 249 | |
| 250 | # Find integer factors of this shape |
| 251 | factors = ArgGen.getFactors(total_elements) |
| 252 | |
| 253 | for rank in range(1, len(shapes) + 1): |
| 254 | if len(factors) < rank: |
| 255 | break |
| 256 | |
| 257 | new_shape = [] |
| 258 | remaining_elements = total_elements |
| 259 | |
| 260 | # Randomly shuffle the factors and iteratively pick from the factors |
| 261 | # of the remaining elements |
| 262 | shuffled_factors = rng.permutation(factors) |
| 263 | for i in range(rank): |
| 264 | # Pick rank - 1 factors |
| 265 | new_shape.append(shuffled_factors[0]) |
| 266 | remaining_elements = remaining_elements // shuffled_factors[0] |
| 267 | shuffled_factors = rng.permutation( |
| 268 | ArgGen.getFactors(remaining_elements) |
| 269 | ) |
| 270 | new_shape.append(remaining_elements) |
| 271 | |
| 272 | # Don't do no-op reshapes because TFLite optimizes out the op |
| 273 | if new_shape == list(shapes): |
| 274 | continue |
| 275 | |
| 276 | arg_list.append(["_rank{}".format(rank), [new_shape]]) |
| 277 | |
| 278 | return arg_list |
| 279 | |
| 280 | def agTranspose(op, shapes, rng): |
| 281 | arg_list = [] |
| 282 | |
| 283 | # Must have at least two dimensions to transpose |
| 284 | if (len(shapes)) < 2: |
| 285 | return arg_list |
| 286 | |
| 287 | # Pick a bunch of random permutations |
| 288 | range_arr = np.arange(len(shapes)) |
| 289 | for i in range(len(shapes)): |
| 290 | perm = rng.permutation(range_arr).astype(np.int32) |
| 291 | # print('\n shape {} permute{} perm: {} arr: {}'.format(shapes, i, |
| 292 | # perm, range_arr)) |
| 293 | if np.allclose(perm, range_arr): |
| 294 | print("skipped") |
| 295 | continue |
| 296 | arg_list.append(["_permute{}".format(i), [perm]]) |
| 297 | |
| 298 | return arg_list |
| 299 | |
| 300 | def agSlice(op, shapes, rng): |
| 301 | arg_list = [] |
| 302 | |
| 303 | rank = len(shapes) |
| 304 | |
| 305 | if rank == 1 and shapes[0] == 1: |
| 306 | return arg_list |
| 307 | |
| 308 | for i in range(4): |
| 309 | # Pick a few random start points, axes, and strides |
| 310 | start = np.empty((rank), dtype=int) |
| 311 | size = np.empty((rank), dtype=int) |
| 312 | for j in range(rank): |
| 313 | if shapes[j] > 2: |
| 314 | start[j] = rng.integers(0, shapes[j] - 2) |
| 315 | # print('j = {}: {} - {} - 1: {}'.format(j, shapes[j], |
| 316 | # start[j], shapes[j] - start[j] - 1)) |
| 317 | size[j] = rng.integers(1, shapes[j] - start[j] - 1) |
| 318 | else: |
| 319 | start[j] = 0 |
| 320 | size[j] = shapes[j] |
| 321 | |
| 322 | arg_list.append(["_perm{}".format(i), [start, size]]) |
| 323 | |
| 324 | return arg_list |
| 325 | |
| 326 | def agStridedSlice(op, shapes, rng): |
| 327 | arg_list = [] |
| 328 | |
| 329 | rank = len(shapes) |
| 330 | |
| 331 | # Reference model is limited to rank=6 internally right now |
| 332 | if rank > 3: |
| 333 | return arg_list |
| 334 | |
| 335 | if rank == 1 and shapes[0] == 1: |
| 336 | return arg_list |
| 337 | |
| 338 | for i in range(4): |
| 339 | # Pick a few random begin points, axes, and strides |
| 340 | begin = np.empty((rank), dtype=int) |
| 341 | end = np.empty((rank), dtype=int) |
| 342 | strides = np.empty((rank), dtype=int) |
| 343 | |
| 344 | begin_mask = rng.integers(0, (1 << (rank - 1))) |
| 345 | end_mask = rng.integers(0, (1 << (rank - 1))) |
| 346 | |
| 347 | for j in range(rank): |
| 348 | |
| 349 | if begin_mask & (1 << j) or shapes[j] < 2: |
| 350 | begin[j] = 0 |
| 351 | else: |
| 352 | begin[j] = rng.integers(0, shapes[j] - 1) |
| 353 | |
| 354 | if end_mask & (1 << j) or shapes[j] < 2 or (begin[j] + 2) >= shapes[j]: |
| 355 | end[j] = shapes[j] |
| 356 | else: |
| 357 | end[j] = rng.integers(begin[j] + 1, shapes[j] - 1) |
| 358 | |
| 359 | possible_stride = ArgGen.getFactors(end[j] - begin[j], 2) |
| 360 | |
| 361 | if not possible_stride: |
| 362 | strides[j] = 1 |
| 363 | else: |
| 364 | strides[j] = rng.choice(possible_stride) |
| 365 | |
| 366 | # Randomly set the masks, except ellipsis_mask and new_axis_mask |
| 367 | # which must be zero for now For begin/end mask this to work, |
| 368 | # strides must be adjusted to still be divsible... |
| 369 | ellipsis_mask = 0 |
| 370 | new_axis_mask = 0 |
| 371 | |
| 372 | # if rng.choice([0, 1]) and rank > 1: |
| 373 | # new_axis_mask = 1 << rng.integers(0, rank - 1) |
| 374 | # else: |
| 375 | # new_axis_mask = 0 |
| 376 | |
| 377 | if rng.choice([0, 1]) and rank > 1: |
| 378 | shrink_axis_mask = 1 << rng.integers(0, rank - 1) |
| 379 | else: |
| 380 | shrink_axis_mask = 0 |
| 381 | |
| 382 | # Only one of these bits may be set. Prefer shrink_axis_mask |
| 383 | new_axis_mask = new_axis_mask & ~shrink_axis_mask |
| 384 | |
| 385 | arg_list.append( |
| 386 | [ |
| 387 | "_perm{}".format(i), |
| 388 | [ |
| 389 | begin, |
| 390 | end, |
| 391 | strides, |
| 392 | begin_mask, |
| 393 | end_mask, |
| 394 | ellipsis_mask, |
| 395 | new_axis_mask, |
| 396 | shrink_axis_mask, |
| 397 | ], |
| 398 | ] |
| 399 | ) |
| 400 | |
| 401 | # print('Shape: {} begin={} end={} strides={} begin_mask={:x} |
| 402 | # end_mask={:x} new_axis_mask={:x} shrink_mask={:x}'.format(shapes, |
| 403 | # begin, end, strides, begin_mask, end_mask, new_axis_mask, |
| 404 | # shrink_axis_mask)) |
| 405 | |
| 406 | return arg_list |
| 407 | |
| 408 | # tf.stack axis can be [0, rank(input)] |
| 409 | def agStack(op, shapes, rng): |
| 410 | axes = [] |
| 411 | for i in range(len(shapes) + 1): |
| 412 | axes.append(["_axis{}".format(i), [i]]) |
| 413 | return axes |
| 414 | |
| 415 | def agPad(op, shapes, rng): |
| 416 | arg_list = [] |
| 417 | |
| 418 | rank = len(shapes) |
| 419 | for left in range(3): |
| 420 | for right in range(3): |
| 421 | paddings = np.zeros((rank, 2), dtype=np.int32) |
| 422 | for d in range(rank): |
| 423 | paddings[d, 0] = left |
| 424 | paddings[d, 1] = right |
| 425 | |
| 426 | arg_list.append(["_pad{}{}".format(left, right), [paddings]]) |
| 427 | return arg_list |
| 428 | |
| 429 | def agFill(op, shapes, rng): |
| 430 | values = [] |
| 431 | for i in range(4): |
| 432 | value = rng.integers(0, 10, dtype=np.int32) |
| 433 | values.append(["_value{}".format(value), [shapes, value]]) |
| 434 | return values |
| 435 | |
| 436 | def getValuesToSum(total, rng): |
| 437 | # Get a list of random integers that sum up to 'total' |
| 438 | vals = [] |
| 439 | |
| 440 | # np.random.randint() min and max to be different, so if the remainder |
| 441 | # is 1, give up |
| 442 | while total > 1: |
| 443 | vals.append(rng.integers(1, total)) |
| 444 | total = total - vals[-1] |
| 445 | |
| 446 | if total == 1: |
| 447 | vals.append(1) |
| 448 | |
| 449 | return vals |
| 450 | |
| 451 | def agSplit(op, shapes, rng): |
| 452 | arg_list = [] |
| 453 | |
| 454 | rank = len(shapes) |
| 455 | |
| 456 | # Shuffle the random number generator a few more times to get |
| 457 | # a better range of axes across shapes |
| 458 | for i in range(rank): |
| 459 | for j in range(shapes[i]): |
| 460 | rng.integers(shapes[i]) |
| 461 | |
| 462 | for i in range(3): |
| 463 | # Need to generate tests for both the num_splits and size_vector versions. |
| 464 | axis = rng.choice(np.arange(0, rank)) |
| 465 | |
| 466 | # For num_splits, get a few divisors of the given axis |
| 467 | divs = ArgGen.getFactors(shapes[axis], 2) |
| 468 | |
| 469 | if divs: |
| 470 | # Get no more than 2 samples |
| 471 | splits = list(rng.choice(divs, size=2)) |
| 472 | |
| 473 | for s in splits: |
| 474 | arg_list.append( |
| 475 | ["_split{}_axis{}".format(int(s), axis), [int(s), axis]] |
| 476 | ) |
| 477 | |
| 478 | # For vector splits, get a list of integers that sum up to the axis size |
| 479 | vals = ArgGen.getValuesToSum(shapes[axis], rng) |
| 480 | |
| 481 | if len(vals) > 1: |
| 482 | arg_list.append(["_splitv_axis{}".format(axis), [vals, axis]]) |
| 483 | |
| 484 | return arg_list |
| 485 | |
| 486 | def agTile(op, shapes, rng): |
| 487 | arg_list = [] |
| 488 | |
| 489 | rank = len(shapes) |
| 490 | |
| 491 | # create 1D multiples list |
| 492 | multiples = list() |
| 493 | for i in range(rank): |
| 494 | multiples.append(rng.integers(1, 4)) |
| 495 | |
| 496 | multiples_str = "x".join(list(str(i) for i in multiples)) |
| 497 | |
| 498 | arg_list.append(["_tile_{}".format(multiples_str), [multiples]]) |
| 499 | |
| 500 | return arg_list |
| 501 | |
| 502 | def agGather(op, shapes, rng): |
| 503 | args = [] |
| 504 | for batch_dims in range(len(shapes) - 1): |
| 505 | for axis in range(batch_dims, len(shapes)): |
| 506 | # indices value must be within [0, shapes[i]) |
| 507 | |
| 508 | # Create an arbitrary shape for the indices |
| 509 | indices_rank = rng.integers(batch_dims + 1, 4) |
| 510 | indices_shape = rng.integers(1, 8, size=indices_rank) |
| 511 | |
| 512 | # Copy in the batch dimensions because they must match |
| 513 | for b in range(batch_dims): |
| 514 | indices_shape[b] = shapes[b] |
| 515 | |
| 516 | # Calculate total element count |
| 517 | indices_size = 1 |
| 518 | for j in range(indices_rank): |
| 519 | indices_size = indices_shape[j] * indices_size |
| 520 | |
| 521 | indices = rng.integers(0, shapes[axis], indices_size, np.int32).reshape( |
| 522 | indices_shape |
| 523 | ) |
| 524 | |
| 525 | args.append( |
| 526 | [ |
| 527 | "_batchdims_{}_axis_{}".format(batch_dims, axis), |
| 528 | [indices, batch_dims, axis], |
| 529 | ] |
| 530 | ) |
| 531 | return args |
| 532 | |
| 533 | def agGatherND(op, shapes, rng): |
| 534 | args = [] |
| 535 | |
| 536 | for N in range(1, len(shapes) - 1): |
| 537 | # Rank includes the N dimension |
| 538 | indices_rank = rng.integers(2, 4, size=1)[0] |
| 539 | indices_shape = [] |
| 540 | |
| 541 | indices_shape = rng.integers(1, 8, size=indices_rank) |
| 542 | indices_shape[-1] = N |
| 543 | |
| 544 | indices_count = 1 |
| 545 | for i in range(indices_rank - 1): |
| 546 | indices_count = indices_count * indices_shape[i] |
| 547 | |
| 548 | indices_list = np.zeros(shape=(indices_count, N), dtype=np.int32) |
| 549 | |
| 550 | for i in range(indices_count): |
| 551 | for j in range(N): |
| 552 | indices_list[i, j] = rng.integers(0, shapes[j], size=1)[0] |
| 553 | |
| 554 | indices = indices_list.reshape(indices_shape) |
| 555 | |
| 556 | args.append(["_n{}".format(N), [indices]]) |
| 557 | |
| 558 | return args |
| 559 | |
| 560 | def agScatterND(op, shapes, rng): |
| 561 | args = [] |
| 562 | |
| 563 | # ScatterND has to generate a constant shapes tensor, indices |
| 564 | # tensor, and a tensor of updates. Unforunately, the updates |
| 565 | # need to be a size that's based on the N generated in this |
| 566 | # function and the dtype known only in the TensorGen function, |
| 567 | # but not in ArgGen. |
| 568 | # |
| 569 | # There are many bad ways to solve this and we'll choose the |
| 570 | # least of the evils which still gives reasonable coverage of |
| 571 | # the possible operand shapes. |
| 572 | for N in range(1, len(shapes)): |
| 573 | # Rank includes the N dimension |
| 574 | indices_rank = rng.integers(2, 4, size=1)[0] |
| 575 | indices_shape = [] |
| 576 | |
| 577 | indices_shape = rng.integers(1, 8, size=indices_rank) |
| 578 | indices_shape[-1] = N |
| 579 | |
| 580 | # Store the Shapes, and the indicies value tensor as arguments. |
| 581 | args.append(["_n{}".format(N), [shapes, indices_shape, N, rng]]) |
| 582 | |
| 583 | return args |
| 584 | |
| 585 | def agSpaceToBatch(op, shapes, rng): |
| 586 | batch_rank = 1 |
| 587 | channel_rank = 1 |
| 588 | block_rank = len(shapes) - batch_rank - channel_rank |
| 589 | |
| 590 | # must have at least rank 1 (M) block |
| 591 | if block_rank < 1: |
| 592 | return [] |
| 593 | |
| 594 | args = [] |
| 595 | block_shape = [] |
| 596 | padding_shape = [] |
| 597 | |
| 598 | for i in range(block_rank): |
| 599 | block_size = 2 |
| 600 | padding_size = block_size - (shapes[i + 1] % block_size) |
| 601 | block_shape.append(block_size) |
| 602 | padding_shape.append([0, padding_size]) |
| 603 | |
| 604 | args.append(["_blockrank_{}".format(block_rank), [block_shape, padding_shape]]) |
| 605 | return args |
| 606 | |
| 607 | def agBatchToSpace(op, shapes, rng): |
| 608 | batch_rank = 1 |
| 609 | channel_rank = 1 |
| 610 | block_rank = len(shapes) - batch_rank - channel_rank |
| 611 | |
| 612 | # must have at least rank 1 (M) block |
| 613 | if block_rank < 1: |
| 614 | return [] |
| 615 | |
| 616 | args = [] |
| 617 | block_shape = [] |
| 618 | padding_shape = [] |
| 619 | block_prod = 1 |
| 620 | |
| 621 | for i in range(block_rank): |
| 622 | block_size = 2 |
| 623 | block_prod = block_prod * block_size |
| 624 | crop_size = 0 |
| 625 | block_shape.append(block_size) |
| 626 | padding_shape.append([0, crop_size]) |
| 627 | |
| 628 | # batch / prod(block_shape[i]) must be integer |
| 629 | # transpose to swap depth and batch. so shape[-1] would be batch dim |
| 630 | if shapes[-1] % block_prod == 0: |
| 631 | args.append( |
| 632 | ["_blockrank_{}".format(block_rank), [block_shape, padding_shape]] |
| 633 | ) |
| 634 | |
| 635 | return args |
| 636 | |
| 637 | def agSpaceToDepth(op, shapes, rng): |
| 638 | # must be rank 4 input tensor |
| 639 | if len(shapes) != 4: |
| 640 | return [] |
| 641 | |
| 642 | block_size = 2 |
| 643 | |
| 644 | # spatial dimension must be divisible by block_size |
| 645 | if shapes[1] % block_size != 0 or shapes[2] % block_size != 0: |
| 646 | return [] |
| 647 | |
| 648 | args = [] |
| 649 | args.append(["_blocksize_{}".format(block_size), [block_size]]) |
| 650 | |
| 651 | return args |
| 652 | |
| 653 | def agDepthToSpace(op, shapes, rng): |
| 654 | # must be rank 4 input tensor |
| 655 | if len(shapes) != 4: |
| 656 | return [] |
| 657 | |
| 658 | block_size = 2 |
| 659 | # depth dimension must be divisible by block_size * block_size |
| 660 | if shapes[3] % (block_size * block_size) != 0: |
| 661 | return [] |
| 662 | |
| 663 | args = [] |
| 664 | args.append(["_blocksize_{}".format(block_size), [block_size]]) |
| 665 | |
| 666 | return args |
| 667 | |
| 668 | def agFakequant(op, shapes, rng): |
| 669 | args = [] |
| 670 | for num_bits in [8, 16]: |
| 671 | for narrow in [False, True]: |
| 672 | args.append( |
| 673 | ["_bits{}_narrow{}".format(num_bits, narrow), [num_bits, narrow]] |
| 674 | ) |
| 675 | |
| 676 | return args |
| 677 | |
| 678 | def agShift(op, shapes, rng): |
| 679 | args = [] |
| 680 | |
| 681 | for shift in rng.integers(0, 32, size=8): |
| 682 | args.append(["_shift{}".format(shift), [shift]]) |
| 683 | |
| 684 | return args |
| 685 | |
| 686 | def agFloat(op, shapes, rng): |
| 687 | args = [] |
| 688 | |
| 689 | i = 0 |
| 690 | for alpha in np.float32(rng.random(size=2)): |
| 691 | args.append(["_{}".format(i), [alpha]]) |
| 692 | |
| 693 | return args |
| 694 | |
| 695 | # Similar to agAxes, but tf.OneHot only allow axis from [-1, rank(input)] |
| 696 | def agOneHot(op, shapes, rng): |
| 697 | axes = [] |
| 698 | for i in range(-1, len(shapes) + 1, 1): |
| 699 | if i >= 0: |
| 700 | axes.append(["_axis_{}".format(i), [i]]) |
| 701 | else: |
| 702 | axes.append(["_axis_m{}".format(-i), [i]]) |
| 703 | return axes |