blob: 900f5fa10a3e9aec1365b2d9011b3a6745cc5bf8 [file] [log] [blame]
TatWai Chongbef907a2024-01-23 09:40:37 -08001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00002# SPDX-License-Identifier: Apache-2.0
Jerry Gec4448362024-03-21 17:33:55 +00003import os
4
Jeremy Johnson015c3552022-02-23 12:15:03 +00005import numpy as np
6import tensorflow as tf
7from frameworks.tensor_gen import TGen
8
Jerry Gec4448362024-03-21 17:33:55 +00009os.environ["TF_USE_LEGACY_KERAS"] = "1"
10
Jeremy Johnson015c3552022-02-23 12:15:03 +000011
12class TBuilder:
13 """The member functions build the tensorflow operators into small networks
14 for our tests"""
15
16 def __init__(self):
17 pass
18
19 def fake_quant(tensor, tensor_scale, name):
20 """Helper function for quantizing with a scaling parameters structure."""
21 return tf.quantization.fake_quant_with_min_max_args(
22 tensor,
23 min=tensor_scale.min,
24 max=tensor_scale.max,
25 num_bits=tensor_scale.num_bits,
26 narrow_range=tensor_scale.narrow_range,
27 name=name,
28 )
29
30 def fake_quant_params(tensor, min, max, scaling, name):
31 """Helper function for quantizing with individual scaling parameters."""
32 return tf.quantization.fake_quant_with_min_max_args(
33 tensor,
34 min=min,
35 max=max,
36 num_bits=scaling.num_bits,
37 narrow_range=scaling.narrow_range,
38 name=name,
39 )
40
41 class Add:
42 def __init__(self, name):
43 self.result_name = name
44
45 def eval(self, a, b):
46 return tf.add(a, b, name=self.result_name)
47
48 class Sub:
49 def __init__(self, name):
50 self.result_name = name
51
52 def eval(self, a, b):
53 return tf.subtract(a, b, name=self.result_name)
54
55 class Mul:
56 def __init__(self, name):
57 self.result_name = name
58
59 def eval(self, a, b):
60 return tf.multiply(a, b, name=self.result_name)
61
62 class Exp:
63 def __init__(self, name):
64 self.result_name = name
65
66 def eval(self, a):
67 return tf.exp(a, name=self.result_name)
68
69 class Rcp:
70 def __init__(self, name):
71 self.result_name = name
72
73 def eval(self, a):
74 return tf.math.reciprocal(a, name=self.result_name)
75
76 class Relu:
77 def __init__(self, name):
78 self.result_name = name
79
80 def eval(self, a):
81 return tf.nn.relu(a, name=self.result_name)
82
Jerry Ge93912432022-07-22 10:29:13 -070083 class Relu1:
84 def __init__(self, name):
85 self.result_name = name
86
87 def eval(self, a):
88 # TF doesn't have relu_n1_to_1 operator,
89 # use min and max as a workaround
90 # alternatively, we can use clip_by_value
91 return tf.math.minimum(1.0, tf.math.maximum(-1.0, a))
92
Jerry Ge2eea5bf2022-10-11 16:27:05 +000093 class Relu0To1:
94 def __init__(self, name):
95 self.result_name = name
96
97 def eval(self, a):
98 # TF doesn't have relu_0_to_1 operator,
99 # use min and max as a workaround
100 # alternatively, we can use clip_by_value
101 return tf.math.minimum(1.0, tf.math.maximum(0.0, a))
102
Jeremy Johnson015c3552022-02-23 12:15:03 +0000103 class Relu6:
104 def __init__(self, name):
105 self.result_name = name
106
107 def eval(self, a):
108 return tf.nn.relu6(a, name=self.result_name)
109
110 class LeakyRelu:
111 def __init__(self, alpha, name):
112 self.alpha = alpha
113 self.result_name = name
114
115 def eval(self, a):
116 return tf.nn.leaky_relu(a, alpha=self.alpha, name=self.result_name)
117
TatWai Chong41a04fe2022-11-03 21:44:32 +0000118 class Prelu:
119 def __init__(self, name):
120 self.result_name = name
121 self.prelu = tf.keras.layers.PReLU(
122 alpha_initializer=tf.keras.initializers.RandomNormal(
123 mean=0.0, stddev=1.0
124 )
125 )
126
127 def eval(self, a):
128 return self.prelu(a)
129
TatWai Chong473eb382022-08-02 04:21:30 +0000130 class Gelu:
131 def __init__(self, name):
132 self.result_name = name
133
134 def eval(self, a):
135 return tf.nn.gelu(a, name=self.result_name)
136
Jeremy Johnson015c3552022-02-23 12:15:03 +0000137 class Concat:
138 def __init__(self, axis, name):
139 self.axis = axis
140 self.result_name = name
141
142 def eval(self, a, b):
Won Jeonf9c0cee2023-09-18 16:32:45 -0700143 return (
144 tf.concat([a, b], self.axis, name=self.result_name)
145 if a.shape != ()
146 else tf.stack([a, b], name=self.result_name)
147 )
Jeremy Johnson015c3552022-02-23 12:15:03 +0000148
149 class BitwiseAnd:
150 def __init__(self, name):
151 self.result_name = name
152
153 def eval(self, a, b):
154 return tf.bitwise.bitwise_and(a, b, name=self.result_name)
155
156 class BitwiseOr:
157 def __init__(self, name):
158 self.result_name = name
159
160 def eval(self, a, b):
161 return tf.bitwise.bitwise_or(a, b, name=self.result_name)
162
163 class BitwiseNot:
164 def __init__(self, name):
165 self.result_name = name
166
167 def eval(self, a):
168 return tf.bitwise.invert(a, name=self.result_name)
169
170 class BitwiseXor:
171 def __init__(self, name):
172 self.result_name = name
173
174 def eval(self, a, b):
175 return tf.bitwise.bitwise_xor(a, b, name=self.result_name)
176
177 class LogicalAnd:
178 def __init__(self, name):
179 self.result_name = name
180
181 def eval(self, a, b):
182 return tf.math.logical_and(a, b, name=self.result_name)
183
184 class LogicalOr:
185 def __init__(self, name):
186 self.result_name = name
187
188 def eval(self, a, b):
189 return tf.math.logical_or(a, b, name=self.result_name)
190
191 class LogicalNot:
192 def __init__(self, name):
193 self.result_name = name
194
195 def eval(self, a):
196 return tf.math.logical_not(a, name=self.result_name)
197
198 class ReduceAny:
199 def __init__(self, axis_list, keepdims, name):
200 self.axis_list = axis_list
201 self.keepdims = keepdims
202 self.result_name = name
203
204 def eval(self, a):
205 return tf.math.reduce_any(
206 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
207 )
208
209 class ReduceAll:
210 def __init__(self, axis_list, keepdims, name):
211 self.axis_list = axis_list
212 self.keepdims = keepdims
213 self.result_name = name
214
215 def eval(self, a):
216 return tf.math.reduce_all(
217 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
218 )
219
220 class ReduceMin:
221 def __init__(self, axis_list, keepdims, name):
222 self.axis_list = axis_list
223 self.keepdims = keepdims
224 self.result_name = name
225
226 def eval(self, a):
227 return tf.math.reduce_min(
228 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
229 )
230
231 class ReduceMax:
232 def __init__(self, axis_list, keepdims, name):
233 self.axis_list = axis_list
234 self.keepdims = keepdims
235 self.result_name = name
236
237 def eval(self, a):
238 return tf.math.reduce_max(
239 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
240 )
241
242 class ReduceSum:
243 def __init__(self, axis_list, keepdims, name):
244 self.axis_list = axis_list
245 self.keepdims = keepdims
246 self.result_name = name
247
248 def eval(self, a):
249 return tf.math.reduce_sum(
250 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
251 )
252
253 class ReduceMean:
254 def __init__(self, axis_list, keepdims, name):
255 self.axis_list = axis_list
256 self.keepdims = keepdims
257 self.result_name = name
258
259 def eval(self, a):
260 return tf.math.reduce_mean(
261 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
262 )
263
264 class ReduceProduct:
265 def __init__(self, axis_list, keepdims, name):
266 self.axis_list = axis_list
267 self.keepdims = keepdims
268 self.result_name = name
269
270 def eval(self, a):
271 return tf.math.reduce_prod(
272 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
273 )
274
275 class Min:
276 def __init__(self, name):
277 self.result_name = name
278
279 def eval(self, a, b):
280 return tf.math.minimum(a, b, name=self.result_name)
281
282 class Max:
283 def __init__(self, name):
284 self.result_name = name
285
286 def eval(self, a, b):
287 return tf.math.maximum(a, b, name=self.result_name)
288
289 class Pow:
290 def __init__(self, name):
291 self.result_name = name
292
293 def eval(self, a, b):
294 return tf.math.pow(a, b, name=self.result_name)
295
296 class Abs:
297 def __init__(self, name):
298 self.result_name = name
299
300 def eval(self, a):
301 return tf.math.abs(a, name=self.result_name)
302
303 class Ceil:
304 def __init__(self, name):
305 self.result_name = name
306
307 def eval(self, a):
308 return tf.math.ceil(a, name=self.result_name)
309
310 class Floor:
311 def __init__(self, name):
312 self.result_name = name
313
314 def eval(self, a):
315 return tf.math.floor(a, name=self.result_name)
316
317 class Log:
318 def __init__(self, name):
319 self.result_name = name
320
321 def eval(self, a):
322 return tf.math.log(a, name=self.result_name)
323
324 class Negate:
325 def __init__(self, name):
326 self.result_name = name
327
328 def eval(self, a):
329 return tf.math.negative(a, name=self.result_name)
330
331 class Rsqrt:
332 def __init__(self, name):
333 self.result_name = name
334
335 def eval(self, a):
336 return tf.math.rsqrt(a, name=self.result_name)
337
TatWai Chongd713a4d2022-11-10 13:54:28 -0800338 class Sign:
339 def __init__(self, name):
340 self.result_name = name
341
342 def eval(self, a):
343 return tf.math.sign(a, name=self.result_name)
344
Jeremy Johnson015c3552022-02-23 12:15:03 +0000345 class Sigmoid:
346 def __init__(self, name):
347 self.result_name = name
348
349 def eval(self, a):
350 return tf.math.sigmoid(a, name=self.result_name)
351
352 class Tanh:
353 def __init__(self, name):
354 self.result_name = name
355
356 def eval(self, a):
357 return tf.math.tanh(a, name=self.result_name)
358
Won Jeon78155c62023-06-10 00:20:04 +0000359 class Erf:
360 # tfl.ops cannot be generated right now.
361 # https://github.com/tensorflow/tensorflow/issues/60809
362 def __init__(self, name):
363 self.result_name = name
364
365 def eval(self, a):
366 return tf.math.erf(a, name=self.result_name)
367
Luke Hutton41601862022-12-06 17:29:15 +0000368 class Sin:
369 def __init__(self, name):
370 self.result_name = name
371
372 def eval(self, a):
373 return tf.math.sin(a, name=self.result_name)
374
375 class Cos:
376 def __init__(self, name):
377 self.result_name = name
378
379 def eval(self, a):
380 return tf.math.cos(a, name=self.result_name)
381
Luke Hutton2138a192022-12-15 11:01:39 +0000382 class Atan2:
383 def __init__(self, name):
384 self.result_name = name
385
386 def eval(self, a, b):
387 return tf.math.atan2(a, b, name=self.result_name)
388
Jeremy Johnson015c3552022-02-23 12:15:03 +0000389 class Square:
390 def __init__(self, name):
391 self.result_name = name
392
393 def eval(self, a):
394 return tf.math.square(a, name=self.result_name)
395
396 class SquaredDifference:
397 def __init__(self, name):
398 self.result_name = name
399
400 def eval(self, a, b):
401 return tf.math.squared_difference(a, b, name=self.result_name)
402
403 class Equal:
404 def __init__(self, name):
405 self.result_name = name
406
407 def eval(self, a, b):
408 return tf.math.equal(a, b, name=self.result_name)
409
410 class GreaterEqual:
411 def __init__(self, name):
412 self.result_name = name
413
414 def eval(self, a, b):
415 return tf.math.greater_equal(a, b, name=self.result_name)
416
417 class Greater:
418 def __init__(self, name):
419 self.result_name = name
420
421 def eval(self, a, b):
422 return tf.math.greater(a, b, name=self.result_name)
423
424 class Less:
425 def __init__(self, name):
426 self.result_name = name
427
428 def eval(self, a, b):
429 return tf.math.less(a, b, name=self.result_name)
430
431 class LessEqual:
432 def __init__(self, name):
433 self.result_name = name
434
435 def eval(self, a, b):
436 return tf.math.less_equal(a, b, name=self.result_name)
437
438 class Conv2d:
439 def __init__(self, weight, strides, padding, dilations, name):
440 self.weight = weight
441 self.strides = strides
442 self.padding = padding
443 self.dilations = dilations
444 self.result_name = name
445
446 def eval(self, input):
447 return tf.nn.conv2d(
448 input,
449 self.weight,
450 self.strides,
451 self.padding,
452 data_format="NHWC",
453 dilations=self.dilations,
454 name=self.result_name,
455 )
456
457 class Conv2dRelu:
458 def __init__(self, weight, name):
459 self.weight = weight
460 self.result_name = name
461
462 def eval(self, input):
463 conv2d = tf.nn.conv2d(
464 input,
465 self.weight,
466 [1, 1, 1, 1],
467 "SAME",
468 data_format="NHWC",
469 dilations=[1, 1, 1, 1],
470 name="conv2d",
471 )
472 return tf.nn.relu(conv2d, name=self.result_name)
473
474 class Conv2dRelu6:
475 def __init__(self, weight, name):
476 self.weight = weight
477 self.result_name = name
478
479 def eval(self, input):
480 conv2d = tf.nn.conv2d(
481 input,
482 self.weight,
483 [1, 1, 1, 1],
484 "SAME",
485 data_format="NHWC",
486 dilations=[1, 1, 1, 1],
487 name="conv2d",
488 )
489 return tf.nn.relu6(conv2d, name=self.result_name)
490
491 class Conv2dReluN1To1:
492 def __init__(self, weight, name):
493 self.weight = weight
494 self.result_name = name
495
496 def eval(self, input):
497 conv2d = tf.nn.conv2d(
498 input,
499 self.weight,
500 [1, 1, 1, 1],
501 "SAME",
502 data_format="NHWC",
503 dilations=[1, 1, 1, 1],
504 name="conv2d",
505 )
506 return tf.clip_by_value(conv2d, -1.0, 1.0, name=self.result_name)
507
508 class Conv2dTanh:
509 def __init__(self, weight, name):
510 self.weight = weight
511 self.result_name = name
512
513 def eval(self, input):
514 conv2d = tf.nn.conv2d(
515 input,
516 self.weight,
517 [1, 1, 1, 1],
518 "SAME",
519 data_format="NHWC",
520 dilations=[1, 1, 1, 1],
521 name="conv2d",
522 )
523 return tf.math.tanh(conv2d, name=self.result_name)
524
525 class Conv2dWithBias:
526 def __init__(self, weight, bias, strides, padding, dilations, name):
527 self.weight = weight
528 self.bias = bias
529 self.strides = strides
530 self.padding = padding
531 self.dilations = dilations
532 self.result_name = name
533
534 def eval(self, input):
535 conv2d_op = tf.nn.conv2d(
536 input,
537 self.weight,
538 self.strides,
539 self.padding,
540 data_format="NHWC",
541 dilations=self.dilations,
542 name="conv2d",
543 )
544 bias_add_op = tf.nn.bias_add(
545 conv2d_op, self.bias, data_format="NHWC", name=self.result_name
546 )
547 return bias_add_op
548
TatWai Chongfd629052022-07-25 04:01:58 +0000549 class Conv3d:
550 def __init__(self, weight, strides, padding, dilations, name):
551 self.weight = weight
552 self.strides = strides
553 self.padding = padding
554 self.dilations = dilations
555 self.result_name = name
556
557 def eval(self, input):
558 return tf.nn.conv3d(
559 input,
560 self.weight,
561 self.strides,
562 self.padding,
563 data_format="NDHWC",
564 dilations=self.dilations,
565 name=self.result_name,
566 )
567
568 class Conv3dWithBias:
569 def __init__(self, weight, bias, strides, padding, dilations, name):
570 self.weight = weight
571 self.bias = bias
572 self.strides = strides
573 self.padding = padding
574 self.dilations = dilations
575 self.result_name = name
576
577 def eval(self, input):
578 conv3d_op = tf.nn.conv3d(
579 input,
580 self.weight,
581 self.strides,
582 self.padding,
583 data_format="NDHWC",
584 dilations=self.dilations,
585 name="conv3d",
586 )
587 bias_add_op = tf.nn.bias_add(conv3d_op, self.bias, name=self.result_name)
588 return bias_add_op
589
Jeremy Johnson015c3552022-02-23 12:15:03 +0000590 class DepthwiseConv2d:
591 def __init__(self, weight, strides, padding, dilations, name):
592 self.weight = weight
593 self.strides = strides
594 self.padding = padding
595 self.dilations = dilations
596 self.result_name = name
597
598 def eval(self, input):
599 dws_conv2d = tf.nn.depthwise_conv2d(
600 input,
601 self.weight,
602 self.strides,
603 self.padding,
604 data_format="NHWC",
605 dilations=self.dilations,
606 name="dws_conv2d",
607 )
608 return tf.identity(dws_conv2d, name=self.result_name)
609
610 class DepthwiseConv2dWithBias:
611 def __init__(self, weight, bias, strides, padding, dilations, name):
612 self.weight = weight
613 self.bias = bias
614 self.strides = strides
615 self.padding = padding
616 self.dilations = dilations
617 self.result_name = name
618
619 def eval(self, input):
620 dws_conv2d = tf.nn.depthwise_conv2d(
621 input,
622 self.weight,
623 self.strides,
624 self.padding,
625 data_format="NHWC",
626 dilations=self.dilations,
627 name="dws_conv2d",
628 )
629 bias_add_op = tf.nn.bias_add(
630 dws_conv2d, self.bias, data_format="NHWC", name=self.result_name
631 )
632 return bias_add_op
633
634 class TransposeConv2d:
635 def __init__(self, weight, output_shape, strides, padding, name):
636 self.weight = weight
637 self.output_shape = output_shape
638 self.strides = strides
639 self.padding = padding
640 self.result_name = name
641
642 def eval(self, input):
643 return tf.nn.conv2d_transpose(
644 input,
645 self.weight,
646 self.output_shape,
647 self.strides,
648 self.padding,
649 data_format="NHWC",
650 name=self.result_name,
651 )
652
653 class Argmax:
654 def __init__(self, axis, name):
655 self.axis = axis
656 self.result_name = name
657
658 def eval(self, a):
659 return tf.argmax(a, self.axis, output_type=tf.int32, name=self.result_name)
660
661 class AvgPool2d:
662 def __init__(self, strides, kernel_size, padding, name):
663 self.strides = strides
664 self.kernel_size = kernel_size
665 self.padding = padding
666 self.result_name = name
667
668 def eval(self, input):
669 return tf.nn.avg_pool2d(
670 input,
671 strides=self.strides,
672 ksize=self.kernel_size,
673 padding=self.padding,
674 data_format="NHWC",
675 name=self.result_name,
676 )
677
678 class MaxPool2d:
679 def __init__(self, strides, kernel_size, padding, name):
680 self.strides = strides
681 self.kernel_size = kernel_size
682 self.padding = padding
683 self.result_name = name
684
685 def eval(self, input):
686 return tf.nn.max_pool2d(
687 input,
688 strides=self.strides,
689 ksize=self.kernel_size,
690 padding=self.padding,
691 data_format="NHWC",
692 name=self.result_name,
693 )
694
695 class Reshape:
696 def __init__(self, shape, name):
697 self.shape = shape
698 self.result_name = name
699
700 def eval(self, a):
701 reshape_op = tf.reshape(a, self.shape)
702 return tf.identity(reshape_op, name=self.result_name)
703
704 class Transpose:
705 def __init__(self, perm, name):
706 self.perm = perm
707 self.result_name = name
708
709 def eval(self, a):
710 return tf.transpose(a, self.perm, name=self.result_name)
711
712 class Slice:
713 def __init__(self, begin, size, name):
714 self.begin = begin
715 self.size = size
716 self.result_name = name
717
718 def eval(self, a):
719 return tf.slice(a, begin=self.begin, size=self.size, name=self.result_name)
720
721 class StridedSlice:
722 def __init__(
723 self,
724 begin,
725 end,
726 strides,
727 begin_mask,
728 end_mask,
729 ellipsis_mask,
730 new_axis_mask,
731 shrink_axis_mask,
732 name,
733 ):
734 self.begin = begin
735 self.end = end
736 self.strides = strides
737 self.begin_mask = begin_mask
738 self.end_mask = end_mask
739 self.ellipsis_mask = ellipsis_mask
740 self.new_axis_mask = new_axis_mask
741 self.shrink_axis_mask = shrink_axis_mask
742 self.result_name = name
743
744 def eval(self, a):
745 return tf.strided_slice(
746 a,
747 begin=self.begin,
748 end=self.end,
749 strides=self.strides,
750 begin_mask=self.begin_mask,
751 end_mask=self.end_mask,
752 ellipsis_mask=self.ellipsis_mask,
753 new_axis_mask=self.new_axis_mask,
754 shrink_axis_mask=self.shrink_axis_mask,
755 name=self.result_name,
756 )
757
758 class Select:
759 def __init__(self, name):
760 self.result_name = name
761
762 def eval(self, selector, a, b):
763 return tf.where(condition=selector, x=a, y=b, name=self.result_name)
764
765 class Addn:
766 def __init__(self, name):
767 self.result_name = name
768
769 def eval(self, a, b, c, d):
770 return tf.add_n([a, b, c, d], name=self.result_name)
771
772 class Concatv2:
773 def __init__(self, axis, name):
774 self.axis = axis
775 self.result_name = name
776
777 def eval(self, a, b, c, d):
Won Jeonf9c0cee2023-09-18 16:32:45 -0700778 return (
779 tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
780 if a.shape != ()
781 else tf.stack([a, b, c, d], name=self.result_name)
782 )
Jeremy Johnson015c3552022-02-23 12:15:03 +0000783
784 class Stack:
785 def __init__(self, axis, name):
786 self.axis = axis
787 self.result_name = name
788
789 def eval(self, a, b, c, d):
790 return tf.stack([a, b, c, d], axis=self.axis, name=self.result_name)
791
792 class Unstack:
793 def __init__(self, axis, name):
794 self.axis = axis
795 self.result_name = name
796
797 def eval(self, a):
798 unstack_op = tf.unstack(a, axis=self.axis, name="unstack_op")
799 result_count = a.shape[self.axis]
800
801 if result_count == 1:
802 return tf.identity(unstack_op[0], name=self.result_name)
803
804 sums = []
805 for i in range(result_count):
806 sums.append(
807 tf.math.reduce_sum(unstack_op[i], name="reduce_{}".format(i))
808 )
809 return tf.stack(sums, 0, name=self.result_name)
810
TatWai Chongf7008da2022-09-09 09:35:40 +0000811 class MirrorPad:
812 def __init__(self, padding, mode, name):
813 self.padding = padding
814 self.mode = mode
815 self.result_name = name
816
817 def eval(self, a):
818 return tf.pad(
819 a,
820 self.padding,
821 mode=self.mode,
822 constant_values=0,
823 name=self.result_name,
824 )
825
Jeremy Johnson015c3552022-02-23 12:15:03 +0000826 class Pad:
TatWai Chong2226f902023-02-22 18:38:01 -0800827 def __init__(self, padding, pad_const, name):
Jeremy Johnson015c3552022-02-23 12:15:03 +0000828 self.padding = padding
TatWai Chong2226f902023-02-22 18:38:01 -0800829 self.pad_const = pad_const
Jeremy Johnson015c3552022-02-23 12:15:03 +0000830 self.result_name = name
831
832 def eval(self, a):
833 return tf.pad(
834 a,
835 self.padding,
836 mode="CONSTANT",
TatWai Chong2226f902023-02-22 18:38:01 -0800837 constant_values=self.pad_const,
Jeremy Johnson015c3552022-02-23 12:15:03 +0000838 name=self.result_name,
839 )
840
841 class ExpandDims:
842 def __init__(self, axis, name):
843 self.axis = axis
844 self.result_name = name
845
846 def eval(self, a):
847 return tf.expand_dims(a, self.axis, name=self.result_name)
848
849 class Shape:
850 def __init__(self, name):
851 self.result_name = name
852
853 def eval(self, a):
854 return tf.shape(a, name=self.result_name)
855
856 class Rank:
857 def __init__(self, name):
858 self.result_name = name
859
860 def eval(self, a):
861 return tf.rank(a, name=self.result_name)
862
863 class Fill:
864 def __init__(self, shape, value, name):
865 self.shape = shape
866 self.value = value
867 self.result_name = name
868
869 def eval(self, a):
870 return tf.fill(self.shape, self.value, name=self.result_name)
871
872 class Elu:
873 def __init__(self, name):
874 self.result_name = name
875
876 def eval(self, a):
877 return tf.nn.elu(a, name=self.result_name)
878
879 class Softmax:
880 def __init__(self, name):
881 self.result_name = name
882
883 def eval(self, a):
884 return tf.nn.softmax(a, name=self.result_name)
885
886 class LogSoftmax:
887 def __init__(self, name):
888 self.result_name = name
889
890 def eval(self, a):
891 return tf.nn.log_softmax(a, name=self.result_name)
892
Jerry Ge28811d92023-12-05 00:53:26 +0000893 class DynamicLinear:
894 def __init__(self, dynamic_input_shape, name):
895 self.result_name = name
896 self.model = tf.keras.Sequential(
897 [
898 tf.keras.layers.Input(shape=dynamic_input_shape),
899 tf.keras.layers.Dense(units=5),
900 ]
901 )
902
903 def eval(self, a):
904 return self.model(a)
905
Jeremy Johnson015c3552022-02-23 12:15:03 +0000906 class MatMul:
907 def __init__(self, name):
908 self.result_name = name
909
910 def eval(self, a, b):
911 return tf.linalg.matmul(a, b, name=self.result_name)
912
913 class AddScalar:
914 def __init__(self, name):
915 self.result_name = name
916
917 def eval(self, a):
918 return tf.add(a, 1, name=self.result_name)
919
920 class Add1d:
921 def __init__(self, name):
922 self.result_name = name
923
924 def eval(self, a, b):
925 if len(b.shape) > 1:
926 b_1d = tf.reduce_sum(b, axis=list(range(0, len(b.shape) - 1, 1)))
927 else:
928 b_1d = b
929 return tf.add(a, b_1d, name=self.result_name)
930
931 class Split:
932 def __init__(self, num_splits, axis, name):
933 self.num_splits = num_splits
934 self.axis = axis
935 self.result_name = name
936
937 def eval(self, a):
938 # The split op generates a list of outputs. Since we have difficulty
939 # serializing a list or array of Numpy arrays, we will reduce each of
940 # the results
941
942 if not isinstance(self.num_splits, list):
943 split_op = tf.split(
944 a, num_or_size_splits=self.num_splits, axis=self.axis, name="split"
945 )
946 result_count = self.num_splits
947 else:
948 num_split = np.asarray(self.num_splits, dtype=np.int32)
949 split_vec_op = tf.compat.v1.constant(
950 num_split,
951 shape=num_split.shape,
952 dtype=tf.int32,
953 name="const_split_vec",
954 )
955 split_op = tf.split(
956 a, num_or_size_splits=split_vec_op, axis=self.axis, name="split"
957 )
958 result_count = num_split.shape[0]
959
960 sums = []
961 for i in range(result_count):
962 sums.append(tf.math.reduce_sum(split_op[i], name="reduce_{}".format(i)))
963 return tf.stack(sums, 0, name=self.result_name)
964
965 class Tile:
966 def __init__(self, multiples, name):
967 self.multiples = multiples
968 self.result_name = name
969
970 def eval(self, a):
971 t = tf.tile(a, self.multiples, name="tile")
972 return tf.identity(t, name=self.result_name)
973
974 class Reverse:
975 def __init__(self, axis, name):
976 self.axis = axis
977 self.result_name = name
978
979 def eval(self, a):
980 return tf.reverse(a, [self.axis], name=self.result_name)
981
982 class Gather:
983 def __init__(self, indices, batch_dims, axis, name):
984 self.indices = indices
985 self.batch_dims = batch_dims
986 self.axis = axis
987 self.result_name = name
988
989 def eval(self, a):
990 return tf.gather(
991 a,
992 self.indices,
993 batch_dims=self.batch_dims,
994 axis=self.axis,
995 name=self.result_name,
996 )
997
998 class GatherNd:
999 def __init__(self, indices, name):
1000 self.indices = indices
1001 self.result_name = name
1002
1003 def eval(self, a):
1004 return tf.gather_nd(a, self.indices, name=self.result_name)
1005
1006 class ScatterNd:
1007 def __init__(self, shape, indices_shape, N, rng, name):
1008 self.shape = shape
1009 self.indices_shape = indices_shape
1010 self.N = N
1011 self.rng = rng
1012 self.result_name = name
1013
1014 def eval(self, a):
1015
1016 # This operator is special. The indices and updates tensors really need
1017 # to be created together, but in the current structure of this tool there
1018 # is no way to do that before now. The number of updates is determined by
1019 # the indices, so we can really only create that after indices; but we
1020 # don't know the type at that time.
1021 #
1022 # Shapes are guaranteed deterministic, but we'll use our rng
1023 # copied from the arggen stage. It's possible that index and
1024 # update *values* will be non-deterministic.
1025 #
1026 # We take the tensor_tensor simply to get the dtype.
1027
1028 shape_const = tf.constant(self.shape, tf.int32)
1029
1030 updates_shape = list(self.indices_shape[:-1])
1031 updates_shape.extend(self.shape[self.indices_shape[-1] :])
1032
1033 updates_const = tf.constant(TGen.getRand(updates_shape, a.dtype, self.rng))
1034
1035 indices = np.zeros(self.indices_shape, dtype=np.int32)
1036
1037 # We need to generate the random indices tensor based on the
1038 # limits of 'shape' for each dimension. Surely, there is a faster
1039 # vectorized way to do this, but the tensors are fairly small so we
1040 # will do this one element at a time. Each element needs to be sized based
1041 # on the size of the last dimension.
1042 for idx in np.ndindex(indices.shape):
1043 indices[idx] = self.rng.integers(0, self.shape[idx[-1]], size=1)[0]
1044 # print('{} {}'.format(idx, indices[idx]))
1045
1046 indices_const = tf.constant(indices, dtype=tf.int32)
1047
1048 return tf.scatter_nd(
1049 indices=indices_const,
1050 updates=updates_const,
1051 shape=shape_const,
1052 name=self.result_name,
1053 )
1054
1055 class SpaceToBatch:
1056 def __init__(self, block_shape, padding, name):
1057 self.block_shape = block_shape
1058 self.padding = padding
1059 self.result_name = name
1060
1061 def eval(self, a):
1062 return tf.space_to_batch(
1063 a, self.block_shape, self.padding, name=self.result_name
1064 )
1065
TatWai Chongbef907a2024-01-23 09:40:37 -08001066 class DynamicSpaceToBatch:
1067 def __init__(self, block_shape, padding, dynamic_input_shape, name):
1068 self.result_name = name
1069
1070 dynamic_input_shape_with_batch = list(dynamic_input_shape)
1071 dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:]
1072 dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch)
1073
1074 self.model = tf.keras.Sequential(
1075 [
1076 tf.keras.layers.Input(shape=dynamic_input_shape_no_batch),
1077 tf.keras.layers.Lambda(
1078 lambda x: tf.space_to_batch(x, block_shape, padding, name=None)
1079 ),
1080 ]
1081 )
1082
1083 def eval(self, a):
1084 return self.model(a)
1085
Jeremy Johnson015c3552022-02-23 12:15:03 +00001086 class BatchToSpace:
1087 def __init__(self, block_shape, cropping, name):
1088 self.block_shape = block_shape
1089 self.cropping = cropping
1090 self.result_name = name
1091
1092 def eval(self, a):
1093 # transpose to swap depth and batch first. this could avoid adding new shape
1094 block_rank = len(self.block_shape)
1095 perm = [len(a.shape) - 1]
1096 for i in range(block_rank):
1097 perm.append(i + 1)
1098 perm.append(0)
1099 transpose_op = tf.transpose(a, perm)
1100 return tf.batch_to_space(
1101 transpose_op, self.block_shape, self.cropping, name=self.result_name
1102 )
1103
Jerry Ge28811d92023-12-05 00:53:26 +00001104 class DynamicBatchToSpace:
1105 def __init__(self, block_shape, cropping, dynamic_input_shape, name):
1106 self.result_name = name
1107
1108 dynamic_input_shape_with_batch = list(dynamic_input_shape)
1109 dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:]
1110 dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch)
1111
1112 self.model = tf.keras.Sequential(
1113 [
1114 tf.keras.layers.Input(shape=dynamic_input_shape_no_batch),
1115 tf.keras.layers.Lambda(
1116 lambda x: tf.batch_to_space(x, block_shape, cropping, name=None)
1117 ),
1118 ]
1119 )
1120
1121 def eval(self, a):
1122 return self.model(a)
1123
Jeremy Johnson015c3552022-02-23 12:15:03 +00001124 class SpaceToDepth:
1125 def __init__(self, block_shape, name):
1126 self.block_shape = block_shape
1127 self.result_name = name
1128
1129 def eval(self, a):
1130 return tf.nn.space_to_depth(a, self.block_shape, name=self.result_name)
1131
Jerry Ge28811d92023-12-05 00:53:26 +00001132 class DynamicSpaceToDepth:
1133 def __init__(self, dynamic_input_shape, name):
1134 self.result_name = name
1135
1136 dynamic_input_shape_with_batch = list(dynamic_input_shape)
1137 dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:]
1138 dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch)
1139
1140 self.model = tf.keras.Sequential(
1141 [
1142 tf.keras.layers.Input(shape=dynamic_input_shape_no_batch),
1143 tf.keras.layers.Lambda(
1144 lambda x: tf.nn.space_to_depth(
1145 x, 2, data_format="NHWC", name=None
1146 )
1147 ),
1148 ]
1149 )
1150
1151 def eval(self, a):
1152 return self.model(a)
1153
Jeremy Johnson015c3552022-02-23 12:15:03 +00001154 class DepthToSpace:
1155 def __init__(self, block_shape, name):
1156 self.block_shape = block_shape
1157 self.result_name = name
1158
1159 def eval(self, a):
1160 return tf.nn.depth_to_space(a, self.block_shape, name=self.result_name)
1161
Jerry Ge28811d92023-12-05 00:53:26 +00001162 class DynamicDepthToSpace:
1163 def __init__(self, dynamic_input_shape, name):
1164 self.result_name = name
1165
1166 dynamic_input_shape_with_batch = list(dynamic_input_shape)
1167 dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:]
1168 dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch)
1169
1170 self.model = tf.keras.Sequential(
1171 [
1172 tf.keras.layers.Input(shape=dynamic_input_shape_no_batch),
1173 tf.keras.layers.Lambda(
1174 lambda x: tf.nn.depth_to_space(
1175 x, 2, data_format="NHWC", name=None
1176 )
1177 ),
1178 ]
1179 )
1180
1181 def eval(self, a):
1182 return self.model(a)
1183
Jeremy Johnson015c3552022-02-23 12:15:03 +00001184 class OneHot:
1185 def __init__(self, depth, axis, name):
1186 self.depth = depth
1187 self.axis = axis
1188 self.result_name = name
1189
1190 def eval(self, indices, on_value, off_value):
1191 return tf.one_hot(
1192 indices,
1193 self.depth,
1194 on_value,
1195 off_value,
1196 self.axis,
1197 on_value.dtype,
1198 self.result_name,
1199 )
1200
1201 class Fakequant:
1202 def __init__(self, num_bits, narrow_range, name):
1203 self.num_bits = num_bits
1204 self.narrow_range = narrow_range
1205 self.result_name = name
1206
1207 def eval(self, a):
1208 return tf.quantization.fake_quant_with_min_max_args(
1209 a,
1210 min=-2.0,
1211 max=2.0,
1212 num_bits=self.num_bits,
1213 narrow_range=self.narrow_range,
1214 name=self.result_name,
1215 )
1216
TatWai Chong0cef07e2023-02-27 13:22:52 -08001217 class Resize:
1218 def __init__(self, mode, align, half, scale, name):
Jeremy Johnson015c3552022-02-23 12:15:03 +00001219 self.result_name = name
TatWai Chong0cef07e2023-02-27 13:22:52 -08001220 self.mode = mode
1221 self.align = align
1222 self.half = half
1223 self.scale = scale
Jeremy Johnson015c3552022-02-23 12:15:03 +00001224
1225 def eval(self, a):
1226 out_shape = []
TatWai Chong0cef07e2023-02-27 13:22:52 -08001227 out_shape.append(a.shape[1] * self.scale)
1228 out_shape.append(a.shape[2] * self.scale)
Jeremy Johnson015c3552022-02-23 12:15:03 +00001229
TatWai Chong0cef07e2023-02-27 13:22:52 -08001230 tf_resize_dict = (
1231 {"tf_resize_func": tf.compat.v1.image.resize_nearest_neighbor}
1232 if (self.mode == "nearest")
1233 else {"tf_resize_func": tf.compat.v1.image.resize_bilinear}
1234 )
1235 resize = tf_resize_dict["tf_resize_func"](
Jeremy Johnson015c3552022-02-23 12:15:03 +00001236 a,
1237 out_shape,
TatWai Chong0cef07e2023-02-27 13:22:52 -08001238 align_corners=self.align,
Jeremy Johnson015c3552022-02-23 12:15:03 +00001239 name="resize",
TatWai Chong0cef07e2023-02-27 13:22:52 -08001240 half_pixel_centers=self.half,
TatWai Chongf7326092022-06-08 12:17:14 -07001241 )
1242 return tf.identity(resize, name=self.result_name)
1243
Jeremy Johnson015c3552022-02-23 12:15:03 +00001244 class LeftShift:
1245 def __init__(self, shift, name):
1246 self.shift = shift
1247 self.result_name = name
1248
1249 def eval(self, a):
1250 return tf.bitwise.left_shift(a, self.shift, name=self.result_name)
1251
1252 class RightShift:
1253 def __init__(self, shift, name):
1254 self.shift = shift
1255 self.result_name = name
1256
1257 def eval(self, a):
1258 return tf.bitwise.right_shift(a, self.shift, name=self.result_name)
Jerry Ge9e94af82022-10-27 09:57:00 -07001259
1260 class While:
1261 def __init__(self, name):
1262 self.result_name = name
1263
1264 def while_cond(self, x):
1265 return tf.reduce_sum(x) < self.cap
1266
1267 def while_body(self, x):
1268 return tf.add(x, tf.math.sigmoid(x))
1269
1270 def eval(self, a):
1271 self.cap = tf.cast(
1272 tf.constant(
1273 2.0,
1274 shape=[
1275 1,
1276 ],
1277 ),
1278 a.dtype,
1279 )
1280
1281 result = tf.while_loop(
1282 self.while_cond, self.while_body, [a], name=self.result_name
1283 )
1284
1285 return result[0]
1286
Tai Lycf84bc92023-09-07 20:49:09 +00001287 class LSTM(tf.Module):
Jerry Ge9e94af82022-10-27 09:57:00 -07001288 def __init__(self, name):
1289 self.result_name = name
1290 self.lstm = tf.keras.layers.LSTM(
1291 2,
1292 activation="tanh",
1293 unroll=False,
1294 recurrent_activation="sigmoid",
1295 use_bias=True,
1296 recurrent_initializer="ones",
1297 kernel_initializer="ones",
1298 )
1299
1300 def eval(self, a):
1301 return self.lstm(a)
1302
Tai Lycf84bc92023-09-07 20:49:09 +00001303 class SLSTM(tf.Module):
1304 def __init__(self, name):
1305 self.result_name = name
1306 self.lstm = tf.keras.layers.LSTM(
1307 2,
1308 stateful=True,
1309 activation="tanh",
1310 unroll=False,
1311 recurrent_activation="sigmoid",
1312 use_bias=True,
1313 recurrent_initializer="ones",
1314 kernel_initializer="ones",
1315 )
1316
1317 def eval(self, a):
1318 return self.lstm(a)
1319
Jerry Ge9e94af82022-10-27 09:57:00 -07001320 class GRU:
1321 def __init__(self, name):
1322 self.result_name = name
1323 self.lstm = tf.keras.layers.GRU(
1324 2,
1325 recurrent_activation="sigmoid",
1326 use_bias=True,
1327 recurrent_initializer="ones",
1328 kernel_initializer="ones",
1329 )
1330
1331 def eval(self, a):
1332 return self.lstm(a)
1333
1334 class RNN:
1335 def __init__(self, name):
1336 self.result_name = name
1337 basic_cell = tf.keras.layers.SimpleRNNCell(
1338 units=2,
1339 activation="sigmoid",
1340 use_bias=True,
1341 recurrent_initializer="ones",
1342 )
1343 self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False)
1344
1345 def eval(self, a):
1346 return self.rnn(a)
1347
1348 class FullyConnected:
1349 def __init__(self, name):
1350 self.result_name = name
1351 self.dense = tf.keras.layers.Dense(2)
1352
1353 def eval(self, a):
1354 return self.dense(a)
Luke Hutton261b7b62023-01-10 14:50:31 +00001355
1356 class RFFT2d:
1357 def __init__(self, fft_length, name):
1358 self.fft_length = fft_length
1359 self.result_name = name
1360
1361 def eval(self, a):
1362 return tf.signal.rfft2d(a, self.fft_length, name=self.result_name)
Luke Hutton714aa602023-02-08 19:45:26 +00001363
1364 class Real:
1365 def __init__(self, name):
1366 self.result_name = name
1367
1368 def eval(self, a):
1369 return tf.math.real(a, name=self.result_name)
1370
1371 class Imag:
1372 def __init__(self, name):
1373 self.result_name = name
1374
1375 def eval(self, a):
1376 return tf.math.imag(a, name=self.result_name)
Tai Lyfe36fa92023-06-01 21:45:12 +00001377
1378 class BroadcastTo:
1379 def __init__(self, shape, name):
1380 self.shape = shape
1381 self.result_name = name
1382
1383 def eval(self, a):
1384 return tf.broadcast_to(a, shape=self.shape, name=self.result_name)
Tai Lycf84bc92023-09-07 20:49:09 +00001385
1386 class CallOnce(tf.Module):
1387 def __init__(self, name):
1388 print(tf.__version__)
1389 self.result_name = name
1390 self.var = tf.Variable([1.0])
1391
1392 @tf.function(
1393 input_signature=[
1394 tf.TensorSpec(
1395 shape=[
1396 1,
1397 ],
1398 dtype=tf.float32,
1399 )
1400 ]
1401 )
1402 def eval(self, a):
1403 return self.var.assign([2.0])