blob: c7ba9a9f6f6a3a1160f31bda431cdc591400b605 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00002# SPDX-License-Identifier: Apache-2.0
3import numpy as np
4import tensorflow as tf
5from frameworks.tensor_gen import TGen
6
7
8class TBuilder:
9 """The member functions build the tensorflow operators into small networks
10 for our tests"""
11
12 def __init__(self):
13 pass
14
15 def fake_quant(tensor, tensor_scale, name):
16 """Helper function for quantizing with a scaling parameters structure."""
17 return tf.quantization.fake_quant_with_min_max_args(
18 tensor,
19 min=tensor_scale.min,
20 max=tensor_scale.max,
21 num_bits=tensor_scale.num_bits,
22 narrow_range=tensor_scale.narrow_range,
23 name=name,
24 )
25
26 def fake_quant_params(tensor, min, max, scaling, name):
27 """Helper function for quantizing with individual scaling parameters."""
28 return tf.quantization.fake_quant_with_min_max_args(
29 tensor,
30 min=min,
31 max=max,
32 num_bits=scaling.num_bits,
33 narrow_range=scaling.narrow_range,
34 name=name,
35 )
36
37 class Add:
38 def __init__(self, name):
39 self.result_name = name
40
41 def eval(self, a, b):
42 return tf.add(a, b, name=self.result_name)
43
44 class Sub:
45 def __init__(self, name):
46 self.result_name = name
47
48 def eval(self, a, b):
49 return tf.subtract(a, b, name=self.result_name)
50
51 class Mul:
52 def __init__(self, name):
53 self.result_name = name
54
55 def eval(self, a, b):
56 return tf.multiply(a, b, name=self.result_name)
57
58 class Exp:
59 def __init__(self, name):
60 self.result_name = name
61
62 def eval(self, a):
63 return tf.exp(a, name=self.result_name)
64
65 class Rcp:
66 def __init__(self, name):
67 self.result_name = name
68
69 def eval(self, a):
70 return tf.math.reciprocal(a, name=self.result_name)
71
72 class Relu:
73 def __init__(self, name):
74 self.result_name = name
75
76 def eval(self, a):
77 return tf.nn.relu(a, name=self.result_name)
78
Jerry Ge93912432022-07-22 10:29:13 -070079 class Relu1:
80 def __init__(self, name):
81 self.result_name = name
82
83 def eval(self, a):
84 # TF doesn't have relu_n1_to_1 operator,
85 # use min and max as a workaround
86 # alternatively, we can use clip_by_value
87 return tf.math.minimum(1.0, tf.math.maximum(-1.0, a))
88
Jeremy Johnson015c3552022-02-23 12:15:03 +000089 class Relu6:
90 def __init__(self, name):
91 self.result_name = name
92
93 def eval(self, a):
94 return tf.nn.relu6(a, name=self.result_name)
95
96 class LeakyRelu:
97 def __init__(self, alpha, name):
98 self.alpha = alpha
99 self.result_name = name
100
101 def eval(self, a):
102 return tf.nn.leaky_relu(a, alpha=self.alpha, name=self.result_name)
103
TatWai Chong41a04fe2022-11-03 21:44:32 +0000104 class Prelu:
105 def __init__(self, name):
106 self.result_name = name
107 self.prelu = tf.keras.layers.PReLU(
108 alpha_initializer=tf.keras.initializers.RandomNormal(
109 mean=0.0, stddev=1.0
110 )
111 )
112
113 def eval(self, a):
114 return self.prelu(a)
115
TatWai Chong473eb382022-08-02 04:21:30 +0000116 class Gelu:
117 def __init__(self, name):
118 self.result_name = name
119
120 def eval(self, a):
121 return tf.nn.gelu(a, name=self.result_name)
122
Jeremy Johnson015c3552022-02-23 12:15:03 +0000123 class Concat:
124 def __init__(self, axis, name):
125 self.axis = axis
126 self.result_name = name
127
128 def eval(self, a, b):
129 return tf.concat([a, b], self.axis, name=self.result_name)
130
131 class BitwiseAnd:
132 def __init__(self, name):
133 self.result_name = name
134
135 def eval(self, a, b):
136 return tf.bitwise.bitwise_and(a, b, name=self.result_name)
137
138 class BitwiseOr:
139 def __init__(self, name):
140 self.result_name = name
141
142 def eval(self, a, b):
143 return tf.bitwise.bitwise_or(a, b, name=self.result_name)
144
145 class BitwiseNot:
146 def __init__(self, name):
147 self.result_name = name
148
149 def eval(self, a):
150 return tf.bitwise.invert(a, name=self.result_name)
151
152 class BitwiseXor:
153 def __init__(self, name):
154 self.result_name = name
155
156 def eval(self, a, b):
157 return tf.bitwise.bitwise_xor(a, b, name=self.result_name)
158
159 class LogicalAnd:
160 def __init__(self, name):
161 self.result_name = name
162
163 def eval(self, a, b):
164 return tf.math.logical_and(a, b, name=self.result_name)
165
166 class LogicalOr:
167 def __init__(self, name):
168 self.result_name = name
169
170 def eval(self, a, b):
171 return tf.math.logical_or(a, b, name=self.result_name)
172
173 class LogicalNot:
174 def __init__(self, name):
175 self.result_name = name
176
177 def eval(self, a):
178 return tf.math.logical_not(a, name=self.result_name)
179
180 class ReduceAny:
181 def __init__(self, axis_list, keepdims, name):
182 self.axis_list = axis_list
183 self.keepdims = keepdims
184 self.result_name = name
185
186 def eval(self, a):
187 return tf.math.reduce_any(
188 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
189 )
190
191 class ReduceAll:
192 def __init__(self, axis_list, keepdims, name):
193 self.axis_list = axis_list
194 self.keepdims = keepdims
195 self.result_name = name
196
197 def eval(self, a):
198 return tf.math.reduce_all(
199 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
200 )
201
202 class ReduceMin:
203 def __init__(self, axis_list, keepdims, name):
204 self.axis_list = axis_list
205 self.keepdims = keepdims
206 self.result_name = name
207
208 def eval(self, a):
209 return tf.math.reduce_min(
210 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
211 )
212
213 class ReduceMax:
214 def __init__(self, axis_list, keepdims, name):
215 self.axis_list = axis_list
216 self.keepdims = keepdims
217 self.result_name = name
218
219 def eval(self, a):
220 return tf.math.reduce_max(
221 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
222 )
223
224 class ReduceSum:
225 def __init__(self, axis_list, keepdims, name):
226 self.axis_list = axis_list
227 self.keepdims = keepdims
228 self.result_name = name
229
230 def eval(self, a):
231 return tf.math.reduce_sum(
232 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
233 )
234
235 class ReduceMean:
236 def __init__(self, axis_list, keepdims, name):
237 self.axis_list = axis_list
238 self.keepdims = keepdims
239 self.result_name = name
240
241 def eval(self, a):
242 return tf.math.reduce_mean(
243 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
244 )
245
246 class ReduceProduct:
247 def __init__(self, axis_list, keepdims, name):
248 self.axis_list = axis_list
249 self.keepdims = keepdims
250 self.result_name = name
251
252 def eval(self, a):
253 return tf.math.reduce_prod(
254 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
255 )
256
257 class Min:
258 def __init__(self, name):
259 self.result_name = name
260
261 def eval(self, a, b):
262 return tf.math.minimum(a, b, name=self.result_name)
263
264 class Max:
265 def __init__(self, name):
266 self.result_name = name
267
268 def eval(self, a, b):
269 return tf.math.maximum(a, b, name=self.result_name)
270
271 class Pow:
272 def __init__(self, name):
273 self.result_name = name
274
275 def eval(self, a, b):
276 return tf.math.pow(a, b, name=self.result_name)
277
278 class Abs:
279 def __init__(self, name):
280 self.result_name = name
281
282 def eval(self, a):
283 return tf.math.abs(a, name=self.result_name)
284
285 class Ceil:
286 def __init__(self, name):
287 self.result_name = name
288
289 def eval(self, a):
290 return tf.math.ceil(a, name=self.result_name)
291
292 class Floor:
293 def __init__(self, name):
294 self.result_name = name
295
296 def eval(self, a):
297 return tf.math.floor(a, name=self.result_name)
298
299 class Log:
300 def __init__(self, name):
301 self.result_name = name
302
303 def eval(self, a):
304 return tf.math.log(a, name=self.result_name)
305
306 class Negate:
307 def __init__(self, name):
308 self.result_name = name
309
310 def eval(self, a):
311 return tf.math.negative(a, name=self.result_name)
312
313 class Rsqrt:
314 def __init__(self, name):
315 self.result_name = name
316
317 def eval(self, a):
318 return tf.math.rsqrt(a, name=self.result_name)
319
TatWai Chongd713a4d2022-11-10 13:54:28 -0800320 class Sign:
321 def __init__(self, name):
322 self.result_name = name
323
324 def eval(self, a):
325 return tf.math.sign(a, name=self.result_name)
326
Jeremy Johnson015c3552022-02-23 12:15:03 +0000327 class Sigmoid:
328 def __init__(self, name):
329 self.result_name = name
330
331 def eval(self, a):
332 return tf.math.sigmoid(a, name=self.result_name)
333
334 class Tanh:
335 def __init__(self, name):
336 self.result_name = name
337
338 def eval(self, a):
339 return tf.math.tanh(a, name=self.result_name)
340
Luke Hutton41601862022-12-06 17:29:15 +0000341 class Sin:
342 def __init__(self, name):
343 self.result_name = name
344
345 def eval(self, a):
346 return tf.math.sin(a, name=self.result_name)
347
348 class Cos:
349 def __init__(self, name):
350 self.result_name = name
351
352 def eval(self, a):
353 return tf.math.cos(a, name=self.result_name)
354
Luke Hutton2138a192022-12-15 11:01:39 +0000355 class Atan2:
356 def __init__(self, name):
357 self.result_name = name
358
359 def eval(self, a, b):
360 return tf.math.atan2(a, b, name=self.result_name)
361
Jeremy Johnson015c3552022-02-23 12:15:03 +0000362 class Square:
363 def __init__(self, name):
364 self.result_name = name
365
366 def eval(self, a):
367 return tf.math.square(a, name=self.result_name)
368
369 class SquaredDifference:
370 def __init__(self, name):
371 self.result_name = name
372
373 def eval(self, a, b):
374 return tf.math.squared_difference(a, b, name=self.result_name)
375
376 class Equal:
377 def __init__(self, name):
378 self.result_name = name
379
380 def eval(self, a, b):
381 return tf.math.equal(a, b, name=self.result_name)
382
383 class GreaterEqual:
384 def __init__(self, name):
385 self.result_name = name
386
387 def eval(self, a, b):
388 return tf.math.greater_equal(a, b, name=self.result_name)
389
390 class Greater:
391 def __init__(self, name):
392 self.result_name = name
393
394 def eval(self, a, b):
395 return tf.math.greater(a, b, name=self.result_name)
396
397 class Less:
398 def __init__(self, name):
399 self.result_name = name
400
401 def eval(self, a, b):
402 return tf.math.less(a, b, name=self.result_name)
403
404 class LessEqual:
405 def __init__(self, name):
406 self.result_name = name
407
408 def eval(self, a, b):
409 return tf.math.less_equal(a, b, name=self.result_name)
410
411 class Conv2d:
412 def __init__(self, weight, strides, padding, dilations, name):
413 self.weight = weight
414 self.strides = strides
415 self.padding = padding
416 self.dilations = dilations
417 self.result_name = name
418
419 def eval(self, input):
420 return tf.nn.conv2d(
421 input,
422 self.weight,
423 self.strides,
424 self.padding,
425 data_format="NHWC",
426 dilations=self.dilations,
427 name=self.result_name,
428 )
429
430 class Conv2dRelu:
431 def __init__(self, weight, name):
432 self.weight = weight
433 self.result_name = name
434
435 def eval(self, input):
436 conv2d = tf.nn.conv2d(
437 input,
438 self.weight,
439 [1, 1, 1, 1],
440 "SAME",
441 data_format="NHWC",
442 dilations=[1, 1, 1, 1],
443 name="conv2d",
444 )
445 return tf.nn.relu(conv2d, name=self.result_name)
446
447 class Conv2dRelu6:
448 def __init__(self, weight, name):
449 self.weight = weight
450 self.result_name = name
451
452 def eval(self, input):
453 conv2d = tf.nn.conv2d(
454 input,
455 self.weight,
456 [1, 1, 1, 1],
457 "SAME",
458 data_format="NHWC",
459 dilations=[1, 1, 1, 1],
460 name="conv2d",
461 )
462 return tf.nn.relu6(conv2d, name=self.result_name)
463
464 class Conv2dReluN1To1:
465 def __init__(self, weight, name):
466 self.weight = weight
467 self.result_name = name
468
469 def eval(self, input):
470 conv2d = tf.nn.conv2d(
471 input,
472 self.weight,
473 [1, 1, 1, 1],
474 "SAME",
475 data_format="NHWC",
476 dilations=[1, 1, 1, 1],
477 name="conv2d",
478 )
479 return tf.clip_by_value(conv2d, -1.0, 1.0, name=self.result_name)
480
481 class Conv2dTanh:
482 def __init__(self, weight, name):
483 self.weight = weight
484 self.result_name = name
485
486 def eval(self, input):
487 conv2d = tf.nn.conv2d(
488 input,
489 self.weight,
490 [1, 1, 1, 1],
491 "SAME",
492 data_format="NHWC",
493 dilations=[1, 1, 1, 1],
494 name="conv2d",
495 )
496 return tf.math.tanh(conv2d, name=self.result_name)
497
498 class Conv2dWithBias:
499 def __init__(self, weight, bias, strides, padding, dilations, name):
500 self.weight = weight
501 self.bias = bias
502 self.strides = strides
503 self.padding = padding
504 self.dilations = dilations
505 self.result_name = name
506
507 def eval(self, input):
508 conv2d_op = tf.nn.conv2d(
509 input,
510 self.weight,
511 self.strides,
512 self.padding,
513 data_format="NHWC",
514 dilations=self.dilations,
515 name="conv2d",
516 )
517 bias_add_op = tf.nn.bias_add(
518 conv2d_op, self.bias, data_format="NHWC", name=self.result_name
519 )
520 return bias_add_op
521
TatWai Chongfd629052022-07-25 04:01:58 +0000522 class Conv3d:
523 def __init__(self, weight, strides, padding, dilations, name):
524 self.weight = weight
525 self.strides = strides
526 self.padding = padding
527 self.dilations = dilations
528 self.result_name = name
529
530 def eval(self, input):
531 return tf.nn.conv3d(
532 input,
533 self.weight,
534 self.strides,
535 self.padding,
536 data_format="NDHWC",
537 dilations=self.dilations,
538 name=self.result_name,
539 )
540
541 class Conv3dWithBias:
542 def __init__(self, weight, bias, strides, padding, dilations, name):
543 self.weight = weight
544 self.bias = bias
545 self.strides = strides
546 self.padding = padding
547 self.dilations = dilations
548 self.result_name = name
549
550 def eval(self, input):
551 conv3d_op = tf.nn.conv3d(
552 input,
553 self.weight,
554 self.strides,
555 self.padding,
556 data_format="NDHWC",
557 dilations=self.dilations,
558 name="conv3d",
559 )
560 bias_add_op = tf.nn.bias_add(conv3d_op, self.bias, name=self.result_name)
561 return bias_add_op
562
Jeremy Johnson015c3552022-02-23 12:15:03 +0000563 class DepthwiseConv2d:
564 def __init__(self, weight, strides, padding, dilations, name):
565 self.weight = weight
566 self.strides = strides
567 self.padding = padding
568 self.dilations = dilations
569 self.result_name = name
570
571 def eval(self, input):
572 dws_conv2d = tf.nn.depthwise_conv2d(
573 input,
574 self.weight,
575 self.strides,
576 self.padding,
577 data_format="NHWC",
578 dilations=self.dilations,
579 name="dws_conv2d",
580 )
581 return tf.identity(dws_conv2d, name=self.result_name)
582
583 class DepthwiseConv2dWithBias:
584 def __init__(self, weight, bias, strides, padding, dilations, name):
585 self.weight = weight
586 self.bias = bias
587 self.strides = strides
588 self.padding = padding
589 self.dilations = dilations
590 self.result_name = name
591
592 def eval(self, input):
593 dws_conv2d = tf.nn.depthwise_conv2d(
594 input,
595 self.weight,
596 self.strides,
597 self.padding,
598 data_format="NHWC",
599 dilations=self.dilations,
600 name="dws_conv2d",
601 )
602 bias_add_op = tf.nn.bias_add(
603 dws_conv2d, self.bias, data_format="NHWC", name=self.result_name
604 )
605 return bias_add_op
606
607 class TransposeConv2d:
608 def __init__(self, weight, output_shape, strides, padding, name):
609 self.weight = weight
610 self.output_shape = output_shape
611 self.strides = strides
612 self.padding = padding
613 self.result_name = name
614
615 def eval(self, input):
616 return tf.nn.conv2d_transpose(
617 input,
618 self.weight,
619 self.output_shape,
620 self.strides,
621 self.padding,
622 data_format="NHWC",
623 name=self.result_name,
624 )
625
626 class Argmax:
627 def __init__(self, axis, name):
628 self.axis = axis
629 self.result_name = name
630
631 def eval(self, a):
632 return tf.argmax(a, self.axis, output_type=tf.int32, name=self.result_name)
633
634 class AvgPool2d:
635 def __init__(self, strides, kernel_size, padding, name):
636 self.strides = strides
637 self.kernel_size = kernel_size
638 self.padding = padding
639 self.result_name = name
640
641 def eval(self, input):
642 return tf.nn.avg_pool2d(
643 input,
644 strides=self.strides,
645 ksize=self.kernel_size,
646 padding=self.padding,
647 data_format="NHWC",
648 name=self.result_name,
649 )
650
651 class MaxPool2d:
652 def __init__(self, strides, kernel_size, padding, name):
653 self.strides = strides
654 self.kernel_size = kernel_size
655 self.padding = padding
656 self.result_name = name
657
658 def eval(self, input):
659 return tf.nn.max_pool2d(
660 input,
661 strides=self.strides,
662 ksize=self.kernel_size,
663 padding=self.padding,
664 data_format="NHWC",
665 name=self.result_name,
666 )
667
668 class Reshape:
669 def __init__(self, shape, name):
670 self.shape = shape
671 self.result_name = name
672
673 def eval(self, a):
674 reshape_op = tf.reshape(a, self.shape)
675 return tf.identity(reshape_op, name=self.result_name)
676
677 class Transpose:
678 def __init__(self, perm, name):
679 self.perm = perm
680 self.result_name = name
681
682 def eval(self, a):
683 return tf.transpose(a, self.perm, name=self.result_name)
684
685 class Slice:
686 def __init__(self, begin, size, name):
687 self.begin = begin
688 self.size = size
689 self.result_name = name
690
691 def eval(self, a):
692 return tf.slice(a, begin=self.begin, size=self.size, name=self.result_name)
693
694 class StridedSlice:
695 def __init__(
696 self,
697 begin,
698 end,
699 strides,
700 begin_mask,
701 end_mask,
702 ellipsis_mask,
703 new_axis_mask,
704 shrink_axis_mask,
705 name,
706 ):
707 self.begin = begin
708 self.end = end
709 self.strides = strides
710 self.begin_mask = begin_mask
711 self.end_mask = end_mask
712 self.ellipsis_mask = ellipsis_mask
713 self.new_axis_mask = new_axis_mask
714 self.shrink_axis_mask = shrink_axis_mask
715 self.result_name = name
716
717 def eval(self, a):
718 return tf.strided_slice(
719 a,
720 begin=self.begin,
721 end=self.end,
722 strides=self.strides,
723 begin_mask=self.begin_mask,
724 end_mask=self.end_mask,
725 ellipsis_mask=self.ellipsis_mask,
726 new_axis_mask=self.new_axis_mask,
727 shrink_axis_mask=self.shrink_axis_mask,
728 name=self.result_name,
729 )
730
731 class Select:
732 def __init__(self, name):
733 self.result_name = name
734
735 def eval(self, selector, a, b):
736 return tf.where(condition=selector, x=a, y=b, name=self.result_name)
737
738 class Addn:
739 def __init__(self, name):
740 self.result_name = name
741
742 def eval(self, a, b, c, d):
743 return tf.add_n([a, b, c, d], name=self.result_name)
744
745 class Concatv2:
746 def __init__(self, axis, name):
747 self.axis = axis
748 self.result_name = name
749
750 def eval(self, a, b, c, d):
751 return tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
752
753 class Stack:
754 def __init__(self, axis, name):
755 self.axis = axis
756 self.result_name = name
757
758 def eval(self, a, b, c, d):
759 return tf.stack([a, b, c, d], axis=self.axis, name=self.result_name)
760
761 class Unstack:
762 def __init__(self, axis, name):
763 self.axis = axis
764 self.result_name = name
765
766 def eval(self, a):
767 unstack_op = tf.unstack(a, axis=self.axis, name="unstack_op")
768 result_count = a.shape[self.axis]
769
770 if result_count == 1:
771 return tf.identity(unstack_op[0], name=self.result_name)
772
773 sums = []
774 for i in range(result_count):
775 sums.append(
776 tf.math.reduce_sum(unstack_op[i], name="reduce_{}".format(i))
777 )
778 return tf.stack(sums, 0, name=self.result_name)
779
TatWai Chongf7008da2022-09-09 09:35:40 +0000780 class MirrorPad:
781 def __init__(self, padding, mode, name):
782 self.padding = padding
783 self.mode = mode
784 self.result_name = name
785
786 def eval(self, a):
787 return tf.pad(
788 a,
789 self.padding,
790 mode=self.mode,
791 constant_values=0,
792 name=self.result_name,
793 )
794
Jeremy Johnson015c3552022-02-23 12:15:03 +0000795 class Pad:
TatWai Chong2226f902023-02-22 18:38:01 -0800796 def __init__(self, padding, pad_const, name):
Jeremy Johnson015c3552022-02-23 12:15:03 +0000797 self.padding = padding
TatWai Chong2226f902023-02-22 18:38:01 -0800798 self.pad_const = pad_const
Jeremy Johnson015c3552022-02-23 12:15:03 +0000799 self.result_name = name
800
801 def eval(self, a):
802 return tf.pad(
803 a,
804 self.padding,
805 mode="CONSTANT",
TatWai Chong2226f902023-02-22 18:38:01 -0800806 constant_values=self.pad_const,
Jeremy Johnson015c3552022-02-23 12:15:03 +0000807 name=self.result_name,
808 )
809
810 class ExpandDims:
811 def __init__(self, axis, name):
812 self.axis = axis
813 self.result_name = name
814
815 def eval(self, a):
816 return tf.expand_dims(a, self.axis, name=self.result_name)
817
818 class Shape:
819 def __init__(self, name):
820 self.result_name = name
821
822 def eval(self, a):
823 return tf.shape(a, name=self.result_name)
824
825 class Rank:
826 def __init__(self, name):
827 self.result_name = name
828
829 def eval(self, a):
830 return tf.rank(a, name=self.result_name)
831
832 class Fill:
833 def __init__(self, shape, value, name):
834 self.shape = shape
835 self.value = value
836 self.result_name = name
837
838 def eval(self, a):
839 return tf.fill(self.shape, self.value, name=self.result_name)
840
841 class Elu:
842 def __init__(self, name):
843 self.result_name = name
844
845 def eval(self, a):
846 return tf.nn.elu(a, name=self.result_name)
847
848 class Softmax:
849 def __init__(self, name):
850 self.result_name = name
851
852 def eval(self, a):
853 return tf.nn.softmax(a, name=self.result_name)
854
855 class LogSoftmax:
856 def __init__(self, name):
857 self.result_name = name
858
859 def eval(self, a):
860 return tf.nn.log_softmax(a, name=self.result_name)
861
862 class MatMul:
863 def __init__(self, name):
864 self.result_name = name
865
866 def eval(self, a, b):
867 return tf.linalg.matmul(a, b, name=self.result_name)
868
869 class AddScalar:
870 def __init__(self, name):
871 self.result_name = name
872
873 def eval(self, a):
874 return tf.add(a, 1, name=self.result_name)
875
876 class Add1d:
877 def __init__(self, name):
878 self.result_name = name
879
880 def eval(self, a, b):
881 if len(b.shape) > 1:
882 b_1d = tf.reduce_sum(b, axis=list(range(0, len(b.shape) - 1, 1)))
883 else:
884 b_1d = b
885 return tf.add(a, b_1d, name=self.result_name)
886
887 class Split:
888 def __init__(self, num_splits, axis, name):
889 self.num_splits = num_splits
890 self.axis = axis
891 self.result_name = name
892
893 def eval(self, a):
894 # The split op generates a list of outputs. Since we have difficulty
895 # serializing a list or array of Numpy arrays, we will reduce each of
896 # the results
897
898 if not isinstance(self.num_splits, list):
899 split_op = tf.split(
900 a, num_or_size_splits=self.num_splits, axis=self.axis, name="split"
901 )
902 result_count = self.num_splits
903 else:
904 num_split = np.asarray(self.num_splits, dtype=np.int32)
905 split_vec_op = tf.compat.v1.constant(
906 num_split,
907 shape=num_split.shape,
908 dtype=tf.int32,
909 name="const_split_vec",
910 )
911 split_op = tf.split(
912 a, num_or_size_splits=split_vec_op, axis=self.axis, name="split"
913 )
914 result_count = num_split.shape[0]
915
916 sums = []
917 for i in range(result_count):
918 sums.append(tf.math.reduce_sum(split_op[i], name="reduce_{}".format(i)))
919 return tf.stack(sums, 0, name=self.result_name)
920
921 class Tile:
922 def __init__(self, multiples, name):
923 self.multiples = multiples
924 self.result_name = name
925
926 def eval(self, a):
927 t = tf.tile(a, self.multiples, name="tile")
928 return tf.identity(t, name=self.result_name)
929
930 class Reverse:
931 def __init__(self, axis, name):
932 self.axis = axis
933 self.result_name = name
934
935 def eval(self, a):
936 return tf.reverse(a, [self.axis], name=self.result_name)
937
938 class Gather:
939 def __init__(self, indices, batch_dims, axis, name):
940 self.indices = indices
941 self.batch_dims = batch_dims
942 self.axis = axis
943 self.result_name = name
944
945 def eval(self, a):
946 return tf.gather(
947 a,
948 self.indices,
949 batch_dims=self.batch_dims,
950 axis=self.axis,
951 name=self.result_name,
952 )
953
954 class GatherNd:
955 def __init__(self, indices, name):
956 self.indices = indices
957 self.result_name = name
958
959 def eval(self, a):
960 return tf.gather_nd(a, self.indices, name=self.result_name)
961
962 class ScatterNd:
963 def __init__(self, shape, indices_shape, N, rng, name):
964 self.shape = shape
965 self.indices_shape = indices_shape
966 self.N = N
967 self.rng = rng
968 self.result_name = name
969
970 def eval(self, a):
971
972 # This operator is special. The indices and updates tensors really need
973 # to be created together, but in the current structure of this tool there
974 # is no way to do that before now. The number of updates is determined by
975 # the indices, so we can really only create that after indices; but we
976 # don't know the type at that time.
977 #
978 # Shapes are guaranteed deterministic, but we'll use our rng
979 # copied from the arggen stage. It's possible that index and
980 # update *values* will be non-deterministic.
981 #
982 # We take the tensor_tensor simply to get the dtype.
983
984 shape_const = tf.constant(self.shape, tf.int32)
985
986 updates_shape = list(self.indices_shape[:-1])
987 updates_shape.extend(self.shape[self.indices_shape[-1] :])
988
989 updates_const = tf.constant(TGen.getRand(updates_shape, a.dtype, self.rng))
990
991 indices = np.zeros(self.indices_shape, dtype=np.int32)
992
993 # We need to generate the random indices tensor based on the
994 # limits of 'shape' for each dimension. Surely, there is a faster
995 # vectorized way to do this, but the tensors are fairly small so we
996 # will do this one element at a time. Each element needs to be sized based
997 # on the size of the last dimension.
998 for idx in np.ndindex(indices.shape):
999 indices[idx] = self.rng.integers(0, self.shape[idx[-1]], size=1)[0]
1000 # print('{} {}'.format(idx, indices[idx]))
1001
1002 indices_const = tf.constant(indices, dtype=tf.int32)
1003
1004 return tf.scatter_nd(
1005 indices=indices_const,
1006 updates=updates_const,
1007 shape=shape_const,
1008 name=self.result_name,
1009 )
1010
1011 class SpaceToBatch:
1012 def __init__(self, block_shape, padding, name):
1013 self.block_shape = block_shape
1014 self.padding = padding
1015 self.result_name = name
1016
1017 def eval(self, a):
1018 return tf.space_to_batch(
1019 a, self.block_shape, self.padding, name=self.result_name
1020 )
1021
1022 class BatchToSpace:
1023 def __init__(self, block_shape, cropping, name):
1024 self.block_shape = block_shape
1025 self.cropping = cropping
1026 self.result_name = name
1027
1028 def eval(self, a):
1029 # transpose to swap depth and batch first. this could avoid adding new shape
1030 block_rank = len(self.block_shape)
1031 perm = [len(a.shape) - 1]
1032 for i in range(block_rank):
1033 perm.append(i + 1)
1034 perm.append(0)
1035 transpose_op = tf.transpose(a, perm)
1036 return tf.batch_to_space(
1037 transpose_op, self.block_shape, self.cropping, name=self.result_name
1038 )
1039
1040 class SpaceToDepth:
1041 def __init__(self, block_shape, name):
1042 self.block_shape = block_shape
1043 self.result_name = name
1044
1045 def eval(self, a):
1046 return tf.nn.space_to_depth(a, self.block_shape, name=self.result_name)
1047
1048 class DepthToSpace:
1049 def __init__(self, block_shape, name):
1050 self.block_shape = block_shape
1051 self.result_name = name
1052
1053 def eval(self, a):
1054 return tf.nn.depth_to_space(a, self.block_shape, name=self.result_name)
1055
1056 class OneHot:
1057 def __init__(self, depth, axis, name):
1058 self.depth = depth
1059 self.axis = axis
1060 self.result_name = name
1061
1062 def eval(self, indices, on_value, off_value):
1063 return tf.one_hot(
1064 indices,
1065 self.depth,
1066 on_value,
1067 off_value,
1068 self.axis,
1069 on_value.dtype,
1070 self.result_name,
1071 )
1072
1073 class Fakequant:
1074 def __init__(self, num_bits, narrow_range, name):
1075 self.num_bits = num_bits
1076 self.narrow_range = narrow_range
1077 self.result_name = name
1078
1079 def eval(self, a):
1080 return tf.quantization.fake_quant_with_min_max_args(
1081 a,
1082 min=-2.0,
1083 max=2.0,
1084 num_bits=self.num_bits,
1085 narrow_range=self.narrow_range,
1086 name=self.result_name,
1087 )
1088
TatWai Chong0cef07e2023-02-27 13:22:52 -08001089 class Resize:
1090 def __init__(self, mode, align, half, scale, name):
Jeremy Johnson015c3552022-02-23 12:15:03 +00001091 self.result_name = name
TatWai Chong0cef07e2023-02-27 13:22:52 -08001092 self.mode = mode
1093 self.align = align
1094 self.half = half
1095 self.scale = scale
Jeremy Johnson015c3552022-02-23 12:15:03 +00001096
1097 def eval(self, a):
1098 out_shape = []
TatWai Chong0cef07e2023-02-27 13:22:52 -08001099 out_shape.append(a.shape[1] * self.scale)
1100 out_shape.append(a.shape[2] * self.scale)
Jeremy Johnson015c3552022-02-23 12:15:03 +00001101
TatWai Chong0cef07e2023-02-27 13:22:52 -08001102 tf_resize_dict = (
1103 {"tf_resize_func": tf.compat.v1.image.resize_nearest_neighbor}
1104 if (self.mode == "nearest")
1105 else {"tf_resize_func": tf.compat.v1.image.resize_bilinear}
1106 )
1107 resize = tf_resize_dict["tf_resize_func"](
Jeremy Johnson015c3552022-02-23 12:15:03 +00001108 a,
1109 out_shape,
TatWai Chong0cef07e2023-02-27 13:22:52 -08001110 align_corners=self.align,
Jeremy Johnson015c3552022-02-23 12:15:03 +00001111 name="resize",
TatWai Chong0cef07e2023-02-27 13:22:52 -08001112 half_pixel_centers=self.half,
TatWai Chongf7326092022-06-08 12:17:14 -07001113 )
1114 return tf.identity(resize, name=self.result_name)
1115
Jeremy Johnson015c3552022-02-23 12:15:03 +00001116 class LeftShift:
1117 def __init__(self, shift, name):
1118 self.shift = shift
1119 self.result_name = name
1120
1121 def eval(self, a):
1122 return tf.bitwise.left_shift(a, self.shift, name=self.result_name)
1123
1124 class RightShift:
1125 def __init__(self, shift, name):
1126 self.shift = shift
1127 self.result_name = name
1128
1129 def eval(self, a):
1130 return tf.bitwise.right_shift(a, self.shift, name=self.result_name)
Jerry Ge9e94af82022-10-27 09:57:00 -07001131
1132 class While:
1133 def __init__(self, name):
1134 self.result_name = name
1135
1136 def while_cond(self, x):
1137 return tf.reduce_sum(x) < self.cap
1138
1139 def while_body(self, x):
1140 return tf.add(x, tf.math.sigmoid(x))
1141
1142 def eval(self, a):
1143 self.cap = tf.cast(
1144 tf.constant(
1145 2.0,
1146 shape=[
1147 1,
1148 ],
1149 ),
1150 a.dtype,
1151 )
1152
1153 result = tf.while_loop(
1154 self.while_cond, self.while_body, [a], name=self.result_name
1155 )
1156
1157 return result[0]
1158
1159 class LSTM:
1160 def __init__(self, name):
1161 self.result_name = name
1162 self.lstm = tf.keras.layers.LSTM(
1163 2,
1164 activation="tanh",
1165 unroll=False,
1166 recurrent_activation="sigmoid",
1167 use_bias=True,
1168 recurrent_initializer="ones",
1169 kernel_initializer="ones",
1170 )
1171
1172 def eval(self, a):
1173 return self.lstm(a)
1174
1175 class GRU:
1176 def __init__(self, name):
1177 self.result_name = name
1178 self.lstm = tf.keras.layers.GRU(
1179 2,
1180 recurrent_activation="sigmoid",
1181 use_bias=True,
1182 recurrent_initializer="ones",
1183 kernel_initializer="ones",
1184 )
1185
1186 def eval(self, a):
1187 return self.lstm(a)
1188
1189 class RNN:
1190 def __init__(self, name):
1191 self.result_name = name
1192 basic_cell = tf.keras.layers.SimpleRNNCell(
1193 units=2,
1194 activation="sigmoid",
1195 use_bias=True,
1196 recurrent_initializer="ones",
1197 )
1198 self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False)
1199
1200 def eval(self, a):
1201 return self.rnn(a)
1202
1203 class FullyConnected:
1204 def __init__(self, name):
1205 self.result_name = name
1206 self.dense = tf.keras.layers.Dense(2)
1207
1208 def eval(self, a):
1209 return self.dense(a)
Luke Hutton261b7b62023-01-10 14:50:31 +00001210
1211 class RFFT2d:
1212 def __init__(self, fft_length, name):
1213 self.fft_length = fft_length
1214 self.result_name = name
1215
1216 def eval(self, a):
1217 return tf.signal.rfft2d(a, self.fft_length, name=self.result_name)