blob: 630286549e698169e56ce8795cd68dfbd13d768e [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson015c3552022-02-23 12:15:03 +00002# SPDX-License-Identifier: Apache-2.0
3import numpy as np
4import tensorflow as tf
5from frameworks.tensor_gen import TGen
6
7
8class TBuilder:
9 """The member functions build the tensorflow operators into small networks
10 for our tests"""
11
12 def __init__(self):
13 pass
14
15 def fake_quant(tensor, tensor_scale, name):
16 """Helper function for quantizing with a scaling parameters structure."""
17 return tf.quantization.fake_quant_with_min_max_args(
18 tensor,
19 min=tensor_scale.min,
20 max=tensor_scale.max,
21 num_bits=tensor_scale.num_bits,
22 narrow_range=tensor_scale.narrow_range,
23 name=name,
24 )
25
26 def fake_quant_params(tensor, min, max, scaling, name):
27 """Helper function for quantizing with individual scaling parameters."""
28 return tf.quantization.fake_quant_with_min_max_args(
29 tensor,
30 min=min,
31 max=max,
32 num_bits=scaling.num_bits,
33 narrow_range=scaling.narrow_range,
34 name=name,
35 )
36
37 class Add:
38 def __init__(self, name):
39 self.result_name = name
40
41 def eval(self, a, b):
42 return tf.add(a, b, name=self.result_name)
43
44 class Sub:
45 def __init__(self, name):
46 self.result_name = name
47
48 def eval(self, a, b):
49 return tf.subtract(a, b, name=self.result_name)
50
51 class Mul:
52 def __init__(self, name):
53 self.result_name = name
54
55 def eval(self, a, b):
56 return tf.multiply(a, b, name=self.result_name)
57
58 class Exp:
59 def __init__(self, name):
60 self.result_name = name
61
62 def eval(self, a):
63 return tf.exp(a, name=self.result_name)
64
65 class Rcp:
66 def __init__(self, name):
67 self.result_name = name
68
69 def eval(self, a):
70 return tf.math.reciprocal(a, name=self.result_name)
71
72 class Relu:
73 def __init__(self, name):
74 self.result_name = name
75
76 def eval(self, a):
77 return tf.nn.relu(a, name=self.result_name)
78
Jerry Ge93912432022-07-22 10:29:13 -070079 class Relu1:
80 def __init__(self, name):
81 self.result_name = name
82
83 def eval(self, a):
84 # TF doesn't have relu_n1_to_1 operator,
85 # use min and max as a workaround
86 # alternatively, we can use clip_by_value
87 return tf.math.minimum(1.0, tf.math.maximum(-1.0, a))
88
Jerry Ge2eea5bf2022-10-11 16:27:05 +000089 class Relu0To1:
90 def __init__(self, name):
91 self.result_name = name
92
93 def eval(self, a):
94 # TF doesn't have relu_0_to_1 operator,
95 # use min and max as a workaround
96 # alternatively, we can use clip_by_value
97 return tf.math.minimum(1.0, tf.math.maximum(0.0, a))
98
Jeremy Johnson015c3552022-02-23 12:15:03 +000099 class Relu6:
100 def __init__(self, name):
101 self.result_name = name
102
103 def eval(self, a):
104 return tf.nn.relu6(a, name=self.result_name)
105
106 class LeakyRelu:
107 def __init__(self, alpha, name):
108 self.alpha = alpha
109 self.result_name = name
110
111 def eval(self, a):
112 return tf.nn.leaky_relu(a, alpha=self.alpha, name=self.result_name)
113
TatWai Chong41a04fe2022-11-03 21:44:32 +0000114 class Prelu:
115 def __init__(self, name):
116 self.result_name = name
117 self.prelu = tf.keras.layers.PReLU(
118 alpha_initializer=tf.keras.initializers.RandomNormal(
119 mean=0.0, stddev=1.0
120 )
121 )
122
123 def eval(self, a):
124 return self.prelu(a)
125
TatWai Chong473eb382022-08-02 04:21:30 +0000126 class Gelu:
127 def __init__(self, name):
128 self.result_name = name
129
130 def eval(self, a):
131 return tf.nn.gelu(a, name=self.result_name)
132
Jeremy Johnson015c3552022-02-23 12:15:03 +0000133 class Concat:
134 def __init__(self, axis, name):
135 self.axis = axis
136 self.result_name = name
137
138 def eval(self, a, b):
139 return tf.concat([a, b], self.axis, name=self.result_name)
140
141 class BitwiseAnd:
142 def __init__(self, name):
143 self.result_name = name
144
145 def eval(self, a, b):
146 return tf.bitwise.bitwise_and(a, b, name=self.result_name)
147
148 class BitwiseOr:
149 def __init__(self, name):
150 self.result_name = name
151
152 def eval(self, a, b):
153 return tf.bitwise.bitwise_or(a, b, name=self.result_name)
154
155 class BitwiseNot:
156 def __init__(self, name):
157 self.result_name = name
158
159 def eval(self, a):
160 return tf.bitwise.invert(a, name=self.result_name)
161
162 class BitwiseXor:
163 def __init__(self, name):
164 self.result_name = name
165
166 def eval(self, a, b):
167 return tf.bitwise.bitwise_xor(a, b, name=self.result_name)
168
169 class LogicalAnd:
170 def __init__(self, name):
171 self.result_name = name
172
173 def eval(self, a, b):
174 return tf.math.logical_and(a, b, name=self.result_name)
175
176 class LogicalOr:
177 def __init__(self, name):
178 self.result_name = name
179
180 def eval(self, a, b):
181 return tf.math.logical_or(a, b, name=self.result_name)
182
183 class LogicalNot:
184 def __init__(self, name):
185 self.result_name = name
186
187 def eval(self, a):
188 return tf.math.logical_not(a, name=self.result_name)
189
190 class ReduceAny:
191 def __init__(self, axis_list, keepdims, name):
192 self.axis_list = axis_list
193 self.keepdims = keepdims
194 self.result_name = name
195
196 def eval(self, a):
197 return tf.math.reduce_any(
198 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
199 )
200
201 class ReduceAll:
202 def __init__(self, axis_list, keepdims, name):
203 self.axis_list = axis_list
204 self.keepdims = keepdims
205 self.result_name = name
206
207 def eval(self, a):
208 return tf.math.reduce_all(
209 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
210 )
211
212 class ReduceMin:
213 def __init__(self, axis_list, keepdims, name):
214 self.axis_list = axis_list
215 self.keepdims = keepdims
216 self.result_name = name
217
218 def eval(self, a):
219 return tf.math.reduce_min(
220 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
221 )
222
223 class ReduceMax:
224 def __init__(self, axis_list, keepdims, name):
225 self.axis_list = axis_list
226 self.keepdims = keepdims
227 self.result_name = name
228
229 def eval(self, a):
230 return tf.math.reduce_max(
231 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
232 )
233
234 class ReduceSum:
235 def __init__(self, axis_list, keepdims, name):
236 self.axis_list = axis_list
237 self.keepdims = keepdims
238 self.result_name = name
239
240 def eval(self, a):
241 return tf.math.reduce_sum(
242 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
243 )
244
245 class ReduceMean:
246 def __init__(self, axis_list, keepdims, name):
247 self.axis_list = axis_list
248 self.keepdims = keepdims
249 self.result_name = name
250
251 def eval(self, a):
252 return tf.math.reduce_mean(
253 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
254 )
255
256 class ReduceProduct:
257 def __init__(self, axis_list, keepdims, name):
258 self.axis_list = axis_list
259 self.keepdims = keepdims
260 self.result_name = name
261
262 def eval(self, a):
263 return tf.math.reduce_prod(
264 a, self.axis_list, keepdims=self.keepdims, name=self.result_name
265 )
266
267 class Min:
268 def __init__(self, name):
269 self.result_name = name
270
271 def eval(self, a, b):
272 return tf.math.minimum(a, b, name=self.result_name)
273
274 class Max:
275 def __init__(self, name):
276 self.result_name = name
277
278 def eval(self, a, b):
279 return tf.math.maximum(a, b, name=self.result_name)
280
281 class Pow:
282 def __init__(self, name):
283 self.result_name = name
284
285 def eval(self, a, b):
286 return tf.math.pow(a, b, name=self.result_name)
287
288 class Abs:
289 def __init__(self, name):
290 self.result_name = name
291
292 def eval(self, a):
293 return tf.math.abs(a, name=self.result_name)
294
295 class Ceil:
296 def __init__(self, name):
297 self.result_name = name
298
299 def eval(self, a):
300 return tf.math.ceil(a, name=self.result_name)
301
302 class Floor:
303 def __init__(self, name):
304 self.result_name = name
305
306 def eval(self, a):
307 return tf.math.floor(a, name=self.result_name)
308
309 class Log:
310 def __init__(self, name):
311 self.result_name = name
312
313 def eval(self, a):
314 return tf.math.log(a, name=self.result_name)
315
316 class Negate:
317 def __init__(self, name):
318 self.result_name = name
319
320 def eval(self, a):
321 return tf.math.negative(a, name=self.result_name)
322
323 class Rsqrt:
324 def __init__(self, name):
325 self.result_name = name
326
327 def eval(self, a):
328 return tf.math.rsqrt(a, name=self.result_name)
329
TatWai Chongd713a4d2022-11-10 13:54:28 -0800330 class Sign:
331 def __init__(self, name):
332 self.result_name = name
333
334 def eval(self, a):
335 return tf.math.sign(a, name=self.result_name)
336
Jeremy Johnson015c3552022-02-23 12:15:03 +0000337 class Sigmoid:
338 def __init__(self, name):
339 self.result_name = name
340
341 def eval(self, a):
342 return tf.math.sigmoid(a, name=self.result_name)
343
344 class Tanh:
345 def __init__(self, name):
346 self.result_name = name
347
348 def eval(self, a):
349 return tf.math.tanh(a, name=self.result_name)
350
Luke Hutton41601862022-12-06 17:29:15 +0000351 class Sin:
352 def __init__(self, name):
353 self.result_name = name
354
355 def eval(self, a):
356 return tf.math.sin(a, name=self.result_name)
357
358 class Cos:
359 def __init__(self, name):
360 self.result_name = name
361
362 def eval(self, a):
363 return tf.math.cos(a, name=self.result_name)
364
Luke Hutton2138a192022-12-15 11:01:39 +0000365 class Atan2:
366 def __init__(self, name):
367 self.result_name = name
368
369 def eval(self, a, b):
370 return tf.math.atan2(a, b, name=self.result_name)
371
Jeremy Johnson015c3552022-02-23 12:15:03 +0000372 class Square:
373 def __init__(self, name):
374 self.result_name = name
375
376 def eval(self, a):
377 return tf.math.square(a, name=self.result_name)
378
379 class SquaredDifference:
380 def __init__(self, name):
381 self.result_name = name
382
383 def eval(self, a, b):
384 return tf.math.squared_difference(a, b, name=self.result_name)
385
386 class Equal:
387 def __init__(self, name):
388 self.result_name = name
389
390 def eval(self, a, b):
391 return tf.math.equal(a, b, name=self.result_name)
392
393 class GreaterEqual:
394 def __init__(self, name):
395 self.result_name = name
396
397 def eval(self, a, b):
398 return tf.math.greater_equal(a, b, name=self.result_name)
399
400 class Greater:
401 def __init__(self, name):
402 self.result_name = name
403
404 def eval(self, a, b):
405 return tf.math.greater(a, b, name=self.result_name)
406
407 class Less:
408 def __init__(self, name):
409 self.result_name = name
410
411 def eval(self, a, b):
412 return tf.math.less(a, b, name=self.result_name)
413
414 class LessEqual:
415 def __init__(self, name):
416 self.result_name = name
417
418 def eval(self, a, b):
419 return tf.math.less_equal(a, b, name=self.result_name)
420
421 class Conv2d:
422 def __init__(self, weight, strides, padding, dilations, name):
423 self.weight = weight
424 self.strides = strides
425 self.padding = padding
426 self.dilations = dilations
427 self.result_name = name
428
429 def eval(self, input):
430 return tf.nn.conv2d(
431 input,
432 self.weight,
433 self.strides,
434 self.padding,
435 data_format="NHWC",
436 dilations=self.dilations,
437 name=self.result_name,
438 )
439
440 class Conv2dRelu:
441 def __init__(self, weight, name):
442 self.weight = weight
443 self.result_name = name
444
445 def eval(self, input):
446 conv2d = tf.nn.conv2d(
447 input,
448 self.weight,
449 [1, 1, 1, 1],
450 "SAME",
451 data_format="NHWC",
452 dilations=[1, 1, 1, 1],
453 name="conv2d",
454 )
455 return tf.nn.relu(conv2d, name=self.result_name)
456
457 class Conv2dRelu6:
458 def __init__(self, weight, name):
459 self.weight = weight
460 self.result_name = name
461
462 def eval(self, input):
463 conv2d = tf.nn.conv2d(
464 input,
465 self.weight,
466 [1, 1, 1, 1],
467 "SAME",
468 data_format="NHWC",
469 dilations=[1, 1, 1, 1],
470 name="conv2d",
471 )
472 return tf.nn.relu6(conv2d, name=self.result_name)
473
474 class Conv2dReluN1To1:
475 def __init__(self, weight, name):
476 self.weight = weight
477 self.result_name = name
478
479 def eval(self, input):
480 conv2d = tf.nn.conv2d(
481 input,
482 self.weight,
483 [1, 1, 1, 1],
484 "SAME",
485 data_format="NHWC",
486 dilations=[1, 1, 1, 1],
487 name="conv2d",
488 )
489 return tf.clip_by_value(conv2d, -1.0, 1.0, name=self.result_name)
490
491 class Conv2dTanh:
492 def __init__(self, weight, name):
493 self.weight = weight
494 self.result_name = name
495
496 def eval(self, input):
497 conv2d = tf.nn.conv2d(
498 input,
499 self.weight,
500 [1, 1, 1, 1],
501 "SAME",
502 data_format="NHWC",
503 dilations=[1, 1, 1, 1],
504 name="conv2d",
505 )
506 return tf.math.tanh(conv2d, name=self.result_name)
507
508 class Conv2dWithBias:
509 def __init__(self, weight, bias, strides, padding, dilations, name):
510 self.weight = weight
511 self.bias = bias
512 self.strides = strides
513 self.padding = padding
514 self.dilations = dilations
515 self.result_name = name
516
517 def eval(self, input):
518 conv2d_op = tf.nn.conv2d(
519 input,
520 self.weight,
521 self.strides,
522 self.padding,
523 data_format="NHWC",
524 dilations=self.dilations,
525 name="conv2d",
526 )
527 bias_add_op = tf.nn.bias_add(
528 conv2d_op, self.bias, data_format="NHWC", name=self.result_name
529 )
530 return bias_add_op
531
TatWai Chongfd629052022-07-25 04:01:58 +0000532 class Conv3d:
533 def __init__(self, weight, strides, padding, dilations, name):
534 self.weight = weight
535 self.strides = strides
536 self.padding = padding
537 self.dilations = dilations
538 self.result_name = name
539
540 def eval(self, input):
541 return tf.nn.conv3d(
542 input,
543 self.weight,
544 self.strides,
545 self.padding,
546 data_format="NDHWC",
547 dilations=self.dilations,
548 name=self.result_name,
549 )
550
551 class Conv3dWithBias:
552 def __init__(self, weight, bias, strides, padding, dilations, name):
553 self.weight = weight
554 self.bias = bias
555 self.strides = strides
556 self.padding = padding
557 self.dilations = dilations
558 self.result_name = name
559
560 def eval(self, input):
561 conv3d_op = tf.nn.conv3d(
562 input,
563 self.weight,
564 self.strides,
565 self.padding,
566 data_format="NDHWC",
567 dilations=self.dilations,
568 name="conv3d",
569 )
570 bias_add_op = tf.nn.bias_add(conv3d_op, self.bias, name=self.result_name)
571 return bias_add_op
572
Jeremy Johnson015c3552022-02-23 12:15:03 +0000573 class DepthwiseConv2d:
574 def __init__(self, weight, strides, padding, dilations, name):
575 self.weight = weight
576 self.strides = strides
577 self.padding = padding
578 self.dilations = dilations
579 self.result_name = name
580
581 def eval(self, input):
582 dws_conv2d = tf.nn.depthwise_conv2d(
583 input,
584 self.weight,
585 self.strides,
586 self.padding,
587 data_format="NHWC",
588 dilations=self.dilations,
589 name="dws_conv2d",
590 )
591 return tf.identity(dws_conv2d, name=self.result_name)
592
593 class DepthwiseConv2dWithBias:
594 def __init__(self, weight, bias, strides, padding, dilations, name):
595 self.weight = weight
596 self.bias = bias
597 self.strides = strides
598 self.padding = padding
599 self.dilations = dilations
600 self.result_name = name
601
602 def eval(self, input):
603 dws_conv2d = tf.nn.depthwise_conv2d(
604 input,
605 self.weight,
606 self.strides,
607 self.padding,
608 data_format="NHWC",
609 dilations=self.dilations,
610 name="dws_conv2d",
611 )
612 bias_add_op = tf.nn.bias_add(
613 dws_conv2d, self.bias, data_format="NHWC", name=self.result_name
614 )
615 return bias_add_op
616
617 class TransposeConv2d:
618 def __init__(self, weight, output_shape, strides, padding, name):
619 self.weight = weight
620 self.output_shape = output_shape
621 self.strides = strides
622 self.padding = padding
623 self.result_name = name
624
625 def eval(self, input):
626 return tf.nn.conv2d_transpose(
627 input,
628 self.weight,
629 self.output_shape,
630 self.strides,
631 self.padding,
632 data_format="NHWC",
633 name=self.result_name,
634 )
635
636 class Argmax:
637 def __init__(self, axis, name):
638 self.axis = axis
639 self.result_name = name
640
641 def eval(self, a):
642 return tf.argmax(a, self.axis, output_type=tf.int32, name=self.result_name)
643
644 class AvgPool2d:
645 def __init__(self, strides, kernel_size, padding, name):
646 self.strides = strides
647 self.kernel_size = kernel_size
648 self.padding = padding
649 self.result_name = name
650
651 def eval(self, input):
652 return tf.nn.avg_pool2d(
653 input,
654 strides=self.strides,
655 ksize=self.kernel_size,
656 padding=self.padding,
657 data_format="NHWC",
658 name=self.result_name,
659 )
660
661 class MaxPool2d:
662 def __init__(self, strides, kernel_size, padding, name):
663 self.strides = strides
664 self.kernel_size = kernel_size
665 self.padding = padding
666 self.result_name = name
667
668 def eval(self, input):
669 return tf.nn.max_pool2d(
670 input,
671 strides=self.strides,
672 ksize=self.kernel_size,
673 padding=self.padding,
674 data_format="NHWC",
675 name=self.result_name,
676 )
677
678 class Reshape:
679 def __init__(self, shape, name):
680 self.shape = shape
681 self.result_name = name
682
683 def eval(self, a):
684 reshape_op = tf.reshape(a, self.shape)
685 return tf.identity(reshape_op, name=self.result_name)
686
687 class Transpose:
688 def __init__(self, perm, name):
689 self.perm = perm
690 self.result_name = name
691
692 def eval(self, a):
693 return tf.transpose(a, self.perm, name=self.result_name)
694
695 class Slice:
696 def __init__(self, begin, size, name):
697 self.begin = begin
698 self.size = size
699 self.result_name = name
700
701 def eval(self, a):
702 return tf.slice(a, begin=self.begin, size=self.size, name=self.result_name)
703
704 class StridedSlice:
705 def __init__(
706 self,
707 begin,
708 end,
709 strides,
710 begin_mask,
711 end_mask,
712 ellipsis_mask,
713 new_axis_mask,
714 shrink_axis_mask,
715 name,
716 ):
717 self.begin = begin
718 self.end = end
719 self.strides = strides
720 self.begin_mask = begin_mask
721 self.end_mask = end_mask
722 self.ellipsis_mask = ellipsis_mask
723 self.new_axis_mask = new_axis_mask
724 self.shrink_axis_mask = shrink_axis_mask
725 self.result_name = name
726
727 def eval(self, a):
728 return tf.strided_slice(
729 a,
730 begin=self.begin,
731 end=self.end,
732 strides=self.strides,
733 begin_mask=self.begin_mask,
734 end_mask=self.end_mask,
735 ellipsis_mask=self.ellipsis_mask,
736 new_axis_mask=self.new_axis_mask,
737 shrink_axis_mask=self.shrink_axis_mask,
738 name=self.result_name,
739 )
740
741 class Select:
742 def __init__(self, name):
743 self.result_name = name
744
745 def eval(self, selector, a, b):
746 return tf.where(condition=selector, x=a, y=b, name=self.result_name)
747
748 class Addn:
749 def __init__(self, name):
750 self.result_name = name
751
752 def eval(self, a, b, c, d):
753 return tf.add_n([a, b, c, d], name=self.result_name)
754
755 class Concatv2:
756 def __init__(self, axis, name):
757 self.axis = axis
758 self.result_name = name
759
760 def eval(self, a, b, c, d):
761 return tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
762
763 class Stack:
764 def __init__(self, axis, name):
765 self.axis = axis
766 self.result_name = name
767
768 def eval(self, a, b, c, d):
769 return tf.stack([a, b, c, d], axis=self.axis, name=self.result_name)
770
771 class Unstack:
772 def __init__(self, axis, name):
773 self.axis = axis
774 self.result_name = name
775
776 def eval(self, a):
777 unstack_op = tf.unstack(a, axis=self.axis, name="unstack_op")
778 result_count = a.shape[self.axis]
779
780 if result_count == 1:
781 return tf.identity(unstack_op[0], name=self.result_name)
782
783 sums = []
784 for i in range(result_count):
785 sums.append(
786 tf.math.reduce_sum(unstack_op[i], name="reduce_{}".format(i))
787 )
788 return tf.stack(sums, 0, name=self.result_name)
789
TatWai Chongf7008da2022-09-09 09:35:40 +0000790 class MirrorPad:
791 def __init__(self, padding, mode, name):
792 self.padding = padding
793 self.mode = mode
794 self.result_name = name
795
796 def eval(self, a):
797 return tf.pad(
798 a,
799 self.padding,
800 mode=self.mode,
801 constant_values=0,
802 name=self.result_name,
803 )
804
Jeremy Johnson015c3552022-02-23 12:15:03 +0000805 class Pad:
TatWai Chong2226f902023-02-22 18:38:01 -0800806 def __init__(self, padding, pad_const, name):
Jeremy Johnson015c3552022-02-23 12:15:03 +0000807 self.padding = padding
TatWai Chong2226f902023-02-22 18:38:01 -0800808 self.pad_const = pad_const
Jeremy Johnson015c3552022-02-23 12:15:03 +0000809 self.result_name = name
810
811 def eval(self, a):
812 return tf.pad(
813 a,
814 self.padding,
815 mode="CONSTANT",
TatWai Chong2226f902023-02-22 18:38:01 -0800816 constant_values=self.pad_const,
Jeremy Johnson015c3552022-02-23 12:15:03 +0000817 name=self.result_name,
818 )
819
820 class ExpandDims:
821 def __init__(self, axis, name):
822 self.axis = axis
823 self.result_name = name
824
825 def eval(self, a):
826 return tf.expand_dims(a, self.axis, name=self.result_name)
827
828 class Shape:
829 def __init__(self, name):
830 self.result_name = name
831
832 def eval(self, a):
833 return tf.shape(a, name=self.result_name)
834
835 class Rank:
836 def __init__(self, name):
837 self.result_name = name
838
839 def eval(self, a):
840 return tf.rank(a, name=self.result_name)
841
842 class Fill:
843 def __init__(self, shape, value, name):
844 self.shape = shape
845 self.value = value
846 self.result_name = name
847
848 def eval(self, a):
849 return tf.fill(self.shape, self.value, name=self.result_name)
850
851 class Elu:
852 def __init__(self, name):
853 self.result_name = name
854
855 def eval(self, a):
856 return tf.nn.elu(a, name=self.result_name)
857
858 class Softmax:
859 def __init__(self, name):
860 self.result_name = name
861
862 def eval(self, a):
863 return tf.nn.softmax(a, name=self.result_name)
864
865 class LogSoftmax:
866 def __init__(self, name):
867 self.result_name = name
868
869 def eval(self, a):
870 return tf.nn.log_softmax(a, name=self.result_name)
871
872 class MatMul:
873 def __init__(self, name):
874 self.result_name = name
875
876 def eval(self, a, b):
877 return tf.linalg.matmul(a, b, name=self.result_name)
878
879 class AddScalar:
880 def __init__(self, name):
881 self.result_name = name
882
883 def eval(self, a):
884 return tf.add(a, 1, name=self.result_name)
885
886 class Add1d:
887 def __init__(self, name):
888 self.result_name = name
889
890 def eval(self, a, b):
891 if len(b.shape) > 1:
892 b_1d = tf.reduce_sum(b, axis=list(range(0, len(b.shape) - 1, 1)))
893 else:
894 b_1d = b
895 return tf.add(a, b_1d, name=self.result_name)
896
897 class Split:
898 def __init__(self, num_splits, axis, name):
899 self.num_splits = num_splits
900 self.axis = axis
901 self.result_name = name
902
903 def eval(self, a):
904 # The split op generates a list of outputs. Since we have difficulty
905 # serializing a list or array of Numpy arrays, we will reduce each of
906 # the results
907
908 if not isinstance(self.num_splits, list):
909 split_op = tf.split(
910 a, num_or_size_splits=self.num_splits, axis=self.axis, name="split"
911 )
912 result_count = self.num_splits
913 else:
914 num_split = np.asarray(self.num_splits, dtype=np.int32)
915 split_vec_op = tf.compat.v1.constant(
916 num_split,
917 shape=num_split.shape,
918 dtype=tf.int32,
919 name="const_split_vec",
920 )
921 split_op = tf.split(
922 a, num_or_size_splits=split_vec_op, axis=self.axis, name="split"
923 )
924 result_count = num_split.shape[0]
925
926 sums = []
927 for i in range(result_count):
928 sums.append(tf.math.reduce_sum(split_op[i], name="reduce_{}".format(i)))
929 return tf.stack(sums, 0, name=self.result_name)
930
931 class Tile:
932 def __init__(self, multiples, name):
933 self.multiples = multiples
934 self.result_name = name
935
936 def eval(self, a):
937 t = tf.tile(a, self.multiples, name="tile")
938 return tf.identity(t, name=self.result_name)
939
940 class Reverse:
941 def __init__(self, axis, name):
942 self.axis = axis
943 self.result_name = name
944
945 def eval(self, a):
946 return tf.reverse(a, [self.axis], name=self.result_name)
947
948 class Gather:
949 def __init__(self, indices, batch_dims, axis, name):
950 self.indices = indices
951 self.batch_dims = batch_dims
952 self.axis = axis
953 self.result_name = name
954
955 def eval(self, a):
956 return tf.gather(
957 a,
958 self.indices,
959 batch_dims=self.batch_dims,
960 axis=self.axis,
961 name=self.result_name,
962 )
963
964 class GatherNd:
965 def __init__(self, indices, name):
966 self.indices = indices
967 self.result_name = name
968
969 def eval(self, a):
970 return tf.gather_nd(a, self.indices, name=self.result_name)
971
972 class ScatterNd:
973 def __init__(self, shape, indices_shape, N, rng, name):
974 self.shape = shape
975 self.indices_shape = indices_shape
976 self.N = N
977 self.rng = rng
978 self.result_name = name
979
980 def eval(self, a):
981
982 # This operator is special. The indices and updates tensors really need
983 # to be created together, but in the current structure of this tool there
984 # is no way to do that before now. The number of updates is determined by
985 # the indices, so we can really only create that after indices; but we
986 # don't know the type at that time.
987 #
988 # Shapes are guaranteed deterministic, but we'll use our rng
989 # copied from the arggen stage. It's possible that index and
990 # update *values* will be non-deterministic.
991 #
992 # We take the tensor_tensor simply to get the dtype.
993
994 shape_const = tf.constant(self.shape, tf.int32)
995
996 updates_shape = list(self.indices_shape[:-1])
997 updates_shape.extend(self.shape[self.indices_shape[-1] :])
998
999 updates_const = tf.constant(TGen.getRand(updates_shape, a.dtype, self.rng))
1000
1001 indices = np.zeros(self.indices_shape, dtype=np.int32)
1002
1003 # We need to generate the random indices tensor based on the
1004 # limits of 'shape' for each dimension. Surely, there is a faster
1005 # vectorized way to do this, but the tensors are fairly small so we
1006 # will do this one element at a time. Each element needs to be sized based
1007 # on the size of the last dimension.
1008 for idx in np.ndindex(indices.shape):
1009 indices[idx] = self.rng.integers(0, self.shape[idx[-1]], size=1)[0]
1010 # print('{} {}'.format(idx, indices[idx]))
1011
1012 indices_const = tf.constant(indices, dtype=tf.int32)
1013
1014 return tf.scatter_nd(
1015 indices=indices_const,
1016 updates=updates_const,
1017 shape=shape_const,
1018 name=self.result_name,
1019 )
1020
1021 class SpaceToBatch:
1022 def __init__(self, block_shape, padding, name):
1023 self.block_shape = block_shape
1024 self.padding = padding
1025 self.result_name = name
1026
1027 def eval(self, a):
1028 return tf.space_to_batch(
1029 a, self.block_shape, self.padding, name=self.result_name
1030 )
1031
1032 class BatchToSpace:
1033 def __init__(self, block_shape, cropping, name):
1034 self.block_shape = block_shape
1035 self.cropping = cropping
1036 self.result_name = name
1037
1038 def eval(self, a):
1039 # transpose to swap depth and batch first. this could avoid adding new shape
1040 block_rank = len(self.block_shape)
1041 perm = [len(a.shape) - 1]
1042 for i in range(block_rank):
1043 perm.append(i + 1)
1044 perm.append(0)
1045 transpose_op = tf.transpose(a, perm)
1046 return tf.batch_to_space(
1047 transpose_op, self.block_shape, self.cropping, name=self.result_name
1048 )
1049
1050 class SpaceToDepth:
1051 def __init__(self, block_shape, name):
1052 self.block_shape = block_shape
1053 self.result_name = name
1054
1055 def eval(self, a):
1056 return tf.nn.space_to_depth(a, self.block_shape, name=self.result_name)
1057
1058 class DepthToSpace:
1059 def __init__(self, block_shape, name):
1060 self.block_shape = block_shape
1061 self.result_name = name
1062
1063 def eval(self, a):
1064 return tf.nn.depth_to_space(a, self.block_shape, name=self.result_name)
1065
1066 class OneHot:
1067 def __init__(self, depth, axis, name):
1068 self.depth = depth
1069 self.axis = axis
1070 self.result_name = name
1071
1072 def eval(self, indices, on_value, off_value):
1073 return tf.one_hot(
1074 indices,
1075 self.depth,
1076 on_value,
1077 off_value,
1078 self.axis,
1079 on_value.dtype,
1080 self.result_name,
1081 )
1082
1083 class Fakequant:
1084 def __init__(self, num_bits, narrow_range, name):
1085 self.num_bits = num_bits
1086 self.narrow_range = narrow_range
1087 self.result_name = name
1088
1089 def eval(self, a):
1090 return tf.quantization.fake_quant_with_min_max_args(
1091 a,
1092 min=-2.0,
1093 max=2.0,
1094 num_bits=self.num_bits,
1095 narrow_range=self.narrow_range,
1096 name=self.result_name,
1097 )
1098
TatWai Chong0cef07e2023-02-27 13:22:52 -08001099 class Resize:
1100 def __init__(self, mode, align, half, scale, name):
Jeremy Johnson015c3552022-02-23 12:15:03 +00001101 self.result_name = name
TatWai Chong0cef07e2023-02-27 13:22:52 -08001102 self.mode = mode
1103 self.align = align
1104 self.half = half
1105 self.scale = scale
Jeremy Johnson015c3552022-02-23 12:15:03 +00001106
1107 def eval(self, a):
1108 out_shape = []
TatWai Chong0cef07e2023-02-27 13:22:52 -08001109 out_shape.append(a.shape[1] * self.scale)
1110 out_shape.append(a.shape[2] * self.scale)
Jeremy Johnson015c3552022-02-23 12:15:03 +00001111
TatWai Chong0cef07e2023-02-27 13:22:52 -08001112 tf_resize_dict = (
1113 {"tf_resize_func": tf.compat.v1.image.resize_nearest_neighbor}
1114 if (self.mode == "nearest")
1115 else {"tf_resize_func": tf.compat.v1.image.resize_bilinear}
1116 )
1117 resize = tf_resize_dict["tf_resize_func"](
Jeremy Johnson015c3552022-02-23 12:15:03 +00001118 a,
1119 out_shape,
TatWai Chong0cef07e2023-02-27 13:22:52 -08001120 align_corners=self.align,
Jeremy Johnson015c3552022-02-23 12:15:03 +00001121 name="resize",
TatWai Chong0cef07e2023-02-27 13:22:52 -08001122 half_pixel_centers=self.half,
TatWai Chongf7326092022-06-08 12:17:14 -07001123 )
1124 return tf.identity(resize, name=self.result_name)
1125
Jeremy Johnson015c3552022-02-23 12:15:03 +00001126 class LeftShift:
1127 def __init__(self, shift, name):
1128 self.shift = shift
1129 self.result_name = name
1130
1131 def eval(self, a):
1132 return tf.bitwise.left_shift(a, self.shift, name=self.result_name)
1133
1134 class RightShift:
1135 def __init__(self, shift, name):
1136 self.shift = shift
1137 self.result_name = name
1138
1139 def eval(self, a):
1140 return tf.bitwise.right_shift(a, self.shift, name=self.result_name)
Jerry Ge9e94af82022-10-27 09:57:00 -07001141
1142 class While:
1143 def __init__(self, name):
1144 self.result_name = name
1145
1146 def while_cond(self, x):
1147 return tf.reduce_sum(x) < self.cap
1148
1149 def while_body(self, x):
1150 return tf.add(x, tf.math.sigmoid(x))
1151
1152 def eval(self, a):
1153 self.cap = tf.cast(
1154 tf.constant(
1155 2.0,
1156 shape=[
1157 1,
1158 ],
1159 ),
1160 a.dtype,
1161 )
1162
1163 result = tf.while_loop(
1164 self.while_cond, self.while_body, [a], name=self.result_name
1165 )
1166
1167 return result[0]
1168
1169 class LSTM:
1170 def __init__(self, name):
1171 self.result_name = name
1172 self.lstm = tf.keras.layers.LSTM(
1173 2,
1174 activation="tanh",
1175 unroll=False,
1176 recurrent_activation="sigmoid",
1177 use_bias=True,
1178 recurrent_initializer="ones",
1179 kernel_initializer="ones",
1180 )
1181
1182 def eval(self, a):
1183 return self.lstm(a)
1184
1185 class GRU:
1186 def __init__(self, name):
1187 self.result_name = name
1188 self.lstm = tf.keras.layers.GRU(
1189 2,
1190 recurrent_activation="sigmoid",
1191 use_bias=True,
1192 recurrent_initializer="ones",
1193 kernel_initializer="ones",
1194 )
1195
1196 def eval(self, a):
1197 return self.lstm(a)
1198
1199 class RNN:
1200 def __init__(self, name):
1201 self.result_name = name
1202 basic_cell = tf.keras.layers.SimpleRNNCell(
1203 units=2,
1204 activation="sigmoid",
1205 use_bias=True,
1206 recurrent_initializer="ones",
1207 )
1208 self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False)
1209
1210 def eval(self, a):
1211 return self.rnn(a)
1212
1213 class FullyConnected:
1214 def __init__(self, name):
1215 self.result_name = name
1216 self.dense = tf.keras.layers.Dense(2)
1217
1218 def eval(self, a):
1219 return self.dense(a)
Luke Hutton261b7b62023-01-10 14:50:31 +00001220
1221 class RFFT2d:
1222 def __init__(self, fft_length, name):
1223 self.fft_length = fft_length
1224 self.result_name = name
1225
1226 def eval(self, a):
1227 return tf.signal.rfft2d(a, self.fft_length, name=self.result_name)