blob: cd7831da3dd6e96656af727dd5a212cc7338d9e9 [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001# Copyright (c) 2020-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3import numpy as np
4import tensorflow as tf
5from frameworks.tensor_gen import TGen
6
7
8class TBuilder:
9 """The member functions build the tensorflow operators into small networks
10 for our tests"""
11
12 def __init__(self):
13 pass
14
15 def fake_quant(tensor, tensor_scale, name):
16 """Helper function for quantizing with a scaling parameters structure."""
17 return tf.quantization.fake_quant_with_min_max_args(
18 tensor,
19 min=tensor_scale.min,
20 max=tensor_scale.max,
21 num_bits=tensor_scale.num_bits,
22 narrow_range=tensor_scale.narrow_range,
23 name=name,
24 )
25
26 def fake_quant_params(tensor, min, max, scaling, name):
27 """Helper function for quantizing with individual scaling parameters."""
28 return tf.quantization.fake_quant_with_min_max_args(
29 tensor,
30 min=min,
31 max=max,
32 num_bits=scaling.num_bits,
33 narrow_range=scaling.narrow_range,
34 name=name,
35 )
36
37 class Add:
38 def __init__(self, name):
39 self.result_name = name
40
41 def eval(self, a, b):
42 return tf.add(a, b, name=self.result_name)
43
44 class Sub:
45 def __init__(self, name):
46 self.result_name = name
47
48 def eval(self, a, b):
49 return tf.subtract(a, b, name=self.result_name)
50
51 class Mul:
52 def __init__(self, name):
53 self.result_name = name
54
55 def eval(self, a, b):
56 return tf.multiply(a, b, name=self.result_name)
57
58 class Exp:
59 def __init__(self, name):
60 self.result_name = name
61
62 def eval(self, a):
63 return tf.exp(a, name=self.result_name)
64
65 class Rcp:
66 def __init__(self, name):
67 self.result_name = name
68
69 def eval(self, a):
70 return tf.math.reciprocal(a, name=self.result_name)
71
72 class Relu:
73 def __init__(self, name):
74 self.result_name = name
75
76 def eval(self, a):
77 return tf.nn.relu(a, name=self.result_name)
78
Jerry Ge93912432022-07-22 10:29:13 -070079 class Relu1:
80 def __init__(self, name):
81 self.result_name = name
82
83 def eval(self, a):
84 # TF doesn't have relu_n1_to_1 operator,
85 # use min and max as a workaround
86 # alternatively, we can use clip_by_value
87 return tf.math.minimum(1.0, tf.math.maximum(-1.0, a))
88
Jeremy Johnson015c3552022-02-23 12:15:03 +000089 class Relu6:
90 def __init__(self, name):
91 self.result_name = name
92
93 def eval(self, a):
94 return tf.nn.relu6(a, name=self.result_name)
95
96 class LeakyRelu:
97 def __init__(self, alpha, name):
98 self.alpha = alpha
99 self.result_name = name
100
101 def eval(self, a):
102 return tf.nn.leaky_relu(a, alpha=self.alpha, name=self.result_name)
103
TatWai Chong473eb382022-08-02 04:21:30 +0000104 class Gelu:
105 def __init__(self, name):
106 self.result_name = name
107
108 def eval(self, a):
109 return tf.nn.gelu(a, name=self.result_name)
110
Jeremy Johnson015c3552022-02-23 12:15:03 +0000111 class Concat:
112 def __init__(self, axis, name):
113 self.axis = axis
114 self.result_name = name
115
116 def eval(self, a, b):
117 return tf.concat([a, b], self.axis, name=self.result_name)
118
119 class BitwiseAnd:
120 def __init__(self, name):
121 self.result_name = name
122
123 def eval(self, a, b):
124 return tf.bitwise.bitwise_and(a, b, name=self.result_name)
125
126 class BitwiseOr:
127 def __init__(self, name):
128 self.result_name = name
129
130 def eval(self, a, b):
131 return tf.bitwise.bitwise_or(a, b, name=self.result_name)
132
133 class BitwiseNot:
134 def __init__(self, name):
135 self.result_name = name
136
137 def eval(self, a):
138 return tf.bitwise.invert(a, name=self.result_name)
139
140 class BitwiseXor:
141 def __init__(self, name):
142 self.result_name = name
143
144 def eval(self, a, b):
145 return tf.bitwise.bitwise_xor(a, b, name=self.result_name)
146
147 class LogicalAnd:
148 def __init__(self, name):
149 self.result_name = name
150
151 def eval(self, a, b):
152 return tf.math.logical_and(a, b, name=self.result_name)
153
154 class LogicalOr:
155 def __init__(self, name):
156 self.result_name = name
157
158 def eval(self, a, b):
159 return tf.math.logical_or(a, b, name=self.result_name)
160
161 class LogicalNot:
162 def __init__(self, name):
163 self.result_name = name
164
165 def eval(self, a):
166 return tf.math.logical_not(a, name=self.result_name)
167
168 class ReduceAny:
169 def __init__(self, axis_list, keepdims, name):
170 self.axis_list = axis_list
171 self.keepdims = keepdims
172 self.result_name = name
173
174 def eval(self, a):
175 return tf.math.reduce_any(
176 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
177 )
178
179 class ReduceAll:
180 def __init__(self, axis_list, keepdims, name):
181 self.axis_list = axis_list
182 self.keepdims = keepdims
183 self.result_name = name
184
185 def eval(self, a):
186 return tf.math.reduce_all(
187 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
188 )
189
190 class ReduceMin:
191 def __init__(self, axis_list, keepdims, name):
192 self.axis_list = axis_list
193 self.keepdims = keepdims
194 self.result_name = name
195
196 def eval(self, a):
197 return tf.math.reduce_min(
198 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
199 )
200
201 class ReduceMax:
202 def __init__(self, axis_list, keepdims, name):
203 self.axis_list = axis_list
204 self.keepdims = keepdims
205 self.result_name = name
206
207 def eval(self, a):
208 return tf.math.reduce_max(
209 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
210 )
211
212 class ReduceSum:
213 def __init__(self, axis_list, keepdims, name):
214 self.axis_list = axis_list
215 self.keepdims = keepdims
216 self.result_name = name
217
218 def eval(self, a):
219 return tf.math.reduce_sum(
220 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
221 )
222
223 class ReduceMean:
224 def __init__(self, axis_list, keepdims, name):
225 self.axis_list = axis_list
226 self.keepdims = keepdims
227 self.result_name = name
228
229 def eval(self, a):
230 return tf.math.reduce_mean(
231 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
232 )
233
234 class ReduceProduct:
235 def __init__(self, axis_list, keepdims, name):
236 self.axis_list = axis_list
237 self.keepdims = keepdims
238 self.result_name = name
239
240 def eval(self, a):
241 return tf.math.reduce_prod(
242 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
243 )
244
245 class Min:
246 def __init__(self, name):
247 self.result_name = name
248
249 def eval(self, a, b):
250 return tf.math.minimum(a, b, name=self.result_name)
251
252 class Max:
253 def __init__(self, name):
254 self.result_name = name
255
256 def eval(self, a, b):
257 return tf.math.maximum(a, b, name=self.result_name)
258
259 class Pow:
260 def __init__(self, name):
261 self.result_name = name
262
263 def eval(self, a, b):
264 return tf.math.pow(a, b, name=self.result_name)
265
266 class Abs:
267 def __init__(self, name):
268 self.result_name = name
269
270 def eval(self, a):
271 return tf.math.abs(a, name=self.result_name)
272
273 class Ceil:
274 def __init__(self, name):
275 self.result_name = name
276
277 def eval(self, a):
278 return tf.math.ceil(a, name=self.result_name)
279
280 class Floor:
281 def __init__(self, name):
282 self.result_name = name
283
284 def eval(self, a):
285 return tf.math.floor(a, name=self.result_name)
286
287 class Log:
288 def __init__(self, name):
289 self.result_name = name
290
291 def eval(self, a):
292 return tf.math.log(a, name=self.result_name)
293
294 class Negate:
295 def __init__(self, name):
296 self.result_name = name
297
298 def eval(self, a):
299 return tf.math.negative(a, name=self.result_name)
300
301 class Rsqrt:
302 def __init__(self, name):
303 self.result_name = name
304
305 def eval(self, a):
306 return tf.math.rsqrt(a, name=self.result_name)
307
308 class Sigmoid:
309 def __init__(self, name):
310 self.result_name = name
311
312 def eval(self, a):
313 return tf.math.sigmoid(a, name=self.result_name)
314
315 class Tanh:
316 def __init__(self, name):
317 self.result_name = name
318
319 def eval(self, a):
320 return tf.math.tanh(a, name=self.result_name)
321
322 class Square:
323 def __init__(self, name):
324 self.result_name = name
325
326 def eval(self, a):
327 return tf.math.square(a, name=self.result_name)
328
329 class SquaredDifference:
330 def __init__(self, name):
331 self.result_name = name
332
333 def eval(self, a, b):
334 return tf.math.squared_difference(a, b, name=self.result_name)
335
336 class Equal:
337 def __init__(self, name):
338 self.result_name = name
339
340 def eval(self, a, b):
341 return tf.math.equal(a, b, name=self.result_name)
342
343 class GreaterEqual:
344 def __init__(self, name):
345 self.result_name = name
346
347 def eval(self, a, b):
348 return tf.math.greater_equal(a, b, name=self.result_name)
349
350 class Greater:
351 def __init__(self, name):
352 self.result_name = name
353
354 def eval(self, a, b):
355 return tf.math.greater(a, b, name=self.result_name)
356
357 class Less:
358 def __init__(self, name):
359 self.result_name = name
360
361 def eval(self, a, b):
362 return tf.math.less(a, b, name=self.result_name)
363
364 class LessEqual:
365 def __init__(self, name):
366 self.result_name = name
367
368 def eval(self, a, b):
369 return tf.math.less_equal(a, b, name=self.result_name)
370
371 class Conv2d:
372 def __init__(self, weight, strides, padding, dilations, name):
373 self.weight = weight
374 self.strides = strides
375 self.padding = padding
376 self.dilations = dilations
377 self.result_name = name
378
379 def eval(self, input):
380 return tf.nn.conv2d(
381 input,
382 self.weight,
383 self.strides,
384 self.padding,
385 data_format="NHWC",
386 dilations=self.dilations,
387 name=self.result_name,
388 )
389
390 class Conv2dRelu:
391 def __init__(self, weight, name):
392 self.weight = weight
393 self.result_name = name
394
395 def eval(self, input):
396 conv2d = tf.nn.conv2d(
397 input,
398 self.weight,
399 [1, 1, 1, 1],
400 "SAME",
401 data_format="NHWC",
402 dilations=[1, 1, 1, 1],
403 name="conv2d",
404 )
405 return tf.nn.relu(conv2d, name=self.result_name)
406
407 class Conv2dRelu6:
408 def __init__(self, weight, name):
409 self.weight = weight
410 self.result_name = name
411
412 def eval(self, input):
413 conv2d = tf.nn.conv2d(
414 input,
415 self.weight,
416 [1, 1, 1, 1],
417 "SAME",
418 data_format="NHWC",
419 dilations=[1, 1, 1, 1],
420 name="conv2d",
421 )
422 return tf.nn.relu6(conv2d, name=self.result_name)
423
424 class Conv2dReluN1To1:
425 def __init__(self, weight, name):
426 self.weight = weight
427 self.result_name = name
428
429 def eval(self, input):
430 conv2d = tf.nn.conv2d(
431 input,
432 self.weight,
433 [1, 1, 1, 1],
434 "SAME",
435 data_format="NHWC",
436 dilations=[1, 1, 1, 1],
437 name="conv2d",
438 )
439 return tf.clip_by_value(conv2d, -1.0, 1.0, name=self.result_name)
440
441 class Conv2dTanh:
442 def __init__(self, weight, name):
443 self.weight = weight
444 self.result_name = name
445
446 def eval(self, input):
447 conv2d = tf.nn.conv2d(
448 input,
449 self.weight,
450 [1, 1, 1, 1],
451 "SAME",
452 data_format="NHWC",
453 dilations=[1, 1, 1, 1],
454 name="conv2d",
455 )
456 return tf.math.tanh(conv2d, name=self.result_name)
457
458 class Conv2dWithBias:
459 def __init__(self, weight, bias, strides, padding, dilations, name):
460 self.weight = weight
461 self.bias = bias
462 self.strides = strides
463 self.padding = padding
464 self.dilations = dilations
465 self.result_name = name
466
467 def eval(self, input):
468 conv2d_op = tf.nn.conv2d(
469 input,
470 self.weight,
471 self.strides,
472 self.padding,
473 data_format="NHWC",
474 dilations=self.dilations,
475 name="conv2d",
476 )
477 bias_add_op = tf.nn.bias_add(
478 conv2d_op, self.bias, data_format="NHWC", name=self.result_name
479 )
480 return bias_add_op
481
TatWai Chongfd629052022-07-25 04:01:58 +0000482 class Conv3d:
483 def __init__(self, weight, strides, padding, dilations, name):
484 self.weight = weight
485 self.strides = strides
486 self.padding = padding
487 self.dilations = dilations
488 self.result_name = name
489
490 def eval(self, input):
491 return tf.nn.conv3d(
492 input,
493 self.weight,
494 self.strides,
495 self.padding,
496 data_format="NDHWC",
497 dilations=self.dilations,
498 name=self.result_name,
499 )
500
501 class Conv3dWithBias:
502 def __init__(self, weight, bias, strides, padding, dilations, name):
503 self.weight = weight
504 self.bias = bias
505 self.strides = strides
506 self.padding = padding
507 self.dilations = dilations
508 self.result_name = name
509
510 def eval(self, input):
511 conv3d_op = tf.nn.conv3d(
512 input,
513 self.weight,
514 self.strides,
515 self.padding,
516 data_format="NDHWC",
517 dilations=self.dilations,
518 name="conv3d",
519 )
520 bias_add_op = tf.nn.bias_add(conv3d_op, self.bias, name=self.result_name)
521 return bias_add_op
522
Jeremy Johnson015c3552022-02-23 12:15:03 +0000523 class DepthwiseConv2d:
524 def __init__(self, weight, strides, padding, dilations, name):
525 self.weight = weight
526 self.strides = strides
527 self.padding = padding
528 self.dilations = dilations
529 self.result_name = name
530
531 def eval(self, input):
532 dws_conv2d = tf.nn.depthwise_conv2d(
533 input,
534 self.weight,
535 self.strides,
536 self.padding,
537 data_format="NHWC",
538 dilations=self.dilations,
539 name="dws_conv2d",
540 )
541 return tf.identity(dws_conv2d, name=self.result_name)
542
543 class DepthwiseConv2dWithBias:
544 def __init__(self, weight, bias, strides, padding, dilations, name):
545 self.weight = weight
546 self.bias = bias
547 self.strides = strides
548 self.padding = padding
549 self.dilations = dilations
550 self.result_name = name
551
552 def eval(self, input):
553 dws_conv2d = tf.nn.depthwise_conv2d(
554 input,
555 self.weight,
556 self.strides,
557 self.padding,
558 data_format="NHWC",
559 dilations=self.dilations,
560 name="dws_conv2d",
561 )
562 bias_add_op = tf.nn.bias_add(
563 dws_conv2d, self.bias, data_format="NHWC", name=self.result_name
564 )
565 return bias_add_op
566
567 class TransposeConv2d:
568 def __init__(self, weight, output_shape, strides, padding, name):
569 self.weight = weight
570 self.output_shape = output_shape
571 self.strides = strides
572 self.padding = padding
573 self.result_name = name
574
575 def eval(self, input):
576 return tf.nn.conv2d_transpose(
577 input,
578 self.weight,
579 self.output_shape,
580 self.strides,
581 self.padding,
582 data_format="NHWC",
583 name=self.result_name,
584 )
585
586 class Argmax:
587 def __init__(self, axis, name):
588 self.axis = axis
589 self.result_name = name
590
591 def eval(self, a):
592 return tf.argmax(a, self.axis, output_type=tf.int32, name=self.result_name)
593
594 class AvgPool2d:
595 def __init__(self, strides, kernel_size, padding, name):
596 self.strides = strides
597 self.kernel_size = kernel_size
598 self.padding = padding
599 self.result_name = name
600
601 def eval(self, input):
602 return tf.nn.avg_pool2d(
603 input,
604 strides=self.strides,
605 ksize=self.kernel_size,
606 padding=self.padding,
607 data_format="NHWC",
608 name=self.result_name,
609 )
610
611 class MaxPool2d:
612 def __init__(self, strides, kernel_size, padding, name):
613 self.strides = strides
614 self.kernel_size = kernel_size
615 self.padding = padding
616 self.result_name = name
617
618 def eval(self, input):
619 return tf.nn.max_pool2d(
620 input,
621 strides=self.strides,
622 ksize=self.kernel_size,
623 padding=self.padding,
624 data_format="NHWC",
625 name=self.result_name,
626 )
627
628 class Reshape:
629 def __init__(self, shape, name):
630 self.shape = shape
631 self.result_name = name
632
633 def eval(self, a):
634 reshape_op = tf.reshape(a, self.shape)
635 return tf.identity(reshape_op, name=self.result_name)
636
637 class Transpose:
638 def __init__(self, perm, name):
639 self.perm = perm
640 self.result_name = name
641
642 def eval(self, a):
643 return tf.transpose(a, self.perm, name=self.result_name)
644
645 class Slice:
646 def __init__(self, begin, size, name):
647 self.begin = begin
648 self.size = size
649 self.result_name = name
650
651 def eval(self, a):
652 return tf.slice(a, begin=self.begin, size=self.size, name=self.result_name)
653
654 class StridedSlice:
655 def __init__(
656 self,
657 begin,
658 end,
659 strides,
660 begin_mask,
661 end_mask,
662 ellipsis_mask,
663 new_axis_mask,
664 shrink_axis_mask,
665 name,
666 ):
667 self.begin = begin
668 self.end = end
669 self.strides = strides
670 self.begin_mask = begin_mask
671 self.end_mask = end_mask
672 self.ellipsis_mask = ellipsis_mask
673 self.new_axis_mask = new_axis_mask
674 self.shrink_axis_mask = shrink_axis_mask
675 self.result_name = name
676
677 def eval(self, a):
678 return tf.strided_slice(
679 a,
680 begin=self.begin,
681 end=self.end,
682 strides=self.strides,
683 begin_mask=self.begin_mask,
684 end_mask=self.end_mask,
685 ellipsis_mask=self.ellipsis_mask,
686 new_axis_mask=self.new_axis_mask,
687 shrink_axis_mask=self.shrink_axis_mask,
688 name=self.result_name,
689 )
690
691 class Select:
692 def __init__(self, name):
693 self.result_name = name
694
695 def eval(self, selector, a, b):
696 return tf.where(condition=selector, x=a, y=b, name=self.result_name)
697
698 class Addn:
699 def __init__(self, name):
700 self.result_name = name
701
702 def eval(self, a, b, c, d):
703 return tf.add_n([a, b, c, d], name=self.result_name)
704
705 class Concatv2:
706 def __init__(self, axis, name):
707 self.axis = axis
708 self.result_name = name
709
710 def eval(self, a, b, c, d):
711 return tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
712
713 class Stack:
714 def __init__(self, axis, name):
715 self.axis = axis
716 self.result_name = name
717
718 def eval(self, a, b, c, d):
719 return tf.stack([a, b, c, d], axis=self.axis, name=self.result_name)
720
721 class Unstack:
722 def __init__(self, axis, name):
723 self.axis = axis
724 self.result_name = name
725
726 def eval(self, a):
727 unstack_op = tf.unstack(a, axis=self.axis, name="unstack_op")
728 result_count = a.shape[self.axis]
729
730 if result_count == 1:
731 return tf.identity(unstack_op[0], name=self.result_name)
732
733 sums = []
734 for i in range(result_count):
735 sums.append(
736 tf.math.reduce_sum(unstack_op[i], name="reduce_{}".format(i))
737 )
738 return tf.stack(sums, 0, name=self.result_name)
739
TatWai Chongf7008da2022-09-09 09:35:40 +0000740 class MirrorPad:
741 def __init__(self, padding, mode, name):
742 self.padding = padding
743 self.mode = mode
744 self.result_name = name
745
746 def eval(self, a):
747 return tf.pad(
748 a,
749 self.padding,
750 mode=self.mode,
751 constant_values=0,
752 name=self.result_name,
753 )
754
Jeremy Johnson015c3552022-02-23 12:15:03 +0000755 class Pad:
756 def __init__(self, padding, name):
757 self.padding = padding
758 self.result_name = name
759
760 def eval(self, a):
761 return tf.pad(
762 a,
763 self.padding,
764 mode="CONSTANT",
765 constant_values=0,
766 name=self.result_name,
767 )
768
769 class ExpandDims:
770 def __init__(self, axis, name):
771 self.axis = axis
772 self.result_name = name
773
774 def eval(self, a):
775 return tf.expand_dims(a, self.axis, name=self.result_name)
776
777 class Shape:
778 def __init__(self, name):
779 self.result_name = name
780
781 def eval(self, a):
782 return tf.shape(a, name=self.result_name)
783
784 class Rank:
785 def __init__(self, name):
786 self.result_name = name
787
788 def eval(self, a):
789 return tf.rank(a, name=self.result_name)
790
791 class Fill:
792 def __init__(self, shape, value, name):
793 self.shape = shape
794 self.value = value
795 self.result_name = name
796
797 def eval(self, a):
798 return tf.fill(self.shape, self.value, name=self.result_name)
799
800 class Elu:
801 def __init__(self, name):
802 self.result_name = name
803
804 def eval(self, a):
805 return tf.nn.elu(a, name=self.result_name)
806
807 class Softmax:
808 def __init__(self, name):
809 self.result_name = name
810
811 def eval(self, a):
812 return tf.nn.softmax(a, name=self.result_name)
813
814 class LogSoftmax:
815 def __init__(self, name):
816 self.result_name = name
817
818 def eval(self, a):
819 return tf.nn.log_softmax(a, name=self.result_name)
820
821 class MatMul:
822 def __init__(self, name):
823 self.result_name = name
824
825 def eval(self, a, b):
826 return tf.linalg.matmul(a, b, name=self.result_name)
827
828 class AddScalar:
829 def __init__(self, name):
830 self.result_name = name
831
832 def eval(self, a):
833 return tf.add(a, 1, name=self.result_name)
834
835 class Add1d:
836 def __init__(self, name):
837 self.result_name = name
838
839 def eval(self, a, b):
840 if len(b.shape) > 1:
841 b_1d = tf.reduce_sum(b, axis=list(range(0, len(b.shape) - 1, 1)))
842 else:
843 b_1d = b
844 return tf.add(a, b_1d, name=self.result_name)
845
846 class Split:
847 def __init__(self, num_splits, axis, name):
848 self.num_splits = num_splits
849 self.axis = axis
850 self.result_name = name
851
852 def eval(self, a):
853 # The split op generates a list of outputs. Since we have difficulty
854 # serializing a list or array of Numpy arrays, we will reduce each of
855 # the results
856
857 if not isinstance(self.num_splits, list):
858 split_op = tf.split(
859 a, num_or_size_splits=self.num_splits, axis=self.axis, name="split"
860 )
861 result_count = self.num_splits
862 else:
863 num_split = np.asarray(self.num_splits, dtype=np.int32)
864 split_vec_op = tf.compat.v1.constant(
865 num_split,
866 shape=num_split.shape,
867 dtype=tf.int32,
868 name="const_split_vec",
869 )
870 split_op = tf.split(
871 a, num_or_size_splits=split_vec_op, axis=self.axis, name="split"
872 )
873 result_count = num_split.shape[0]
874
875 sums = []
876 for i in range(result_count):
877 sums.append(tf.math.reduce_sum(split_op[i], name="reduce_{}".format(i)))
878 return tf.stack(sums, 0, name=self.result_name)
879
880 class Tile:
881 def __init__(self, multiples, name):
882 self.multiples = multiples
883 self.result_name = name
884
885 def eval(self, a):
886 t = tf.tile(a, self.multiples, name="tile")
887 return tf.identity(t, name=self.result_name)
888
889 class Reverse:
890 def __init__(self, axis, name):
891 self.axis = axis
892 self.result_name = name
893
894 def eval(self, a):
895 return tf.reverse(a, [self.axis], name=self.result_name)
896
897 class Gather:
898 def __init__(self, indices, batch_dims, axis, name):
899 self.indices = indices
900 self.batch_dims = batch_dims
901 self.axis = axis
902 self.result_name = name
903
904 def eval(self, a):
905 return tf.gather(
906 a,
907 self.indices,
908 batch_dims=self.batch_dims,
909 axis=self.axis,
910 name=self.result_name,
911 )
912
913 class GatherNd:
914 def __init__(self, indices, name):
915 self.indices = indices
916 self.result_name = name
917
918 def eval(self, a):
919 return tf.gather_nd(a, self.indices, name=self.result_name)
920
921 class ScatterNd:
922 def __init__(self, shape, indices_shape, N, rng, name):
923 self.shape = shape
924 self.indices_shape = indices_shape
925 self.N = N
926 self.rng = rng
927 self.result_name = name
928
929 def eval(self, a):
930
931 # This operator is special. The indices and updates tensors really need
932 # to be created together, but in the current structure of this tool there
933 # is no way to do that before now. The number of updates is determined by
934 # the indices, so we can really only create that after indices; but we
935 # don't know the type at that time.
936 #
937 # Shapes are guaranteed deterministic, but we'll use our rng
938 # copied from the arggen stage. It's possible that index and
939 # update *values* will be non-deterministic.
940 #
941 # We take the tensor_tensor simply to get the dtype.
942
943 shape_const = tf.constant(self.shape, tf.int32)
944
945 updates_shape = list(self.indices_shape[:-1])
946 updates_shape.extend(self.shape[self.indices_shape[-1] :])
947
948 updates_const = tf.constant(TGen.getRand(updates_shape, a.dtype, self.rng))
949
950 indices = np.zeros(self.indices_shape, dtype=np.int32)
951
952 # We need to generate the random indices tensor based on the
953 # limits of 'shape' for each dimension. Surely, there is a faster
954 # vectorized way to do this, but the tensors are fairly small so we
955 # will do this one element at a time. Each element needs to be sized based
956 # on the size of the last dimension.
957 for idx in np.ndindex(indices.shape):
958 indices[idx] = self.rng.integers(0, self.shape[idx[-1]], size=1)[0]
959 # print('{} {}'.format(idx, indices[idx]))
960
961 indices_const = tf.constant(indices, dtype=tf.int32)
962
963 return tf.scatter_nd(
964 indices=indices_const,
965 updates=updates_const,
966 shape=shape_const,
967 name=self.result_name,
968 )
969
970 class SpaceToBatch:
971 def __init__(self, block_shape, padding, name):
972 self.block_shape = block_shape
973 self.padding = padding
974 self.result_name = name
975
976 def eval(self, a):
977 return tf.space_to_batch(
978 a, self.block_shape, self.padding, name=self.result_name
979 )
980
981 class BatchToSpace:
982 def __init__(self, block_shape, cropping, name):
983 self.block_shape = block_shape
984 self.cropping = cropping
985 self.result_name = name
986
987 def eval(self, a):
988 # transpose to swap depth and batch first. this could avoid adding new shape
989 block_rank = len(self.block_shape)
990 perm = [len(a.shape) - 1]
991 for i in range(block_rank):
992 perm.append(i + 1)
993 perm.append(0)
994 transpose_op = tf.transpose(a, perm)
995 return tf.batch_to_space(
996 transpose_op, self.block_shape, self.cropping, name=self.result_name
997 )
998
999 class SpaceToDepth:
1000 def __init__(self, block_shape, name):
1001 self.block_shape = block_shape
1002 self.result_name = name
1003
1004 def eval(self, a):
1005 return tf.nn.space_to_depth(a, self.block_shape, name=self.result_name)
1006
1007 class DepthToSpace:
1008 def __init__(self, block_shape, name):
1009 self.block_shape = block_shape
1010 self.result_name = name
1011
1012 def eval(self, a):
1013 return tf.nn.depth_to_space(a, self.block_shape, name=self.result_name)
1014
1015 class OneHot:
1016 def __init__(self, depth, axis, name):
1017 self.depth = depth
1018 self.axis = axis
1019 self.result_name = name
1020
1021 def eval(self, indices, on_value, off_value):
1022 return tf.one_hot(
1023 indices,
1024 self.depth,
1025 on_value,
1026 off_value,
1027 self.axis,
1028 on_value.dtype,
1029 self.result_name,
1030 )
1031
1032 class Fakequant:
1033 def __init__(self, num_bits, narrow_range, name):
1034 self.num_bits = num_bits
1035 self.narrow_range = narrow_range
1036 self.result_name = name
1037
1038 def eval(self, a):
1039 return tf.quantization.fake_quant_with_min_max_args(
1040 a,
1041 min=-2.0,
1042 max=2.0,
1043 num_bits=self.num_bits,
1044 narrow_range=self.narrow_range,
1045 name=self.result_name,
1046 )
1047
1048 class ResizeNearest:
1049 def __init__(self, name):
1050 self.result_name = name
1051
1052 def eval(self, a):
1053 out_shape = []
1054 out_shape.append(a.shape[1] * 2)
1055 out_shape.append(a.shape[2] * 2)
1056
1057 # tf.image.resize() will overwrite the node name with result_name +
1058 # '/BILINEAR' need to add extra identity to force output tensor name to
1059 # result_name return tf.image.resize(a, out_shape,
1060 # method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, name=result_name)
1061 resize = tf.image.resize(
1062 a,
1063 out_shape,
1064 method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
1065 name="resize",
1066 )
1067 return tf.identity(resize, name=self.result_name)
1068
1069 class ResizeBilinear:
1070 def __init__(self, name):
1071 self.result_name = name
1072
1073 def eval(self, a):
1074 out_shape = []
1075 out_shape.append(a.shape[1] * 2)
1076 out_shape.append(a.shape[2] * 2)
1077
1078 # tf.image.resize() will overwrite the node name with result_name +
1079 # '/BILINEAR' need to add extra identity to force output tensor name to
1080 # result_name return tf.image.resize(a, out_shape,
1081 # method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, name=result_name)
1082 resize = tf.image.resize(
1083 a, out_shape, method=tf.image.ResizeMethod.BILINEAR, name="resize"
1084 )
1085 return tf.identity(resize, name=self.result_name)
1086
TatWai Chongf7326092022-06-08 12:17:14 -07001087 # New tf resize set (align_corners, half_pixel_centers) = (false, true) by default.
1088 # Test the rest option combinations here.
1089 # Note that (align_corners, half_pixel_centers) = (true, true) is NOT valid.
1090 class ResizeBilinearV1AlignCorners:
1091 def __init__(self, name):
1092 self.result_name = name
1093
1094 def eval(self, a):
1095 out_shape = []
1096 out_shape.append(a.shape[1] * 2)
1097 out_shape.append(a.shape[2] * 2)
1098
1099 resize = tf.compat.v1.image.resize_bilinear(
1100 a,
1101 out_shape,
1102 align_corners=True,
1103 name="resize",
1104 half_pixel_centers=False,
1105 )
1106 return tf.identity(resize, name=self.result_name)
1107
1108 class ResizeBilinearV1None:
1109 def __init__(self, name):
1110 self.result_name = name
1111
1112 def eval(self, a):
1113 out_shape = []
1114 out_shape.append(a.shape[1] * 2)
1115 out_shape.append(a.shape[2] * 2)
1116
1117 resize = tf.compat.v1.image.resize_bilinear(
1118 a,
1119 out_shape,
1120 align_corners=False,
1121 name="resize",
1122 half_pixel_centers=False,
1123 )
1124 return tf.identity(resize, name=self.result_name)
1125
Jeremy Johnson015c3552022-02-23 12:15:03 +00001126 class LeftShift:
1127 def __init__(self, shift, name):
1128 self.shift = shift
1129 self.result_name = name
1130
1131 def eval(self, a):
1132 return tf.bitwise.left_shift(a, self.shift, name=self.result_name)
1133
1134 class RightShift:
1135 def __init__(self, shift, name):
1136 self.shift = shift
1137 self.result_name = name
1138
1139 def eval(self, a):
1140 return tf.bitwise.right_shift(a, self.shift, name=self.result_name)