blob: 3b5d012949f99a9884982193ff182d37d348c598 [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001#!/usr/bin/env python3
2# Copyright (c) 2020-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import argparse
5import os
6import re
7import traceback
8
9import numpy as np
10
11# Level | Level for Humans | Level Description
12# -------|------------------|------------------------------------
13# 0 | DEBUG | [Default] Print all messages
14# 1 | INFO | Filter out INFO messages
15# 2 | WARNING | Filter out INFO & WARNING messages
16# 3 | ERROR | Filter out all messages
17# Filter tensorflow debug message except errors
18os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
19
20# Flake8 E402 - ignore imports not at top of file to allow os.environ setting
21import tensorflow as tf # noqa: E402
22from frameworks.write_test_json import write_test_json # noqa: E402
23from frameworks.arg_gen import ArgGen # noqa: E402
24from frameworks.tensor_gen import TGen # noqa: E402
25from frameworks.test_builder import TBuilder # noqa: E402
Jeremy Johnson5d1a3472022-03-31 09:50:06 +010026from frameworks.test_gen_utils import ( # noqa: E402
Jeremy Johnson015c3552022-02-23 12:15:03 +000027 QuantType,
28 get_tf_dtype,
29 get_shape_str,
30) # noqa: E402
31from tensorflow.lite.python.interpreter import OpResolverType # noqa: E402
32
33# All of the supported frameworks
34ALL_FRAMEWORKS = ["tf", "tflite"]
35
36# Lists of different data types
37TYPE_F = [tf.float32]
38TYPE_I = [tf.int32]
39TYPE_FI = [tf.float32, tf.int32]
40TYPE_B = [tf.bool]
41TYPE_FIB = [tf.float32, tf.int32, tf.bool]
42TYPE_H = [tf.float16]
43TYPE_FH = [tf.float32, tf.float16]
44TYPE_FHI = [tf.float32, tf.float16, tf.int32]
45TYPE_FHIB = [tf.float32, tf.float16, tf.int32, tf.bool]
46
47# The list of operator tests
48# Each dictionary entry for an op is a dictionary with the following required members:
49# 'operands': tuple (number_of_placeholder_tensors, number_of_constant_tensors)
50# 'build_fcn: tuple (Test builder function, Tensor generator function,
51# Argument generator function)
52# 'types': list of Tensorflow types that should be tested for this op
53# OR
54# a dictionary of {'framework_name': [type_list] } for cases where only
55# a subset of the types should be tested in each framework. This can also
56# be used to restrict an operator to a particular framework.
57#
58# And optional members:
59# 'template': boolean (indicates that this is a templated op which gets further
60# processing in createDynamicOpLists)
61# 'bias': boolean indicating that there is a bias component to be generated
62# 'qtypes': List of QuantType quantized types to generate for this op
63
64TF_OP_LIST = {
65 "add": {
66 "operands": (2, 0),
67 "build_fcn": (TBuilder.Add, TGen.tgBFuzz, ArgGen.agNone),
68 "types": {
69 "tf": TYPE_FI,
70 "tflite": list(
71 TYPE_FI + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
72 ),
73 },
74 },
75 "sub": {
76 "operands": (2, 0),
77 "build_fcn": (TBuilder.Sub, TGen.tgBFuzz, ArgGen.agNone),
78 "types": {
79 "tf": TYPE_FI,
80 "tflite": list(TYPE_FI + [QuantType.ALL_U8, QuantType.ALL_I8]),
81 # QuantType.ALL_I16 fail in TFLite conversion
82 },
83 },
84 "mul": {
85 "operands": (2, 0),
86 "build_fcn": (TBuilder.Mul, TGen.tgBFuzz, ArgGen.agNone),
87 "types": {
88 "tf": TYPE_FI,
89 "tflite": list(
90 TYPE_FI + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
91 ),
92 },
93 },
94 "exp": {
95 "operands": (1, 0),
96 "build_fcn": (TBuilder.Exp, TGen.tgBasic, ArgGen.agNone),
97 "types": TYPE_F,
98 },
99 "rcp": {
100 "operands": (1, 0),
101 "build_fcn": (TBuilder.Rcp, TGen.tgBasic, ArgGen.agNone),
102 "types": TYPE_F,
103 },
104 "relu": {
105 "operands": (1, 0),
106 "build_fcn": (TBuilder.Relu, TGen.tgBasic, ArgGen.agNone),
107 "types": {
108 "tf": TYPE_F,
109 "tflite": list(
110 TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
111 ),
112 },
113 },
114 "relu6": {
115 "operands": (1, 0),
116 "build_fcn": (TBuilder.Relu6, TGen.tgBasic, ArgGen.agNone),
117 "types": {
118 "tf": TYPE_F,
119 "tflite": list(
120 TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
121 ),
122 },
123 },
124 "leaky_relu": {
125 "operands": (1, 0),
126 "build_fcn": (TBuilder.LeakyRelu, TGen.tgBasic, ArgGen.agFloat),
127 "types": {
128 "tf": TYPE_F,
129 "tflite": list(
130 TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
131 ),
132 },
133 },
134 "concat": {
135 "operands": (2, 0),
136 "build_fcn": (TBuilder.Concat, TGen.tgBasic, ArgGen.agAxes),
137 "types": TYPE_FI,
138 },
139 "bitwise_and": {
140 "operands": (2, 0),
141 "build_fcn": (TBuilder.BitwiseAnd, TGen.tgBFuzz, ArgGen.agNone),
142 "types": {"tf": TYPE_I}, # Not supported in TF Lite
143 },
144 "bitwise_or": {
145 "operands": (2, 0),
146 "build_fcn": (TBuilder.BitwiseOr, TGen.tgBFuzz, ArgGen.agNone),
147 "types": {"tf": TYPE_I}, # Not supported in TF Lite
148 },
149 "bitwise_not": {
150 "operands": (1, 0),
151 "build_fcn": (TBuilder.BitwiseNot, TGen.tgBFuzz, ArgGen.agNone),
152 "types": {"tf": TYPE_I}, # Not supported in TF Lite
153 },
154 "bitwise_xor": {
155 "operands": (2, 0),
156 "build_fcn": (TBuilder.BitwiseXor, TGen.tgBFuzz, ArgGen.agNone),
157 "types": {"tf": TYPE_I}, # Not supported in TF Lite
158 },
159 "logical_and": {
160 "operands": (2, 0),
161 "build_fcn": (TBuilder.LogicalAnd, TGen.tgBFuzz, ArgGen.agNone),
162 "types": TYPE_B,
163 },
164 "logical_or": {
165 "operands": (2, 0),
166 "build_fcn": (TBuilder.LogicalOr, TGen.tgBFuzz, ArgGen.agNone),
167 "types": TYPE_B,
168 },
169 "logical_not": {
170 "operands": (1, 0),
171 "build_fcn": (TBuilder.LogicalNot, TGen.tgBFuzz, ArgGen.agNone),
172 "types": TYPE_B,
173 },
174 "reduce_any": {
175 "operands": (1, 0),
176 "build_fcn": (TBuilder.ReduceAny, TGen.tgBasic, ArgGen.agAxesListKeepdims),
177 "types": TYPE_B,
178 },
179 "reduce_all": {
180 "operands": (1, 0),
181 "build_fcn": (TBuilder.ReduceAll, TGen.tgBasic, ArgGen.agAxesListKeepdims),
182 "types": {"tf": TYPE_B},
183 },
184 "reduce_min": {
185 "operands": (1, 0),
186 "build_fcn": (TBuilder.ReduceMin, TGen.tgBasic, ArgGen.agAxesListKeepdims),
187 "types": {
188 "tf": TYPE_FI,
189 "tflite": list(TYPE_FI + [QuantType.ALL_U8, QuantType.ALL_I8]),
190 },
191 },
192 "reduce_max": {
193 "operands": (1, 0),
194 "build_fcn": (TBuilder.ReduceMax, TGen.tgBasic, ArgGen.agAxesListKeepdims),
195 "types": {
196 "tf": TYPE_FI,
197 "tflite": list(TYPE_FI + [QuantType.ALL_U8, QuantType.ALL_I8]),
198 },
199 },
200 "reduce_sum": {
201 "operands": (1, 0),
202 "build_fcn": (TBuilder.ReduceSum, TGen.tgBasic, ArgGen.agAxesListKeepdims),
203 "types": {
204 "tf": TYPE_F,
205 # v2 converter doesn't recognize quantized reduce_sum
206 # "tflite": list(TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8]),
207 "tflite": TYPE_F,
208 },
209 },
210 "reduce_mean": {
211 "operands": (1, 0),
212 "build_fcn": (TBuilder.ReduceMean, TGen.tgBasic, ArgGen.agAxesListKeepdims),
213 "types": {
214 "tf": TYPE_F,
215 "tflite": list(
216 TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
217 ),
218 },
219 },
220 "reduce_product": {
221 "operands": (1, 0),
222 "build_fcn": (TBuilder.ReduceProduct, TGen.tgBasic, ArgGen.agAxesListKeepdims),
223 "types": TYPE_F,
224 },
225 "min": {
226 "operands": (2, 0),
227 "build_fcn": (TBuilder.Min, TGen.tgBFuzz, ArgGen.agNone),
228 "types": TYPE_FI,
229 },
230 "max": {
231 "operands": (2, 0),
232 "build_fcn": (TBuilder.Max, TGen.tgBFuzz, ArgGen.agNone),
233 "types": TYPE_FI,
234 },
235 "pow": {
236 "operands": (2, 0),
237 "build_fcn": (TBuilder.Pow, TGen.tgBFuzz, ArgGen.agNone),
238 # Technically, integer is supported, but only for positive exponents.
239 # Needs a random argument generator.
240 "types": TYPE_F,
241 },
242 "abs": {
243 "operands": (1, 0),
244 "build_fcn": (TBuilder.Abs, TGen.tgBasic, ArgGen.agNone),
245 "types": TYPE_F,
246 },
247 "ceil": {
248 "operands": (1, 0),
249 "build_fcn": (TBuilder.Ceil, TGen.tgBasic, ArgGen.agNone),
250 "types": TYPE_F,
251 },
252 "floor": {
253 "operands": (1, 0),
254 "build_fcn": (TBuilder.Floor, TGen.tgBasic, ArgGen.agNone),
255 "types": TYPE_F,
256 },
257 "log": {
258 "operands": (1, 0),
259 "build_fcn": (TBuilder.Log, TGen.tgBasic, ArgGen.agNone),
260 "types": TYPE_F,
261 },
262 "negate": {
263 "operands": (1, 0),
264 "build_fcn": (TBuilder.Negate, TGen.tgBasic, ArgGen.agNone),
265 "types": TYPE_F,
266 },
267 "rsqrt": {
268 "operands": (1, 0),
269 "build_fcn": (TBuilder.Rsqrt, TGen.tgBasic, ArgGen.agNone),
270 "types": TYPE_F,
271 },
272 "sigmoid": {
273 "operands": (1, 0),
274 "build_fcn": (TBuilder.Sigmoid, TGen.tgBasic, ArgGen.agNone),
275 "types": {
276 "tf": TYPE_F,
277 "tflite": list(
278 TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
279 ),
280 },
281 },
282 "tanh": {
283 "operands": (1, 0),
284 "build_fcn": (TBuilder.Tanh, TGen.tgBasic, ArgGen.agNone),
285 "types": {
286 "tf": TYPE_F,
287 "tflite": list(
288 TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
289 ),
290 },
291 },
292 "square": {
293 "operands": (1, 0),
294 "build_fcn": (TBuilder.Square, TGen.tgBasic, ArgGen.agNone),
295 "types": TYPE_F,
296 },
297 "squared_difference": {
298 "operands": (2, 0),
299 "build_fcn": (TBuilder.SquaredDifference, TGen.tgBFuzz, ArgGen.agNone),
300 "types": TYPE_F,
301 },
302 "equal": {
303 "operands": (2, 0),
304 "build_fcn": (TBuilder.Equal, TGen.tgBFuzz, ArgGen.agNone),
305 "types": TYPE_FI,
306 },
307 "greater_equal": {
308 "operands": (2, 0),
309 "build_fcn": (TBuilder.GreaterEqual, TGen.tgBFuzz, ArgGen.agNone),
310 "types": TYPE_FI,
311 },
312 "greater": {
313 "operands": (2, 0),
314 "build_fcn": (TBuilder.Greater, TGen.tgBFuzz, ArgGen.agNone),
315 "types": TYPE_FI,
316 },
317 "less": {
318 "operands": (2, 0),
319 "build_fcn": (TBuilder.Less, TGen.tgBFuzz, ArgGen.agNone),
320 "types": TYPE_FI,
321 },
322 "less_equal": {
323 "operands": (2, 0),
324 "build_fcn": (TBuilder.LessEqual, TGen.tgBFuzz, ArgGen.agNone),
325 "types": TYPE_FI,
326 },
327 "conv2d_TEMPLATE": {
328 "operands": (1, 1),
329 "build_fcn": (TBuilder.Conv2d, TGen.tgConv2d, ArgGen.agConv2d),
330 "types": {
331 "tf": [tf.float32],
332 "tflite": [
333 tf.float32,
334 QuantType.CONV_U8_U8,
335 QuantType.CONV_I8_I8,
336 QuantType.CONV_I16_I8,
337 ],
338 },
339 "template": True,
340 },
341 "conv2d_relu_TEMPLATE": {
342 "operands": (1, 2),
343 "build_fcn": (TBuilder.Conv2dRelu, TGen.tgConv2d, ArgGen.agNone),
344 "types": {
345 "tf": [tf.float32],
346 "tflite": [
347 tf.float32,
348 QuantType.CONV_U8_U8,
349 QuantType.CONV_I8_I8,
350 QuantType.CONV_I16_I8,
351 ],
352 },
353 "template": True,
354 },
355 "conv2d_relu6_TEMPLATE": {
356 "operands": (1, 2),
357 "build_fcn": (TBuilder.Conv2dRelu6, TGen.tgConv2d, ArgGen.agNone),
358 "types": {
359 "tf": [tf.float32],
360 "tflite": [
361 tf.float32,
362 QuantType.CONV_U8_U8,
363 QuantType.CONV_I8_I8,
364 QuantType.CONV_I16_I8,
365 ],
366 },
367 "template": True,
368 },
369 "conv2d_relu_n1_to_1_TEMPLATE": {
370 "operands": (1, 2),
371 "build_fcn": (TBuilder.Conv2dReluN1To1, TGen.tgConv2d, ArgGen.agNone),
372 "types": {
373 "tf": [tf.float32],
374 "tflite": [
375 tf.float32,
376 QuantType.CONV_U8_U8,
377 QuantType.CONV_I8_I8,
378 QuantType.CONV_I16_I8,
379 ],
380 },
381 "template": True,
382 },
383 # This test is converted as:
384 # tfl.conv2d(){fused_activation_function="NONE"} + tfl.tanh()
385 # TODO: anyway to generate tfl.conv2d(){fused_activation_function="TANH"}?
386 "conv2d_tanh_TEMPLATE": {
387 "operands": (1, 2),
388 "build_fcn": (TBuilder.Conv2dTanh, TGen.tgConv2d, ArgGen.agNone),
389 "types": {
390 "tf": [tf.float32],
391 "tflite": [
392 tf.float32,
393 QuantType.CONV_U8_U8,
394 QuantType.CONV_I8_I8,
395 QuantType.CONV_I16_I8,
396 ],
397 },
398 "template": True,
399 },
400 "conv2d_bias_TEMPLATE": {
401 "operands": (1, 2),
402 "build_fcn": (TBuilder.Conv2dWithBias, TGen.tgConv2d, ArgGen.agConv2d),
403 "types": {
404 "tf": [tf.float32],
405 "tflite": [
406 tf.float32,
407 QuantType.CONV_U8_U8,
408 QuantType.CONV_I8_I8,
409 QuantType.CONV_I16_I8,
410 ],
411 },
412 "bias": True,
413 "template": True,
414 },
415 "depthwise_conv2d_TEMPLATE": {
416 "operands": (1, 1),
417 "build_fcn": (
418 TBuilder.DepthwiseConv2d,
419 TGen.tgDepthwiseConv2d,
420 ArgGen.agDepthwiseConv2d,
421 ),
422 "types": {
423 "tf": [tf.float32],
424 "tflite": [
425 tf.float32,
426 QuantType.CONV_U8_U8,
427 QuantType.CONV_I8_I8,
428 QuantType.CONV_I16_I8,
429 ],
430 },
431 "template": True,
432 },
433 "depthwise_conv2d_bias_TEMPLATE": {
434 "operands": (1, 2),
435 "build_fcn": (
436 TBuilder.DepthwiseConv2dWithBias,
437 TGen.tgDepthwiseConv2d,
438 ArgGen.agDepthwiseConv2d,
439 ),
440 "types": {
441 "tf": [tf.float32],
442 "tflite": [
443 tf.float32,
444 QuantType.CONV_U8_U8,
445 QuantType.CONV_I8_I8,
446 QuantType.CONV_I16_I8,
447 ],
448 },
449 "bias": True,
450 "template": True,
451 },
452 "transpose_conv2d_TEMPLATE": {
453 "operands": (1, 1),
454 "build_fcn": (
455 TBuilder.TransposeConv2d,
456 TGen.tgTransposeConv2d,
457 ArgGen.agTransposeConv2d,
458 ),
459 "types": {
460 "tf": [tf.float32],
461 "tflite": [
462 tf.float32,
463 QuantType.CONV_U8_U8,
464 QuantType.CONV_I8_I8,
465 QuantType.CONV_I16_I8,
466 ],
467 },
468 "template": True,
469 },
470 "argmax": {
471 "operands": (1, 0),
472 "build_fcn": (TBuilder.Argmax, TGen.tgBasic, ArgGen.agAxes),
473 "types": {"tf": TYPE_F},
474 },
475 "avg_pool2d": {
476 "operands": (1, 0),
477 "build_fcn": (TBuilder.AvgPool2d, TGen.tgPooling, ArgGen.agPooling),
478 "types": {
479 "tf": TYPE_F,
480 "tflite": list(
481 TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
482 ),
483 },
484 },
485 "max_pool2d": {
486 "operands": (1, 0),
487 "build_fcn": (TBuilder.MaxPool2d, TGen.tgPooling, ArgGen.agPooling),
488 "types": {
489 "tf": TYPE_F,
490 "tflite": list(TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8]),
491 # ALL_I16 not supported yet
492 # In tensorflow/compiler/mlir/lite/ir/tfl_ops.td,
493 # QI16 is missing from MaxPoolOperandAndResultConstraints
494 # If adding QI16 back this test can run through.
495 },
496 },
497 "reshape": {
498 "operands": (1, 0),
499 "build_fcn": (TBuilder.Reshape, TGen.tgBasic, ArgGen.agReshape),
500 "types": TYPE_FI,
501 },
502 "transpose": {
503 "operands": (1, 0),
504 "build_fcn": (TBuilder.Transpose, TGen.tgBasic, ArgGen.agTranspose),
505 "types": TYPE_FI,
506 },
507 "slice": {
508 "operands": (1, 0),
509 "build_fcn": (TBuilder.Slice, TGen.tgBasic, ArgGen.agSlice),
510 "types": TYPE_FI,
511 },
512 "strided_slice": {
513 "operands": (1, 0),
514 "build_fcn": (TBuilder.StridedSlice, TGen.tgBasic, ArgGen.agStridedSlice),
515 "types": TYPE_FI,
516 },
517 "select": {
518 "operands": (3, 0),
519 "build_fcn": (TBuilder.Select, TGen.tgSelect, ArgGen.agNone),
520 "types": TYPE_FI,
521 },
522 "addn": {
523 "operands": (4, 0),
524 "build_fcn": (TBuilder.Addn, TGen.tgBasic, ArgGen.agNone),
525 "types": TYPE_FI,
526 },
527 "concatv2": {
528 "operands": (4, 0),
529 "build_fcn": (TBuilder.Concatv2, TGen.tgBasic, ArgGen.agAxes),
530 "types": TYPE_FI,
531 },
532 "stack": {
533 "operands": (4, 0),
534 "build_fcn": (TBuilder.Stack, TGen.tgBasic, ArgGen.agStack),
535 "types": TYPE_FI,
536 },
537 "unstack": {
538 "operands": (1, 0),
539 "build_fcn": (TBuilder.Unstack, TGen.tgPooling, ArgGen.agAxes),
540 "types": TYPE_F,
541 },
542 "pad": {
543 "operands": (1, 0),
544 "build_fcn": (TBuilder.Pad, TGen.tgBasic, ArgGen.agPad),
545 "types": TYPE_F,
546 },
547 "expand_dims": {
548 "operands": (1, 0),
549 "build_fcn": (TBuilder.ExpandDims, TGen.tgBasic, ArgGen.agStack),
550 "types": TYPE_FI,
551 },
552 "shape": {
553 "operands": (1, 0),
554 "build_fcn": (TBuilder.Shape, TGen.tgBasic, ArgGen.agNone),
555 "types": TYPE_FI,
556 },
557 "rank": {
558 "operands": (1, 0),
559 "build_fcn": (TBuilder.Rank, TGen.tgBasic, ArgGen.agNone),
560 "types": TYPE_FI,
561 },
562 "fill": {
563 "operands": (1, 0),
564 "build_fcn": (TBuilder.Fill, TGen.tgBasic, ArgGen.agFill),
565 "types": TYPE_FI,
566 },
567 "elu": {
568 "operands": (1, 0),
569 "build_fcn": (TBuilder.Elu, TGen.tgBasic, ArgGen.agNone),
570 "types": TYPE_F,
571 },
572 "softmax": {
573 "operands": (1, 0),
574 "build_fcn": (TBuilder.Softmax, TGen.tgBasic, ArgGen.agNone),
575 "types": {
576 "tf": TYPE_F,
577 "tflite": list(
578 TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
579 ),
580 },
581 },
582 "log_softmax": {
583 "operands": (1, 0),
584 "build_fcn": (TBuilder.LogSoftmax, TGen.tgBasic, ArgGen.agNone),
585 "types": TYPE_F,
586 },
587 "matmul": {
588 "operands": (2, 0),
589 "build_fcn": (TBuilder.MatMul, TGen.tgMatmul, ArgGen.agNone),
590 "types": {
591 "tf": TYPE_F,
592 "tflite": list(
593 TYPE_F
594 + [QuantType.ALL_U8, QuantType.ALL_I8]
595 # 16 bits matmul fail to convert
596 ),
597 },
598 },
599 "add_scalar": {
600 "operands": (1, 0),
601 "build_fcn": (TBuilder.AddScalar, TGen.tgBasic, ArgGen.agNone),
602 "types": TYPE_F,
603 },
604 "add_1d": {
605 "operands": (2, 0),
606 "build_fcn": (TBuilder.Add1d, TGen.tgBasic, ArgGen.agNone),
607 "types": TYPE_F,
608 },
609 "split": {
610 "operands": (1, 0),
611 "build_fcn": (TBuilder.Split, TGen.tgBasic, ArgGen.agSplit),
612 "types": TYPE_FI,
613 },
614 "tile": {
615 "operands": (1, 0),
616 "build_fcn": (TBuilder.Tile, TGen.tgBasic, ArgGen.agTile),
617 "types": TYPE_FI,
618 },
619 "reverse": {
620 "operands": (1, 0),
621 "build_fcn": (TBuilder.Reverse, TGen.tgBasic, ArgGen.agAxes),
622 "types": {"tf": TYPE_FI},
623 },
624 "gather": {
625 "operands": (1, 0),
626 "build_fcn": (TBuilder.Gather, TGen.tgBasic, ArgGen.agGather),
627 "types": TYPE_FI,
628 },
629 "gather_nd": {
630 "operands": (1, 0),
631 "build_fcn": (TBuilder.GatherNd, TGen.tgBasic, ArgGen.agGatherND),
632 "types": TYPE_FI,
633 },
634 "scatter_nd": {
635 "operands": (1, 0),
636 "build_fcn": (TBuilder.ScatterNd, TGen.tgBasic, ArgGen.agScatterND),
637 "types": TYPE_FI,
638 },
639 "space_to_batch": {
640 "operands": (1, 0),
641 "build_fcn": (TBuilder.SpaceToBatch, TGen.tgBasic, ArgGen.agSpaceToBatch),
642 "types": TYPE_F,
643 },
644 "batch_to_space": {
645 "operands": (1, 0),
646 "build_fcn": (TBuilder.BatchToSpace, TGen.tgBasic, ArgGen.agBatchToSpace),
647 "types": TYPE_F,
648 },
649 "space_to_depth": {
650 "operands": (1, 0),
651 "build_fcn": (TBuilder.SpaceToDepth, TGen.tgBasic, ArgGen.agSpaceToDepth),
652 "types": TYPE_F,
653 },
654 "depth_to_space": {
655 "operands": (1, 0),
656 "build_fcn": (TBuilder.DepthToSpace, TGen.tgBasic, ArgGen.agDepthToSpace),
657 "types": TYPE_F,
658 },
659 "one_hot": {
660 "operands": (3, 1),
661 "build_fcn": (TBuilder.OneHot, TGen.tgOneHot, ArgGen.agOneHot),
662 "types": TYPE_FI,
663 },
664 "fakequant": {
665 "operands": (1, 0),
666 "build_fcn": (
667 TBuilder.Fakequant,
668 TGen.tgBasic,
669 ArgGen.agFakequant,
670 ),
671 "types": {"tf": TYPE_F},
672 },
673 "resize_nearest": {
674 "operands": (1, 0),
675 "build_fcn": (TBuilder.ResizeNearest, TGen.tgPooling, ArgGen.agNone),
676 "types": {
677 "tf": TYPE_F,
678 "tflite": list(
679 TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
680 ),
681 },
682 },
683 "resize_bilinear": {
684 "operands": (1, 0),
685 "build_fcn": (TBuilder.ResizeBilinear, TGen.tgPooling, ArgGen.agNone),
686 "types": {
687 "tf": TYPE_F,
688 "tflite": list(
689 TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
690 ),
691 },
692 },
693 "left_shift": {
694 "operands": (1, 0),
695 "build_fcn": (TBuilder.LeftShift, TGen.tgBasic, ArgGen.agShift),
696 "types": {"tf": [tf.int32]},
697 },
698 "right_shift": {
699 "operands": (1, 0),
700 "build_fcn": (TBuilder.RightShift, TGen.tgBasic, ArgGen.agShift),
701 "types": {
702 "tf": [
703 tf.int32,
704 ]
705 },
706 },
707}
708
709# Shapes to be tested; default can be overwritten
710shape_list = [
711 (1,),
712 (64,),
713 (14, 19),
714 (13, 21, 3),
715 (1, 4, 4, 4),
716 (1, 8, 4, 17),
717 (1, 4, 8, 19),
718 (1, 32, 32, 8),
719 (1, 7, 7, 9),
720]
721
722
723def gen_rand_shapes(args):
724 """Overwrite the global shape list with a new list of random shapes"""
725 global shape_list
726
727 rng = np.random.default_rng(args.random_seed)
728
729 # Don't let things get too big... cap the maximum volume, but let
730 # an individual dimension be 1..47
731 max_total_volume = 32 * 32 * 4
732
733 shape_list = []
734 # Only iterate over ranks 2, 3, and 4
735 for rank in range(2, 5):
736 for n in range(args.random_shapes):
737 new_shape = rng.integers(1, 48, size=rank)
738
739 # Set the batch dimension on 4D objects to 1
740 if rank == 4:
741 new_shape[0] = 1
742
743 # Limit the total shape volume and throw out any
744 # shapes that wouldn't leave at least size=2 in some non-batch dimension
745 volume = 1
746 skip_shape = False
747 for i in range(rank):
748
749 volume *= new_shape[i]
750
751 # Reduce the shape, while it's larger than the maximum volume
752 while volume > max_total_volume:
753 new_shape[i] = new_shape[i] // 2
754 volume = volume // 2
755
756 # Now an untenable dimension size? Skip this one.
757 if new_shape[i] < 1:
758 skip_shape = True
759
760 if not skip_shape:
761 shape_list.append(tuple(new_shape))
762
763
764# Construct, run and save a whole tensorflow tf.function to a protobuf file
765# or convert to .tflite if it's quantized unit test
766def run_unit_test(
767 op_name,
768 args,
769 test_dir,
770 curr_shape,
771 addl_args,
772 dtype,
773 excluded_framework_list,
774 quantized_inference_dtype,
775 result_name,
776 seed,
777):
778
779 try:
780 op = TF_OP_LIST[op_name]
781 op_fcn, tensor_gen_fcn, arg_gen_fcn = op["build_fcn"]
782
783 # Get and seed a random number generator for this test
784 rng = np.random.default_rng(seed)
785
786 # return placeholders=(str: name, np.array: value)
787 # consts=(str: name, np.array: value)
788 placeholders, consts = tensor_gen_fcn(op, curr_shape, dtype, rng)
789
790 # if test doesn't have any placeholders/consts, terminated
791 if len(placeholders) == 0 and len(consts) == 0:
792 return True
793
794 if not args.quiet:
795 print(" {} ".format(test_dir))
796
797 try:
798 os.mkdir(test_dir)
799 except FileExistsError:
800 pass
801
802 const_nodes = [value for name, value in consts]
803
804 num_placeholders = len(placeholders)
805 # if test is quantized, create tensor quantization metadata info for
806 # each input tensor, based on different quantized type
807 if quantized_inference_dtype:
808 is_quantized = True
809 # TODO: support INT8 IFM x INT4 weight later
810 if quantized_inference_dtype == QuantType.ALL_U8:
811 qzero = [128] * num_placeholders
812 numpy_dtype = [np.uint8] * num_placeholders
813 tflite_inference_dtype = tf.uint8
814 elif quantized_inference_dtype == QuantType.ALL_I8:
815 qzero = [0] * num_placeholders
816 numpy_dtype = [np.int8] * num_placeholders
817 tflite_inference_dtype = tf.int8
818 elif quantized_inference_dtype == QuantType.ALL_I16:
819 qzero = [0] * num_placeholders
820 numpy_dtype = [np.int16] * num_placeholders
821 tflite_inference_dtype = tf.int16
822 elif quantized_inference_dtype == QuantType.CONV_U8_U8:
823 assert (
824 num_placeholders == 1
825 ), "Unsupported number of placeholders for Convolution: {}".format(
826 num_placeholders
827 )
828 qzero = [128] * num_placeholders
829 if num_placeholders == 2:
830 numpy_dtype = [np.uint8, np.uint8]
831 else:
832 numpy_dtype = [np.uint8, np.uint8, np.int32]
833 tflite_inference_dtype = tf.uint8
834 elif quantized_inference_dtype == QuantType.CONV_I8_I8:
835 assert (
836 num_placeholders == 1
837 ), "Unsupported number of placeholders for Convolution: {}".format(
838 num_placeholders
839 )
840 qzero = [0] * num_placeholders
841 if num_placeholders == 2:
842 numpy_dtype = [np.int8, np.int8]
843 else:
844 numpy_dtype = [np.int8, np.int8, np.int32]
845 tflite_inference_dtype = tf.int8
846 elif quantized_inference_dtype == QuantType.CONV_I16_I8:
847 assert (
848 num_placeholders == 1
849 ), "Unsupported number of placeholders for Convolution: {}".format(
850 num_placeholders
851 )
852 if num_placeholders == 2:
853 qzero = [0, 0]
854 numpy_dtype = [np.int16, np.int8]
855 else:
856 qzero = [0, 0, 0]
857 numpy_dtype = [
858 np.int16,
859 np.int8,
860 np.int64,
861 ] # np.int64 to represent 40 bits accumulator
862 tflite_inference_dtype = tf.int16
863 else:
864 raise Exception(
865 "Unsupported fakequant dtype: {}".format(quantized_inference_dtype)
866 )
867
868 else:
869 is_quantized = False
870
871 tf_model_filename = None
872 tf_result_npy_filename = None
873 tf_result_name = None
874
875 tflite_model_filename = None
876 tflite_result_npy_filename = None
877 tflite_result_name = None
878
879 placeholder_names = []
880 placeholder_vals = []
881 placeholder_signatures = ()
882 placeholder_npy_filenames = []
883 placeholder_shapes = []
884
885 for idx, (name, val) in enumerate(placeholders):
886 placeholder_names.append(name)
887 placeholder_signatures = placeholder_signatures + (
888 tf.TensorSpec(shape=val.shape, dtype=val.dtype, name=name),
889 )
890 placeholder_npy_filenames.append("{}.npy".format(name.split(":")[0]))
891 placeholder_shapes.append(val.shape)
892
893 # Get test builder class
894 fcn_node = op_fcn(*const_nodes, *addl_args, result_name)
895 concrete_function = tf.function(input_signature=placeholder_signatures)(
896 fcn_node.eval
897 ).get_concrete_function()
898
899 if is_quantized:
900
901 assert dtype is tf.float32, "quantized test must come from float32 graph"
902
903 # 1. Quantize float placeholder npy to quantized to feed the graph
904 for idx, (name, val) in enumerate(placeholders):
905
906 # we use np.amin()/np.amax() to determine dynamic range
907 # for quantized test
908 zeropoint = 0
909 scale = 1.0
910 if numpy_dtype[idx] != np.int64:
911 qmin = np.iinfo(numpy_dtype[idx]).min
912 qmax = np.iinfo(numpy_dtype[idx]).max
913 num_bits = np.iinfo(numpy_dtype[idx]).bits
914 # 40 bit is represented as np.int64
915 else:
916 num_bits = 40
917 qmin = -(1 << num_bits)
918 qmax = (1 << num_bits) - 1
919
920 min_val = np.amin(val)
921 max_val = np.amax(val)
922
923 # for single value tensor, we set scale equal to the abs(value),
924 # and fix zeropoint to 128
925 # if val > 0, it'll be represented as 129,
926 # where val = (129 - 128) * val
927 # if val < 0, it'll be represented as 127,
928 # where val = (127 - 128) * (-val)
929 # if val == 0, it'll be represted as 128, with range [-128.0, 128.0]
930 # and let quantized 1 represent the value
931 # also adjust effective min/max consequently
932 if max_val == min_val:
933 if max_val != 0:
934 scale = abs(max_val)
935 else:
936 scale = 1.0
937 min_val = float(qmin - qzero[idx]) * scale
938 max_val = float(qmax - qzero[idx]) * scale
939 else:
940 scale = (max_val - min_val) / float(qmax - qmin)
941 zeropoint = int(round((-min_val) / scale)) + qmin
942
943 # run through tf.fakequant first to assure quantization error aligned
944 fakequant_val = tf.quantization.fake_quant_with_min_max_args(
945 val,
946 min=min_val,
947 max=max_val,
948 num_bits=num_bits,
949 name="gen_quant_npy",
950 )
951
952 quant_val = np.round(fakequant_val / scale).astype(np.int32) + zeropoint
953
954 # very few unit tests after TF hash may/2020, this quantized
955 # value for some reason exceed [0, 255] range
956 saved_val = np.clip(quant_val, qmin, qmax).astype(numpy_dtype[idx])
957
958 # saved all quantized tensor as np.int32
959 # since TOSA numpy Cpp API only supports int32
960 np.save(
961 os.path.join(test_dir, placeholder_npy_filenames[idx]),
962 saved_val.astype(np.int32),
963 False,
964 )
965
966 placeholder_vals.append(tf.convert_to_tensor(saved_val))
967
968 # 2. Convert the model to quantized TFLite flatbuffer
969 module = tf.Module()
970 converter = tf.lite.TFLiteConverter.from_concrete_functions(
971 [concrete_function], module
972 )
973 converter.optimizations = [tf.lite.Optimize.DEFAULT]
974 converter.experimental_new_converter = True
975
976 # use MLIR-based post-quantizer
977 converter.experimental_new_quantizer = True
978
979 flag = (
980 tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 # noqa: E501
981 )
982 if tflite_inference_dtype == tf.int16:
983 converter.target_spec.supported_ops = [flag]
984
985 def input_stats():
986 for i in range(0, args.num_samples):
987 a = [
988 TGen.getRand(shape, tf.float32, rng)
989 for shape in placeholder_shapes
990 ]
991 yield a
992
993 converter.representative_dataset = input_stats
994 converter.inference_input_type = tflite_inference_dtype
995 converter.inference_output_type = tflite_inference_dtype
996
997 tflite_model = converter.convert()
998
999 tflite_model_filename = "model.tflite"
1000
1001 # Write out converted model to disk
1002 with open(os.path.join(test_dir, tflite_model_filename), "wb") as f:
1003 f.write(tflite_model)
1004
1005 else: # is_quantized is False
1006
1007 # 1. Saved out numpy array directly
1008 for idx, (name, val) in enumerate(placeholders):
1009 placeholder_vals.append(tf.convert_to_tensor(val))
1010 np.save(
1011 os.path.join(test_dir, placeholder_npy_filenames[idx]), val, False
1012 )
1013
1014 # 2.a Saved out .pb if framework includes tensorflow
1015 if "tf" not in excluded_framework_list:
1016 # Write out graph as protobuf to disk
1017 tf_model_filename = "model.pb"
1018 tf.io.write_graph(
1019 concrete_function.graph, test_dir, tf_model_filename, True
1020 )
1021
1022 # 2.b Saved out .tflite if framework includes tflite
1023 if "tflite" not in excluded_framework_list:
1024 # Convert the model to TFLite flatbuffer
1025 module = tf.Module()
1026 converter = tf.lite.TFLiteConverter.from_concrete_functions(
1027 [concrete_function], module
1028 )
1029
1030 converter.experimental_new_converter = True
1031
1032 # Even it's non-quantized int32 test, this needs to be set to tf.float32
1033 converter.inference_input_type = tf.float32
1034 converter.inference_output_type = tf.float32
1035 tflite_model = converter.convert()
1036
1037 # Write out converted model to disk
1038 tflite_model_filename = "model.tflite"
1039 with open(os.path.join(test_dir, tflite_model_filename), "wb") as f:
1040 f.write(tflite_model)
1041
1042 # Get TF reference result if .pb is specified
1043 if tf_model_filename:
1044 tf_result_npy_filename = "tf_result.npy"
1045 tf_result = concrete_function(*placeholder_vals)
1046 np.save(os.path.join(test_dir, tf_result_npy_filename), tf_result, False)
1047
1048 tf_result_name = result_name
1049
1050 # Get TFLite inference result if .tflite is specified
1051 if tflite_model_filename:
1052 tflite_result_npy_filename = "tflite_result.npy"
1053
1054 ops_with_optimized_only_kernel = ["elu", "ceil", "gather"]
1055
1056 if args.tflite_kernel_mode == "optimized" or (
1057 op_name in ops_with_optimized_only_kernel
1058 ):
1059 interpreter = tf.lite.Interpreter(
1060 model_path=os.path.join(test_dir, tflite_model_filename)
1061 )
1062 elif args.tflite_kernel_mode == "reference":
1063 interpreter = tf.lite.Interpreter(
1064 model_path=os.path.join(test_dir, tflite_model_filename),
1065 experimental_op_resolver_type=OpResolverType.BUILTIN_REF,
1066 )
1067 else:
1068 assert 0, "unknown tflite interpreter mode {}".format(
1069 args.tflite_kernel_mode
1070 )
1071 interpreter.allocate_tensors()
1072
1073 input_details = interpreter.get_input_details()
1074 output_details = interpreter.get_output_details()
1075
1076 assert len(input_details) == len(
1077 placeholder_vals
1078 ), "number of placeholder mismatch"
1079
1080 for idx, val in enumerate(placeholder_vals):
1081 interpreter.set_tensor(input_details[idx]["index"], val.numpy())
1082
1083 interpreter.invoke()
1084 tflite_result = interpreter.get_tensor(output_details[0]["index"])
1085
1086 np.save(
1087 os.path.join(test_dir, tflite_result_npy_filename), tflite_result, False
1088 )
1089
1090 # Result tensor name would change after converting to TFLite flatbuffer
1091 # Overwrite the information from TFLite models directly.
1092 # Assume single result tensor now
1093 tflite_result_name = output_details[0]["name"]
1094
1095 # Write out test descriptor
1096 write_test_json(
1097 filename=os.path.join(test_dir, "test.json"),
1098 tf_model_filename=tf_model_filename,
1099 tf_result_npy_filename=tf_result_npy_filename,
1100 tf_result_name=tf_result_name,
1101 tflite_model_filename=tflite_model_filename,
1102 tflite_result_npy_filename=tflite_result_npy_filename,
1103 tflite_result_name=tflite_result_name,
1104 ifm_name=placeholder_names,
1105 ifm_file=placeholder_npy_filenames,
1106 ifm_shape=placeholder_shapes,
1107 framework_exclusions=excluded_framework_list,
1108 quantized=is_quantized,
1109 )
1110 except Exception as e:
1111 msg = "Error running task: {}".format(e)
1112 print(msg)
1113 print(
1114 "".join(
1115 traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
1116 )
1117 )
1118 return False
1119 return True
1120
1121
1122def build_const_net(
1123 args,
1124 curr_shape,
1125 op_name,
1126 dtype,
1127 excluded_framework_list,
1128 quantized_inference_dtype,
1129 result_name,
1130 seed,
1131 rng,
1132 filter,
1133 unit_test_args,
1134):
1135
1136 if quantized_inference_dtype:
1137 quant_dtype = get_tf_dtype(quantized_inference_dtype)
1138 test_dir = "test_{}_{}".format(op_name, get_shape_str(curr_shape, quant_dtype))
1139 else:
1140 test_dir = "test_{}_{}".format(op_name, get_shape_str(curr_shape, dtype))
1141 test_dir = os.path.join(args.output_dir, test_dir)
1142
1143 # If the operator has an additional function to generate arguments, call it
1144 # here and iterate through the argument list that it generates
1145 op = TF_OP_LIST[op_name]
1146 op_fcn, tensor_gen_fcn, arg_gen_fcn = op["build_fcn"]
1147
1148 addl_args_tuple = arg_gen_fcn(op, curr_shape, rng)
1149 for desc, addl_args in addl_args_tuple:
Jeremy Johnson0e6218e2022-05-05 17:08:04 +01001150 # Only filter on the full test_name, not the output directory
1151 _, test_name = os.path.split(test_dir + desc)
1152 if not filter or filter.search(test_name):
Jeremy Johnson015c3552022-02-23 12:15:03 +00001153 unit_test_args.append(
1154 [
1155 op_name,
1156 args,
1157 test_dir + desc,
1158 curr_shape,
1159 addl_args,
1160 dtype,
1161 excluded_framework_list,
1162 quantized_inference_dtype,
1163 result_name,
1164 seed,
1165 ]
1166 )
1167
1168
1169# python hash is not reproducible, create hash for our purpose
1170def op_name_hash(op_name):
1171 result = 0xDEADBEEF
1172 for ch in op_name:
1173 if result & 1:
1174 result = (ord(ch) << 24) ^ (result >> 1) ^ 0x82608EDB
1175 else:
1176 result = (ord(ch) << 24) ^ (result >> 1)
1177
1178 return result
1179
1180
1181def generate_op_tests(args, op_name, shape_list, result_name, filter, unit_test_args):
1182
1183 if not args.quiet:
1184 print(
1185 "Generating tests for {} ".format(
1186 op_name
1187 )
1188 )
1189
1190 op = TF_OP_LIST[op_name]
1191
1192 # Seed the RNG so that we get the same random tests for each test each time
1193 # If the number of tests for a given generation function changes, the tests
1194 # for that operator may also change accordingly, but this will at least keep
1195 # down churn across operators.
1196
1197 bounded_hash_val = (args.random_seed + op_name_hash(op_name)) % np.iinfo(
1198 np.int32
1199 ).max
1200 rng = np.random.default_rng(bounded_hash_val)
1201
1202 # this is a dictionary with 'tf' and 'tflite' as key
1203 # and value being the data types we want to test under these framework
1204
1205 if isinstance(op["types"], dict):
1206 try:
1207 tf_dtypes = op["types"]["tf"]
1208 except KeyError:
1209 tf_dtypes = []
1210 try:
1211 tflite_dtypes = op["types"]["tflite"]
1212 except KeyError:
1213 tflite_dtypes = []
1214 elif isinstance(op["types"], list):
1215 tf_dtypes = op["types"]
1216 tflite_dtypes = op["types"]
1217
1218 tf_nonquantized_dtypes = tf_dtypes # tf doesn't support quantized data types
1219 tflite_quantized_dtypes = []
1220 tflite_nonquantized_dtypes = []
1221 for dtype in tflite_dtypes:
1222 if isinstance(dtype, QuantType):
1223 tflite_quantized_dtypes.append(dtype)
1224 else:
1225 tflite_nonquantized_dtypes.append(dtype)
1226
1227 nonquantized_dtypes_set = set(tf_nonquantized_dtypes).union(
1228 set(tflite_nonquantized_dtypes)
1229 )
1230 nonquantized_dtypes = list(nonquantized_dtypes_set)
1231 quantized_dtypes = tflite_quantized_dtypes
1232
1233 # populate non quantized unit test arguments
1234 for dtype in nonquantized_dtypes:
1235
1236 excluded_framework_set = set(ALL_FRAMEWORKS)
1237 if dtype in tf_nonquantized_dtypes:
1238 excluded_framework_set.remove("tf")
1239 if dtype in tflite_nonquantized_dtypes:
1240 excluded_framework_set.remove("tflite")
1241 excluded_framework_list = list(excluded_framework_set)
1242
1243 for curr_shape in shape_list:
1244 build_const_net(
1245 args,
1246 curr_shape,
1247 op_name,
1248 dtype,
1249 excluded_framework_list,
1250 None,
1251 result_name,
1252 bounded_hash_val,
1253 rng,
1254 filter,
1255 unit_test_args,
1256 )
1257
1258 # populate quantized unit test arguments
1259 # must exclude 'tf' and source dtype being tf.float32
1260 for dtype in quantized_dtypes:
1261 for curr_shape in shape_list:
1262 build_const_net(
1263 args,
1264 curr_shape,
1265 op_name,
1266 tf.float32,
1267 ["tf"],
1268 dtype,
1269 result_name,
1270 bounded_hash_val,
1271 rng,
1272 filter,
1273 unit_test_args,
1274 )
1275
1276 return unit_test_args
1277
1278
1279def createDynamicOpLists():
1280 """The templated operators are conv2d-style operators with a number of kernel
1281 sizes. Since the operator is unchanged, we generate the range of kernel
1282 sizes here in this loop and remove the original templates from the list.
1283
1284 This could be expanded to non-conv2d-style operators in the future."""
1285
1286 # Dynamically create op lists for convolutions with a list of kernel sizes
1287 KERNELS = [
1288 [1, 1],
1289 [3, 3],
1290 [5, 5],
1291 ]
1292
1293 TEMPLATE_LIST = [
1294 "conv2d",
1295 "conv2d_bias",
1296 "conv2d_relu",
1297 "conv2d_relu6",
1298 "conv2d_relu_n1_to_1",
1299 "conv2d_tanh",
1300 "depthwise_conv2d",
1301 "depthwise_conv2d_bias",
1302 "transpose_conv2d",
1303 ]
1304
1305 for t in TEMPLATE_LIST:
1306 for k in KERNELS:
1307 testName = "{}_{}x{}".format(t, k[0], k[1])
1308 TF_OP_LIST[testName] = TF_OP_LIST["{}_TEMPLATE".format(t)].copy()
1309 TF_OP_LIST[testName]["filter"] = k
1310 TF_OP_LIST[testName]["template"] = False
1311
1312 # Delete any templates after having created any dynamic ops
1313 # This is a two-pass operation because it's bad practice to delete
1314 # keys from dictionaries while iterating
1315 keyList = []
1316 for k in TF_OP_LIST:
1317 try:
1318 if TF_OP_LIST[k]["template"]:
1319 keyList.append(k)
1320 continue
1321 except KeyError:
1322 pass
1323
1324 for k in keyList:
1325 del TF_OP_LIST[k]
1326
1327
1328def main():
1329 parser = argparse.ArgumentParser()
1330 parser.add_argument(
1331 "--seed", dest="random_seed", default=42, type=int, help="Random seed"
1332 )
1333 parser.add_argument(
1334 "--random-shapes",
1335 dest="random_shapes",
1336 default=0,
1337 type=int,
1338 help=(
1339 "Use N random shapes of each rank for generating tests,"
1340 "seeded with random seed"
1341 ),
1342 )
1343 parser.add_argument(
1344 "-o",
1345 "--output-dir",
1346 dest="output_dir",
1347 default=".",
1348 type=str,
1349 help="Test output directory path prefix",
1350 )
1351 parser.add_argument(
1352 "-q",
1353 "--quiet",
1354 dest="quiet",
1355 default=False,
1356 action="store_true",
1357 help="Do not print test names",
1358 )
1359 parser.add_argument(
1360 "-j", "--jobs", dest="jobs", type=int, default=1, help="Number of parallel jobs"
1361 )
1362 parser.add_argument(
1363 "-m",
1364 "--tflite-kernel-mode",
1365 dest="tflite_kernel_mode",
1366 type=str,
1367 choices=["reference", "optimized"],
1368 default="reference",
1369 help="TFLite interpreter kernel mode",
1370 )
1371 parser.add_argument(
1372 "--num-samples",
1373 dest="num_samples",
1374 default=200,
1375 type=int,
1376 help="Number of input samples for post-training quantization",
1377 )
1378 parser.add_argument(
1379 "--filter",
1380 dest="filter",
1381 default="",
1382 type=str,
1383 help="Filter test names by this expression",
1384 )
1385 args = parser.parse_args()
1386
1387 # Turn the filter into a re object if present
1388 filter = None
1389 if args.filter != "":
1390 filter = re.compile(args.filter)
1391
1392 # Autodetect CPU count
1393 if args.jobs <= 0:
1394 args.jobs = os.cpu_count()
1395
1396 # Disable TF info messages
1397 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
1398
1399 try:
1400 os.makedirs(args.output_dir)
1401 except FileExistsError:
1402 pass
1403
1404 if args.random_shapes:
1405 gen_rand_shapes(args)
1406
1407 # Build dynamic ops
1408 createDynamicOpLists()
1409
1410 # Generate the test list and arguments to run_unit_test()
1411 unit_test_args = []
1412
1413 for op in TF_OP_LIST:
1414 generate_op_tests(args, op, shape_list, "result", filter, unit_test_args)
1415
1416 errors = 0
1417 for t in unit_test_args:
1418 if not run_unit_test(*t):
1419 errors = errors + 1
1420
1421 if not args.quiet:
1422 print("\nAll tasks done - with {} errors".format(errors))
1423
1424 return 1 if errors else 0
1425
1426
1427if __name__ == "__main__":
1428 exit(main())