blob: caf63e3629c25032e1c393ec6af0a8d35927f975 [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003import numpy as np
4from generator.tosa_utils import product
5from generator.tosa_utils import usableDTypes
6from generator.tosa_utils import valueToName
7from tosa.DType import DType
8from tosa.Op import Op
9from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010
Matthew Haddone86fd342021-09-07 16:12:21 +010011
12class ErrorIf(object):
13 MaxDimExceeded = "MaxDimExceeded"
14 StrideSmallerEqualZero = "StrideSmallerEqualZero"
15 StrideLargerEqualMax = "StrideLargerEqualMax"
16 StrideLargerDimension = "StrideLargerDimension"
17 OffsetSmallerEqualMin = "OffsetSmallerEqualMin"
18 OffsetLargerEqualMax = "OffsetLargerEqualMax"
19 ShiftNotZero = "ShiftNotZero"
20 ShiftSmallerOne = "ShiftSmallerOne"
21 ShiftLargerEleven = "ShiftLargerEleven"
Matthew Haddon848efb42021-09-09 12:30:53 +010022 WrongInputType = "WrongInputType"
23 WrongOutputType = "WrongOutputType"
24 WrongInputList = "WrongInputList"
25 WrongOutputList = "WrongOutputList"
26 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010027 BatchMismatch = "BatchMismatch"
28 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010029 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000030 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010031 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010032 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010033 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010034 AxisSmallerZero = "AxisSmallerZero"
35 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
37 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010039 KernelSmallerOne = "KernelSmallerOne"
40 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000041 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010042 PadSmallerZero = "PadSmallerZero"
43 PadLargerEqualKernel = "PadLargerEqualKernel"
44 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Matthew Haddonc2025212021-10-08 21:21:05 +010045 ScaleNotTrue = "ScaleNotTrue"
46 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010047 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
48 StartSmallerZero = "StartSmallerZero"
49 SizeSmallerEqualZero = "SizeSmallerEqualZero"
50 StartSizeOutsideBounds = "StartSizeOutsideBounds"
51 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
52 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
53 IndexOutsideBounds = "IndexOutsideBounds"
54 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010055 MaxSmallerMin = "MaxSmallerMin"
56 ConcatInputRankMismatch = "ConcatInputRankMismatch"
57 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010058 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010059 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
60 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
61 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
62 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
63 InputListOutputListMismatch = "InputListOutputListMismatch"
64 InputListCondGraphMismatch = "InputListCondGraphMismatch"
65 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
66 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
67 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068
69
70class TosaErrorIfArgGen:
71 @staticmethod
72 def eiResizeErrorIf(
73 testGen,
74 error_name,
75 mode,
76 dtype,
77 shapeList,
78 outputDType,
79 shift,
80 stride,
81 stride_fp,
82 offset,
83 offset_fp,
84 ):
85
86 if outputDType == DType.FLOAT:
87 if error_name == ErrorIf.StrideSmallerEqualZero:
88 stride_fp = testGen.rng.random(size=[2]) - 2
89 elif error_name == ErrorIf.ShiftNotZero:
90 shift = testGen.rng.integers(1, 5)
91 elif error_name == ErrorIf.StrideLargerDimension:
92 shape = shapeList[0]
93 transform_height = testGen.rng.choice([False, True])
94 if transform_height:
95 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
96 else:
97 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
98 else:
99 if error_name == ErrorIf.StrideSmallerEqualZero:
100 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
101 elif error_name == ErrorIf.ShiftSmallerOne:
102 shift = testGen.rng.integers(-3, 1)
103 if shift <= 0:
104 stride = [
105 (16 >> -shift) - 1,
106 (16 >> -shift) - 1,
107 ] # avoids other ERROR_IF checks
108 offset = [
109 (16 >> -shift) - 1,
110 (16 >> -shift) - 1,
111 ] # avoids other ERROR_IF checks
112 else:
113 stride = [
114 (16 << shift) - 1,
115 (16 << shift) - 1,
116 ] # avoids other ERROR_IF checks
117 offset = [
118 (16 << shift) - 1,
119 (16 << shift) - 1,
120 ] # avoids other ERROR_IF checks
121 elif error_name == ErrorIf.ShiftLargerEleven:
122 shift = np.int16(testGen.rng.integers(12, 15))
123 elif error_name == ErrorIf.StrideLargerDimension:
124 shape = shapeList[0]
125 transform_height = testGen.rng.choice([False, True])
126 if transform_height:
127 stride[0] = shape[1] + testGen.rng.integers(1, 10)
128 else:
129 stride[1] = shape[2] + testGen.rng.integers(1, 10)
130 elif error_name == ErrorIf.StrideLargerEqualMax:
131 stride = [(16 << shift) + 1, (16 << shift) + 1]
132 elif error_name == ErrorIf.OffsetLargerEqualMax:
133 offset = [(16 << shift) + 1, (16 << shift) + 1]
134 elif error_name == ErrorIf.OffsetSmallerEqualMin:
135 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
136
137 if error_name == ErrorIf.WrongOutputType:
138 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
139 incorrect_types = (
140 DType.INT4,
141 DType.INT16,
142 DType.INT32,
143 DType.INT48,
144 DType.FLOAT,
145 )
146 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
147 incorrect_types = (
148 DType.INT4,
149 DType.INT8,
150 DType.INT32,
151 DType.INT48,
152 DType.FLOAT,
153 )
154 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
155 incorrect_types = (
156 DType.INT4,
157 DType.INT8,
158 DType.INT16,
159 DType.INT48,
160 DType.FLOAT,
161 )
162 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
163 incorrect_types = (
164 DType.INT4,
165 DType.INT8,
166 DType.INT16,
167 DType.INT32,
168 DType.FLOAT,
169 )
170 elif dtype == DType.FLOAT:
171 incorrect_types = (
172 DType.INT4,
173 DType.INT8,
174 DType.INT16,
175 DType.INT32,
176 DType.INT48,
177 )
178 outputDType = testGen.rng.choice(a=incorrect_types)
179
180 return shift, stride, stride_fp, offset, offset_fp, outputDType
181
182 @staticmethod
183 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
184 if (
185 error_name == ErrorIf.StrideSmallerOne
186 # padding must not exceed the kernel size
187 and pad[0] < kernel[0]
188 and pad[1] < kernel[0]
189 and pad[2] < kernel[1]
190 and pad[3] < kernel[1]
191 ):
192 wrongStride = (
193 testGen.rng.choice([0, -1, -2, -3]),
194 testGen.rng.choice([0, -1, -2, -3]),
195 )
196 return wrongStride, pad, kernel
197 elif error_name == ErrorIf.PadSmallerZero:
198 wrongPad = (
199 testGen.rng.choice([-1, -2, -3]),
200 testGen.rng.choice([-1, -2, -3]),
201 testGen.rng.choice([-1, -2, -3]),
202 testGen.rng.choice([-1, -2, -3]),
203 )
204 return stride, wrongPad, kernel
205 elif error_name == ErrorIf.KernelSmallerOne:
206 wrongKernel = (
207 testGen.rng.choice([0, -1, -2, -3]),
208 testGen.rng.choice([0, -1, -2, -3]),
209 )
210 return stride, pad, wrongKernel
211 elif error_name == ErrorIf.PadLargerEqualKernel:
212 wrongPad = (
213 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
214 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
215 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
216 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
217 )
218 return stride, wrongPad, kernel
219 else:
220 return None, None, None
221
222 @staticmethod
223 def eiRescaleWrongOutputType(input_dtype, output_dtype):
224 if input_dtype == DType.INT8:
225 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
226 return True
227 if input_dtype in [DType.INT16, DType.INT32]:
228 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
229 return True
230 elif input_dtype == DType.INT48:
231 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
232 return True
233 elif input_dtype == DType.UINT8:
234 if output_dtype != DType.INT8:
235 return True
236 return False
237
238 @staticmethod
239 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
240 # Mess up input/output tensors for ERROR_IF checks
241 if error_name == "WrongInputList":
242 add_input = testGen.rng.choice([True, False])
243 if add_input:
244 input_list.append("eiDummyInput")
245 else:
246 input_list = input_list[:-1]
247 elif error_name == "WrongOutputList":
248 add_output = testGen.rng.choice([True, False])
249 if add_output:
250 output_list.append("eiDummyOutput")
251 else:
252 output_list = []
253 return input_list, output_list
254
255 @staticmethod
256 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
257 """Restrict the dimensions and overall size of a shape to
258 max_dim and max_items.
259 """
260 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
261 while product(new_shape) > max_items:
262 new_shape = [max(d - 1, 1) for d in new_shape]
263 return new_shape
264
265 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
266 if error_name == ErrorIf.StartSmallerZero:
267 newStart = []
268 for i in range(len(input_shape)):
269 newStart.append(testGen.rng.choice([-3, -2, -1]))
270 return newStart, size
271 elif error_name == ErrorIf.SizeSmallerEqualZero:
272 newSize = []
273 for i in range(len(input_shape)):
274 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
275 return start, newSize
276 elif error_name == ErrorIf.StartSizeOutsideBounds:
277 newStart, newSize = [], []
278 for i in range(len(input_shape)):
279 newStart.append(input_shape[i] - 1)
280 newSize.append(testGen.rng.choice([2, 3, 4]))
281 return newStart, newSize
282 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
283 remove = testGen.rng.choice([True, False])
284 if remove:
285 newStart = start[1:]
286 newSize = size[1:]
287 else:
288 newStart = start
289 newStart.append(1)
290 newSize = size
291 newSize.append(1)
292 return newStart, newSize
293 else:
294 return start, size
295
296 @staticmethod
297 def eiCastErrorIf(testGen, input_dtype):
298 if input_dtype in [DType.BOOL, DType.FLOAT]:
299 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
300 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
301 outputDType = [DType.INT48]
302 else:
303 assert True, f"input_dtype ({input_dtype}) not supported"
304 return outputDType
305
306
307class TosaErrorValidator:
308 @staticmethod
309 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
310 """Check ERROR_IF statements are caught and set the expected result.
311
312 Args:
313 serializer: the serializer to set the expected result in
314 validator_fcns: a sequence of validator functions to verify the result
315 error_name: the name of the ERROR_IF condition to check for
316 kwargs: keyword arguments for the validator functions
317 Returns:
318 True if the result matches the expected result; otherwise False
319 """
320 overall_result = True
321 for val_fcn in validator_fcns:
322 val_result = val_fcn(True, **kwargs)
323 validator_name = val_result["error_name"]
324 error_result = val_result["error_result"]
325 error_reason = val_result["error_reason"]
326
327 # expect an error IFF the error_name and validator_name match
328 expected_result = error_result == (error_name == validator_name)
329 overall_result &= expected_result
330
331 if expected_result and error_result:
332 serializer.setExpectedReturnCode(2, True, desc=error_reason)
333 elif error_result: # and not expected_result
334 print(
335 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
336 f" Expected: {error_name}, Got: {validator_name}"
337 )
338 elif not expected_result: # and not error_result
339 print(
340 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
341 f" Expected: {error_name}"
342 )
343
344 if not expected_result:
345 for k, v in sorted(kwargs.items()):
346 if k != "op":
347 if k.endswith("dtype"):
348 v = valueToName(DType, v)
349 print(f" {k} = {v}")
350
351 return overall_result
352
353 @staticmethod
354 def evWrongInputType(check=False, **kwargs):
355 error_result = False
356
357 # Find the unsupported input data types
358 op = kwargs["op"]
359 input_dtypes = op["types"]
360 allowed_input_dtypes = {
361 t[0] if isinstance(t, list) else t for t in input_dtypes
362 }
363 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
364
365 if op["op"] == Op.CLAMP:
366 wrong_input_dtypes.remove(DType.INT48)
367
368 if check:
369 input_dtype = kwargs["input_dtype"]
370 if input_dtype not in allowed_input_dtypes:
371 error_result = True
372
373 info_dict = {
374 "error_name": ErrorIf.WrongInputType,
375 "error_result": error_result,
376 "error_reason": "Input data type not supported for this operator",
377 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
378 }
379 return info_dict
380
381 @staticmethod
382 def evWrongOutputType(check=False, **kwargs):
383 error_result = False
384
385 if check:
386 input_dtype = kwargs["input_dtype"]
387 output_dtype = kwargs["output_dtype"]
388 op = kwargs["op"]
389
390 if op["op"] == Op.RESIZE:
391 mode = kwargs["mode"]
392 if (
393 (
394 mode == ResizeMode.NEAREST
395 and input_dtype == DType.INT8
396 and output_dtype != DType.INT8
397 )
398 or (
399 mode == ResizeMode.NEAREST
400 and input_dtype == DType.INT16
401 and output_dtype != DType.INT16
402 )
403 or (
404 mode == ResizeMode.BILINEAR
405 and input_dtype == DType.INT8
406 and output_dtype != DType.INT32
407 )
408 or (
409 mode == ResizeMode.BILINEAR
410 and input_dtype == DType.INT16
411 and output_dtype != DType.INT48
412 )
413 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
414 ):
415 error_result = True
416
417 elif op["op"] == Op.RESCALE:
418 if input_dtype == DType.INT8:
419 if output_dtype not in [
420 DType.UINT8,
421 DType.INT8,
422 DType.INT16,
423 DType.INT32,
424 ]:
425 error_result = True
426 if input_dtype in [DType.INT16, DType.INT32]:
427 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
428 error_result = True
429 elif input_dtype == DType.INT48:
430 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
431 error_result = True
432 elif input_dtype == DType.UINT8:
433 if output_dtype != DType.INT8:
434 error_result = True
435
436 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
437 if (
438 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
439 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
440 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
441 ):
442 error_result = True
443
444 elif op["op"] == Op.ARGMAX:
445 if (
446 input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
447 and output_dtype != DType.INT32
448 ):
449 error_result = True
450
451 elif op["op"] == Op.MUL:
452 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
453 error_result = True
454 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
455 error_result = True
456
457 elif op["op"] == Op.TABLE:
458 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
459 error_result = True
460 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
461 error_result = True
462
463 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
464 if output_dtype != DType.BOOL:
465 error_result = True
466
467 elif op["op"] == Op.CAST:
468 if (
469 (
470 input_dtype == DType.BOOL
471 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
472 )
473 or (
474 input_dtype == DType.INT8
475 and output_dtype
476 not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
477 )
478 or (
479 input_dtype == DType.INT16
480 and output_dtype
481 not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
482 )
483 or (
484 input_dtype == DType.INT32
485 and output_dtype
486 not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
487 )
488 or (
489 input_dtype == DType.FLOAT
490 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
491 )
492 ):
493 error_result = True
494
495 elif op["op"] in {
496 Op.CONV2D,
497 Op.CONV3D,
498 Op.DEPTHWISE_CONV2D,
499 Op.TRANSPOSE_CONV2D,
500 }:
501 if (
502 input_dtype == DType.INT8
503 and output_dtype != DType.INT32
504 or input_dtype == DType.INT16
505 and output_dtype != DType.INT48
506 or input_dtype == DType.FLOAT
507 and output_dtype != DType.FLOAT
508 ):
509 error_result = True
510 # invalid input types are ignored, to avoid reporting multiple errors
511
512 else:
513 if output_dtype != input_dtype:
514 error_result = True
515
516 info_dict = {
517 "error_name": ErrorIf.WrongOutputType,
518 "error_result": error_result,
519 "error_reason": (
520 "Output data type not supported for this configuration of operator"
521 ),
522 "param_reqs": {"rank": None, "dtype": None, "shape": None},
523 }
524 return info_dict
525
526 @staticmethod
527 def evWrongRank(check=False, **kwargs):
528 all_ranks = (1, 2, 3, 4, 5)
529
530 # Make a list of incorrect ranks
531 assert "op" in kwargs
532 op = kwargs["op"]
533 rmin, rmax = op["rank"]
534 rank_range = range(rmin, rmax + 1)
535 incorrect_ranks = list(set(all_ranks) - set(rank_range))
536 # Remove small incorrect ranks to avoid index errors
537 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
538 # Set minimum incorrect rank to 3 to avoid index error
539 if op["op"] in [Op.RESIZE]:
540 incorrect_ranks = [3, 5]
541 elif op["op"] in [Op.TRANSPOSE]:
542 incorrect_ranks = [7, 8]
543 elif op["op"] in [Op.CONV3D]:
544 incorrect_ranks = [6, 7]
545
546 error_name = ErrorIf.WrongRank
547 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
548 error_result = False
549 error_reason = "Rank not supported for this operator"
550
551 if check:
552 input_shape = kwargs["input_shape"]
553
554 if (
555 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
556 and len(input_shape) != 4
557 ):
558 error_result = True
559 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
560 error_result = True
561 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
562 error_result = True
563 else:
564 if len(input_shape) not in rank_range:
565 error_result = True
566
567 info_dict = {
568 "error_name": error_name,
569 "error_result": error_result,
570 "error_reason": error_reason,
571 "param_reqs": param_reqs,
572 }
573 return info_dict
574
575 @staticmethod
576 def evWrongInputList(check=False, **kwargs):
577 error_name = ErrorIf.WrongInputList
578 param_reqs = {"rank": None, "dtype": None, "shape": None}
579 error_result = False
580 error_reason = "Op input list does not match expected input"
581
582 if check:
583 op = kwargs["op"]
584 input_list = kwargs["input_list"]
585 num_operands = kwargs["num_operands"]
586 if op["op"] in [Op.SCATTER, Op.GATHER]:
587 # SCATTER/GATHER add an indices input tensor in their build functions
588 num_operands += 1
589 if len(input_list) != num_operands:
590 error_result = True
591
592 info_dict = {
593 "error_name": error_name,
594 "error_result": error_result,
595 "error_reason": error_reason,
596 "param_reqs": param_reqs,
597 }
598 return info_dict
599
600 @staticmethod
601 def evWrongOutputList(check=False, **kwargs):
602 error_name = ErrorIf.WrongOutputList
603 param_reqs = {"rank": None, "dtype": None, "shape": None}
604 error_result = False
605 error_reason = "Op output list does not match expected output"
606
607 if check:
608 output_list = kwargs["output_list"]
609 # Note this will be incorrect if an operator returns more than one output
610 if len(output_list) != 1:
611 error_result = True
612
613 info_dict = {
614 "error_name": error_name,
615 "error_result": error_result,
616 "error_reason": error_reason,
617 "param_reqs": param_reqs,
618 }
619 return info_dict
620
621 @staticmethod
622 def evMaxDimExceeded(check=False, **kwargs):
623 error_name = ErrorIf.MaxDimExceeded
624 param_reqs = {
625 "rank": [4, 4],
626 "dtype": [DType.INT8],
627 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
628 }
629 error_result = False
630 error_reason = (
631 "At least one maximum dimension is greater than or equal to 16384"
632 )
633
634 if check:
635 input_shape = kwargs["input_shape"]
636 output_shape = kwargs["output_shape"] # Note this is just (OH, OW)
637 if (
638 (input_shape[1] >= 16384)
639 or (input_shape[2] >= 16384)
640 or (output_shape[0] >= 16384)
641 or (output_shape[1] >= 16384)
642 ):
643 error_result = True
644
645 info_dict = {
646 "error_name": error_name,
647 "error_result": error_result,
648 "error_reason": error_reason,
649 "param_reqs": param_reqs,
650 }
651 return info_dict
652
653 @staticmethod
654 def evBatchMismatch(check=False, **kwargs):
655 error_name = ErrorIf.BatchMismatch
656 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
657 error_result = False
658 error_reason = "Input batch size not equal to output batch size"
659
660 assert "op" in kwargs
661 op = kwargs["op"]
662 rmin, rmax = op["rank"]
663 rank_range = range(rmin, rmax + 1)
664
665 if check:
666 input_shape = kwargs["input_shape"]
667 output_shape = kwargs[
668 "result_tensor"
669 ].shape # Note this is just (N, OH, OW, C)
670
671 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
672 error_result = True
673
674 info_dict = {
675 "error_name": error_name,
676 "error_result": error_result,
677 "error_reason": error_reason,
678 "param_reqs": param_reqs,
679 }
680 return info_dict
681
682 @staticmethod
683 def evChannelMismatch(check=False, **kwargs):
684 error_name = ErrorIf.ChannelMismatch
685 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
686 error_result = False
687 error_reason = "Input channel size not equal to output channel size"
688
689 assert "op" in kwargs
690 op = kwargs["op"]
691 rmin, rmax = op["rank"]
692 rank_range = range(rmin, rmax + 1)
693
694 if check:
695 input_shape = kwargs["input_shape"]
696 output_shape = kwargs[
697 "result_tensor"
698 ].shape # Note this is just (N, OH, OW, C)
699 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
700 error_result = True
701
702 info_dict = {
703 "error_name": error_name,
704 "error_result": error_result,
705 "error_reason": error_reason,
706 "param_reqs": param_reqs,
707 }
708 return info_dict
709
710 @staticmethod
711 def evStrideSmallerEqualZero(check=False, **kwargs):
712 error_name = ErrorIf.StrideSmallerEqualZero
713 param_reqs = {"rank": None, "dtype": None, "shape": None}
714 error_result = False
715 error_reason = "Stride value smaller than or equal zero"
716
717 if check:
718 input_dtype = kwargs["input_dtype"]
719 output_dtype = kwargs["output_dtype"]
720 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
721 stride = kwargs["stride"] # Work around wrong input/output type tests
722 elif output_dtype == DType.FLOAT:
723 stride = kwargs["stride_fp"]
724 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
725 stride = kwargs[
726 "stride_fp"
727 ] # Work around wrong input/output type tests
728 else:
729 stride = kwargs["stride"]
730
731 if min(stride) <= 0:
732 error_result = True
733
734 info_dict = {
735 "error_name": error_name,
736 "error_result": error_result,
737 "error_reason": error_reason,
738 "param_reqs": param_reqs,
739 }
740 return info_dict
741
742 @staticmethod
743 def evStrideLargerEqualMax(check=False, **kwargs):
744 error_name = ErrorIf.StrideLargerEqualMax
745 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
746 error_result = False
747 error_reason = "Stride value larger than or equal to maximum value"
748
749 if check:
750 shift = kwargs["shift"]
751 input_dtype = kwargs["input_dtype"]
752 stride = kwargs["stride"]
753 if input_dtype in [DType.INT8, DType.INT16]:
754 if shift >= 0 and (
755 stride[0] >= (16 << shift) or stride[1] >= (16 << shift)
756 ):
757 error_result = True
758 elif shift < 0 and (
759 stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)
760 ):
761 error_result = True
762
763 info_dict = {
764 "error_name": error_name,
765 "error_result": error_result,
766 "error_reason": error_reason,
767 "param_reqs": param_reqs,
768 }
769 return info_dict
770
771 @staticmethod
772 def evStrideLargerDimension(check=False, **kwargs):
773 error_name = ErrorIf.StrideLargerDimension
774 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
775 error_result = False
776 error_reason = "Stride value larger than or equal to H/W dimension"
777
778 if check:
779 shape = kwargs["input_shape"]
780 input_dtype = kwargs["input_dtype"]
781 stride = kwargs["stride_fp"]
782
783 if (
784 input_dtype == DType.FLOAT
785 and (stride[0] > shape[1])
786 or (stride[1] > shape[2])
787 ):
788 error_result = True
789
790 info_dict = {
791 "error_name": error_name,
792 "error_result": error_result,
793 "error_reason": error_reason,
794 "param_reqs": param_reqs,
795 }
796 return info_dict
797
798 @staticmethod
799 def evOffsetSmallerEqualMin(check=False, **kwargs):
800 error_name = ErrorIf.OffsetSmallerEqualMin
801 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
802 error_result = False
803 error_reason = "Offset value smaller than or equal to minimum value"
804
805 if check:
806 shift = kwargs["shift"]
807 output_dtype = kwargs["output_dtype"]
808 if output_dtype == DType.FLOAT:
809 offset = kwargs["offset_fp"]
810 else:
811 offset = kwargs["offset"]
812
813 if shift >= 0 and (
814 offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)
815 ):
816 error_result = True
817 elif shift < 0 and (
818 offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)
819 ):
820 error_result = True
821
822 info_dict = {
823 "error_name": error_name,
824 "error_result": error_result,
825 "error_reason": error_reason,
826 "param_reqs": param_reqs,
827 }
828 return info_dict
829
830 @staticmethod
831 def evOffsetLargerEqualMax(check=False, **kwargs):
832 error_name = ErrorIf.OffsetLargerEqualMax
833 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
834 error_result = False
835 error_reason = "Offset value larger than or equal to maximum value"
836
837 if check:
838 shift = kwargs["shift"]
839 output_dtype = kwargs["output_dtype"]
840 if output_dtype == DType.FLOAT:
841 offset = kwargs["offset_fp"]
842 else:
843 offset = kwargs["offset"]
844
845 if shift >= 0:
846 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
847 error_result = True
848
849 if shift >= 0 and (
850 offset[0] >= (16 << shift) or offset[1] >= (16 << shift)
851 ):
852 error_result = True
853 elif shift < 0 and (
854 offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)
855 ):
856 error_result = True
857
858 info_dict = {
859 "error_name": error_name,
860 "error_result": error_result,
861 "error_reason": error_reason,
862 "param_reqs": param_reqs,
863 }
864 return info_dict
865
866 @staticmethod
867 def evShiftNotZero(check=False, **kwargs):
868 error_name = ErrorIf.ShiftNotZero
869 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
870 error_result = False
871 error_reason = "Shift value must be zero for float input"
872
873 if check:
874 shift = kwargs["shift"]
875 input_dtype = kwargs["input_dtype"]
876 output_dtype = kwargs["output_dtype"]
877 if (
878 input_dtype == DType.FLOAT
879 and output_dtype == DType.FLOAT
880 and shift != 0
881 ):
882 error_result = True
883
884 info_dict = {
885 "error_name": error_name,
886 "error_result": error_result,
887 "error_reason": error_reason,
888 "param_reqs": param_reqs,
889 }
890 return info_dict
891
892 @staticmethod
893 def evShiftSmallerOne(check=False, **kwargs):
894 error_name = ErrorIf.ShiftSmallerOne
895 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
896 error_result = False
897 error_reason = "Shift value smaller than one"
898
899 if check:
900 shift = kwargs["shift"]
901 input_dtype = kwargs["input_dtype"]
902 output_dtype = kwargs["output_dtype"]
903 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
904 error_result = True
905
906 info_dict = {
907 "error_name": error_name,
908 "error_result": error_result,
909 "error_reason": error_reason,
910 "param_reqs": param_reqs,
911 }
912 return info_dict
913
914 @staticmethod
915 def evShiftLargerEleven(check=False, **kwargs):
916 error_name = ErrorIf.ShiftLargerEleven
917 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
918 error_result = False
919 error_reason = "Shift value larger than eleven"
920
921 if check:
922 shift = kwargs["shift"]
923 if shift > 11:
924 error_result = True
925
926 info_dict = {
927 "error_name": error_name,
928 "error_result": error_result,
929 "error_reason": error_reason,
930 "param_reqs": param_reqs,
931 }
932 return info_dict
933
934 @staticmethod
935 def evRankMismatch(check=False, **kwargs):
936 error_name = ErrorIf.RankMismatch
937 param_reqs = {"rank": None, "dtype": None, "shape": None}
938 error_result = False
939 error_reason = "Input Rank does not match output rank"
940
941 if check:
942 input1_shape = kwargs["input1"].shape
943 input2_shape = kwargs["input2"].shape
944 # In case of SELECT op
945 input3_shape = (
946 kwargs["input3"].shape if "input3" in kwargs else input2_shape
947 )
948 output_shape = kwargs["result_tensor"].shape
949 if (
950 (len(input1_shape) != len(output_shape))
951 or (len(input2_shape) != len(output_shape))
952 or (len(input3_shape) != len(output_shape))
953 ):
954 error_result = True
955
956 info_dict = {
957 "error_name": error_name,
958 "error_result": error_result,
959 "error_reason": error_reason,
960 "param_reqs": param_reqs,
961 }
962 return info_dict
963
964 @staticmethod
965 def evDimensionMismatch(check=False, **kwargs):
966 error_name = ErrorIf.DimensionMismatch
967 param_reqs = {"rank": None, "dtype": None, "shape": None}
968 error_result = False
969 error_reason = "Input Dimensions do not match output"
970
971 if check:
972 input1_shape = kwargs["input1"].shape
973 input2_shape = kwargs["input2"].shape
974 # In case of SELECT op
975 input3_shape = (
976 kwargs["input3"].shape if "input3" in kwargs else input2_shape
977 )
978 output_shape = kwargs["result_tensor"].shape
979 for i in range(
980 min(len(input1_shape), len(input2_shape), len(input3_shape))
981 ):
982 if (
983 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
984 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
985 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
986 ):
987 error_result = True
988
989 info_dict = {
990 "error_name": error_name,
991 "error_result": error_result,
992 "error_reason": error_reason,
993 "param_reqs": param_reqs,
994 }
995 return info_dict
996
997 @staticmethod
998 def evInputZeroPointNotZero(check=False, **kwargs):
999 op = kwargs["op"]
1000 error_result = False
1001
1002 # Quantizable types
1003 qTypes = (DType.INT8, DType.UINT8)
1004
1005 # This does not apply to quantizable types
1006 inputDtypes = [
1007 dtype
1008 for dtype in op["types"]
1009 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1010 or (not isinstance(dtype, list) and dtype not in qTypes)
1011 ]
1012
1013 if check:
1014 input_dtype = kwargs["input_dtype"]
1015 if isinstance(kwargs["qinfo"], tuple):
1016 qinfo = kwargs["qinfo"]
1017 input_zero_point = qinfo[0]
1018 else:
1019 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1020 qinfo = kwargs["qinfo"].ints
1021 input_zero_point = qinfo[0][1]
1022
1023 if op["op"] == Op.MATMUL:
1024 qinfo = kwargs["qinfo"].ints
1025 for dtype, zp in (
1026 (kwargs["input_dtype"], qinfo[0][1]),
1027 (kwargs["input2_dtype"], qinfo[1][1]),
1028 ):
1029 if dtype not in qTypes and zp != 0:
1030 error_result = True
1031 break
1032 else:
1033 error_result = input_dtype not in qTypes and input_zero_point != 0
1034
1035 info_dict = {
1036 "error_name": ErrorIf.InputZeroPointNotZero,
1037 "error_result": error_result,
1038 "error_reason": "Input DType not INT8 and zero point not 0",
1039 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1040 }
1041 return info_dict
1042
1043 @staticmethod
1044 def evWeightZeroPointNotZero(check=False, **kwargs):
1045 op = kwargs["op"]
1046
1047 # exclude inputs with INT8 weights
1048 inputDtypes = [
1049 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1050 ]
1051
1052 error_name = ErrorIf.WeightZeroPointNotZero
1053 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1054 error_result = False
1055 error_reason = "Weight DType not INT8 and zero point not 0"
1056
1057 if check:
1058 weight_dtype = kwargs["weight_dtype"]
1059 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
1060 qinfo = kwargs["qinfo"].ints
1061 weight_zero_point = qinfo[1][1]
1062 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1063 error_result = True
1064
1065 info_dict = {
1066 "error_name": error_name,
1067 "error_result": error_result,
1068 "error_reason": error_reason,
1069 "param_reqs": param_reqs,
1070 }
1071 return info_dict
1072
1073 @staticmethod
1074 def evOutputZeroPointNotZero(check=False, **kwargs):
1075 op = kwargs["op"]
1076 inputDtypes = op["types"].copy()
1077 if DType.INT8 in inputDtypes:
1078 inputDtypes.remove(DType.INT8)
1079 if DType.UINT8 in inputDtypes:
1080 inputDtypes.remove(DType.UINT8)
1081
1082 error_name = ErrorIf.OutputZeroPointNotZero
1083 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1084 error_result = False
1085 error_reason = "Output DType not INT8 and zero point not 0"
1086
1087 if check:
1088 input_dtype = kwargs["input_dtype"]
1089 output_dtype = kwargs["output_dtype"]
1090 if isinstance(kwargs["qinfo"], tuple):
1091 qinfo = kwargs["qinfo"]
1092 output_zero_point = qinfo[1]
1093 else:
1094 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1095 qinfo = kwargs["qinfo"].ints
1096 output_zero_point = qinfo[1][1]
1097 if op["op"] == Op.AVG_POOL2D:
1098 if input_dtype != DType.INT8 and output_zero_point != 0:
1099 error_result = True
1100 elif (
1101 output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
1102 ):
1103 error_result = True
1104
1105 info_dict = {
1106 "error_name": error_name,
1107 "error_result": error_result,
1108 "error_reason": error_reason,
1109 "param_reqs": param_reqs,
1110 }
1111 return info_dict
1112
1113 @staticmethod
1114 def evAxisSmallerZero(check=False, **kwargs):
1115 error_name = ErrorIf.AxisSmallerZero
1116 param_reqs = {"rank": None, "dtype": None, "shape": None}
1117 error_result = False
1118 error_reason = "Axis smaller than zero"
1119
1120 if check:
1121 axis = kwargs["axis"]
1122 if axis < 0:
1123 error_result = True
1124
1125 info_dict = {
1126 "error_name": error_name,
1127 "error_result": error_result,
1128 "error_reason": error_reason,
1129 "param_reqs": param_reqs,
1130 }
1131 return info_dict
1132
1133 @staticmethod
1134 def evAxisLargerRank(check=False, **kwargs):
1135 error_name = ErrorIf.AxisLargerRank
1136 param_reqs = {"rank": None, "dtype": None, "shape": None}
1137 error_result = False
1138 error_reason = "Axis larger than rank"
1139
1140 if check:
1141 axis = kwargs["axis"]
1142 shape = kwargs["input_shape"]
1143 if axis > len(shape):
1144 error_result = True
1145
1146 info_dict = {
1147 "error_name": error_name,
1148 "error_result": error_result,
1149 "error_reason": error_reason,
1150 "param_reqs": param_reqs,
1151 }
1152 return info_dict
1153
1154 @staticmethod
1155 def evShapeOfAxisNotOne(check=False, **kwargs):
1156 error_name = ErrorIf.ShapeOfAxisNotOne
1157 param_reqs = {"rank": None, "dtype": None, "shape": None}
1158 error_result = False
1159 error_reason = "shape[axis] is not equal to 1"
1160
1161 if check:
1162 axis = kwargs["axis"]
1163 shape = kwargs["output_shape"]
1164 if (0 <= axis < len(shape)) and shape[axis] != 1:
1165 error_result = True
1166
1167 info_dict = {
1168 "error_name": error_name,
1169 "error_result": error_result,
1170 "error_reason": error_reason,
1171 "param_reqs": param_reqs,
1172 }
1173 return info_dict
1174
1175 @staticmethod
1176 def evPadSmallerZero(check=False, **kwargs):
1177 error_name = ErrorIf.PadSmallerZero
1178 param_reqs = {"rank": None, "dtype": None, "shape": None}
1179 error_result = False
1180 error_reason = "At least one pad is smaller than zero"
1181
1182 if check:
1183 op = kwargs["op"]
1184 pad = kwargs["pad"]
1185 if op["op"] == Op.PAD:
1186 for padding in pad:
1187 if min(padding) < 0:
1188 error_result = True
1189 else:
1190 if min(pad) < 0:
1191 error_result = True
1192
1193 info_dict = {
1194 "error_name": error_name,
1195 "error_result": error_result,
1196 "error_reason": error_reason,
1197 "param_reqs": param_reqs,
1198 }
1199 return info_dict
1200
1201 @staticmethod
1202 def evPadLargerEqualKernel(check=False, **kwargs):
1203 error_name = ErrorIf.PadLargerEqualKernel
1204 param_reqs = {"rank": None, "dtype": None, "shape": None}
1205 error_result = False
1206 error_reason = "At least one pad is larger than kernel dimension"
1207
1208 if check:
1209 pad = kwargs["pad"]
1210 kernel = kwargs["kernel"]
1211 if min(pad) > 0 and min(kernel) > 1:
1212 if (
1213 pad[0] >= kernel[0]
1214 or pad[1] >= kernel[0]
1215 or pad[2] >= kernel[1]
1216 or pad[3] >= kernel[1]
1217 ):
1218 error_result = True
1219
1220 info_dict = {
1221 "error_name": error_name,
1222 "error_result": error_result,
1223 "error_reason": error_reason,
1224 "param_reqs": param_reqs,
1225 }
1226 return info_dict
1227
1228 @staticmethod
1229 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1230 error_name = ErrorIf.PoolingOutputShapeMismatch
1231 param_reqs = {"rank": None, "dtype": None, "shape": None}
1232 error_result = False
1233 error_reason = (
1234 "Mismatch between output shape provided and expected output shape"
1235 )
1236
1237 if check:
1238 pad = kwargs["pad"]
1239 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1240
1241 kernel = kwargs["kernel"]
1242 kernel_y, kernel_x = kernel[0], kernel[1]
1243
1244 input_shape = kwargs["input_shape"]
1245 IH, IW = input_shape[1], input_shape[2]
1246
1247 output_shape = kwargs["output_shape"]
1248 OH, OW = output_shape[1], output_shape[2]
1249
1250 stride = kwargs["stride"]
1251 stride_y, stride_x = stride[0], stride[1]
1252
1253 # calculate correct height, width dimensions
1254 if stride_x != 0 and stride_y != 0:
1255 y_correct = (
1256 IH + pad_top + pad_bottom + stride_y - kernel_y
1257 ) // stride_y
1258 x_correct = (
1259 IW + pad_left + pad_right + stride_x - kernel_x
1260 ) // stride_x
1261
1262 # ensure parameters are valid
1263 params_valid = (
1264 min(kernel) >= 1
1265 and min(stride) >= 1
1266 and min(pad) >= 0
1267 and not (
1268 pad[0] >= kernel[0]
1269 or pad[1] >= kernel[0]
1270 or pad[2] >= kernel[1]
1271 or pad[3] >= kernel[1]
1272 )
1273 )
1274
1275 if params_valid and (OH != y_correct or OW != x_correct):
1276 error_result = True
1277
1278 info_dict = {
1279 "error_name": error_name,
1280 "error_result": error_result,
1281 "error_reason": error_reason,
1282 "param_reqs": param_reqs,
1283 }
1284 return info_dict
1285
1286 @staticmethod
1287 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1288 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1289 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1290 error_result = False
1291 error_reason = (
1292 "Mismatch between output shape provided and expected output shape"
1293 )
1294
1295 if check:
1296 output_shape = kwargs["output_shape"]
1297 input_shape = kwargs["input_shape"]
1298 axis = kwargs["axis"]
1299
1300 dimension_match = True
1301 axis_shift = 0
1302
1303 # Check that rank is correct before trying to check dimensions
1304 if (len(input_shape) - 1) == len(output_shape):
1305 for i in range(len(input_shape)):
1306 if i == axis:
1307 axis_shift = 1
1308 continue
1309 if input_shape[i] != output_shape[i - axis_shift]:
1310 dimension_match = False
1311
1312 if not dimension_match:
1313 error_result = True
1314
1315 info_dict = {
1316 "error_name": error_name,
1317 "error_result": error_result,
1318 "error_reason": error_reason,
1319 "param_reqs": param_reqs,
1320 }
1321 return info_dict
1322
1323 @staticmethod
1324 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1325 error_name = ErrorIf.ArgmaxOutputRankMismatch
1326 param_reqs = {"rank": None, "dtype": None, "shape": None}
1327 error_result = False
1328 error_reason = (
1329 "Mismatch between output shape provided and expected output shape"
1330 )
1331
1332 if check:
1333 output_shape = kwargs["output_shape"]
1334 input_shape = kwargs["input_shape"]
1335 axis = kwargs["axis"]
1336 valid_params = axis >= 0 and axis < len(input_shape)
1337
1338 if valid_params and (len(input_shape) - 1) != len(output_shape):
1339 error_result = True
1340
1341 info_dict = {
1342 "error_name": error_name,
1343 "error_result": error_result,
1344 "error_reason": error_reason,
1345 "param_reqs": param_reqs,
1346 }
1347 return info_dict
1348
1349 @staticmethod
1350 def evKernelSmallerOne(check=False, **kwargs):
1351 error_name = ErrorIf.KernelSmallerOne
1352 param_reqs = {"rank": None, "dtype": None, "shape": None}
1353 error_result = False
1354 error_reason = "At least one kernel dimension is smaller than zero"
1355
1356 if check:
1357 kernel = kwargs["kernel"]
1358 if min(kernel) < 1:
1359 error_result = True
1360
1361 info_dict = {
1362 "error_name": error_name,
1363 "error_result": error_result,
1364 "error_reason": error_reason,
1365 "param_reqs": param_reqs,
1366 }
1367 return info_dict
1368
1369 @staticmethod
1370 def evStrideSmallerOne(check=False, **kwargs):
1371 error_name = ErrorIf.StrideSmallerOne
1372 param_reqs = {"rank": None, "dtype": None, "shape": None}
1373 error_result = False
1374 error_reason = "At least one stride dimension is smaller than zero"
1375
1376 if check:
1377 stride = kwargs["stride"]
1378 if min(stride) < 1:
1379 error_result = True
1380
1381 info_dict = {
1382 "error_name": error_name,
1383 "error_result": error_result,
1384 "error_reason": error_reason,
1385 "param_reqs": param_reqs,
1386 }
1387 return info_dict
1388
1389 @staticmethod
1390 def evDilationSmallerOne(check=False, **kwargs):
1391 error_result = check and min(kwargs["dilation"]) < 1
1392 return {
1393 "error_name": ErrorIf.DilationSmallerOne,
1394 "error_reason": "At least one dilation is smaller than one",
1395 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1396 "error_result": error_result,
1397 }
1398
1399 @staticmethod
1400 def evScaleTrue(check=False, **kwargs):
1401 error_name = ErrorIf.ScaleTrue
1402 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1403 error_result = False
1404 error_reason = "Scale set to true but input type is INT48"
1405
1406 if check:
1407 input_dtype = kwargs["input_dtype"]
1408 scale32 = kwargs["scale32"]
1409 if scale32 and input_dtype == DType.INT48:
1410 error_result = True
1411
1412 info_dict = {
1413 "error_name": error_name,
1414 "error_result": error_result,
1415 "error_reason": error_reason,
1416 "param_reqs": param_reqs,
1417 }
1418 return info_dict
1419
1420 @staticmethod
1421 def evScaleNotTrue(check=False, **kwargs):
1422 error_name = ErrorIf.ScaleNotTrue
1423 param_reqs = {"rank": None, "dtype": None, "shape": None}
1424 error_result = False
1425 error_reason = "Scale set to false but double round set to true"
1426
1427 if check:
1428 scale32 = kwargs["scale32"]
1429 double_round = kwargs["double_round"]
1430 if not scale32 and double_round:
1431 error_result = True
1432
1433 info_dict = {
1434 "error_name": error_name,
1435 "error_result": error_result,
1436 "error_reason": error_reason,
1437 "param_reqs": param_reqs,
1438 }
1439 return info_dict
1440
1441 @staticmethod
1442 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1443 error_name = ErrorIf.TensorSizeInputOutputMismatch
1444 param_reqs = {"rank": None, "dtype": None, "shape": None}
1445 error_result = False
1446 error_reason = "Input tensor size does not match output tensor size"
1447
1448 if check:
1449 input_shape = kwargs["input_shape"]
1450 output_shape = kwargs["output_shape"]
1451 input_size = np.prod(input_shape)
1452 output_size = np.prod(output_shape)
1453 if input_size != output_size:
1454 error_result = True
1455
1456 info_dict = {
1457 "error_name": error_name,
1458 "error_result": error_result,
1459 "error_reason": error_reason,
1460 "param_reqs": param_reqs,
1461 }
1462 return info_dict
1463
1464 @staticmethod
1465 def evStartSmallerZero(check=False, **kwargs):
1466 error_name = ErrorIf.StartSmallerZero
1467 param_reqs = {"rank": None, "dtype": None, "shape": None}
1468 error_result = False
1469 error_reason = "Starting point smaller than zero"
1470
1471 if check:
1472 input_shape = kwargs["input_shape"]
1473 start = kwargs["start"]
1474 rank = len(input_shape)
1475 if len(start) == rank:
1476 for index in range(rank):
1477 if start[index] < 0:
1478 error_result = True
1479
1480 info_dict = {
1481 "error_name": error_name,
1482 "error_result": error_result,
1483 "error_reason": error_reason,
1484 "param_reqs": param_reqs,
1485 }
1486 return info_dict
1487
1488 @staticmethod
1489 def evSizeSmallerEqualZero(check=False, **kwargs):
1490 error_name = ErrorIf.SizeSmallerEqualZero
1491 param_reqs = {"rank": None, "dtype": None, "shape": None}
1492 error_result = False
1493 error_reason = "Size smaller than or equal to zero"
1494
1495 if check:
1496 input_shape = kwargs["input_shape"]
1497 size = kwargs["size"]
1498 rank = len(input_shape)
1499 if len(size) == rank:
1500 for index in range(rank):
1501 if size[index] <= 0:
1502 error_result = True
1503
1504 info_dict = {
1505 "error_name": error_name,
1506 "error_result": error_result,
1507 "error_reason": error_reason,
1508 "param_reqs": param_reqs,
1509 }
1510 return info_dict
1511
1512 @staticmethod
1513 def evStartSizeOutsideBounds(check=False, **kwargs):
1514 error_name = ErrorIf.StartSizeOutsideBounds
1515 param_reqs = {"rank": None, "dtype": None, "shape": None}
1516 error_result = False
1517 error_reason = "starting point plus size larger than input dimension"
1518
1519 if check:
1520 input_shape = kwargs["input_shape"]
1521 start = kwargs["start"]
1522 size = kwargs["size"]
1523 rank = len(input_shape)
1524 if len(start) == rank and len(size) == rank:
1525 for index in range(rank):
1526 if start[index] + size[index] > input_shape[index]:
1527 error_result = True
1528
1529 info_dict = {
1530 "error_name": error_name,
1531 "error_result": error_result,
1532 "error_reason": error_reason,
1533 "param_reqs": param_reqs,
1534 }
1535 return info_dict
1536
1537 @staticmethod
1538 def evSizeOutputShapeMismatch(check=False, **kwargs):
1539 error_name = ErrorIf.SizeOutputShapeMismatch
1540 param_reqs = {"rank": None, "dtype": None, "shape": None}
1541 error_result = False
1542 error_reason = "Size does not match output dimension"
1543
1544 if check:
1545 input_shape = kwargs["input_shape"]
1546 output_shape = kwargs["output_shape"]
1547 size = kwargs["size"]
1548 rank = len(input_shape)
1549 if len(size) == rank:
1550 for index in range(rank):
1551 if size[index] != output_shape[index]:
1552 error_result = True
1553
1554 info_dict = {
1555 "error_name": error_name,
1556 "error_result": error_result,
1557 "error_reason": error_reason,
1558 "param_reqs": param_reqs,
1559 }
1560 return info_dict
1561
1562 @staticmethod
1563 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1564 error_name = ErrorIf.InputSizeStartLengthMismatch
1565 param_reqs = {"rank": None, "dtype": None, "shape": None}
1566 error_result = False
1567 error_reason = "rank of input not equal to length of start or size"
1568
1569 if check:
1570 input_shape = kwargs["input_shape"]
1571 start = kwargs["start"]
1572 size = kwargs["size"]
1573 rank = len(input_shape)
1574 if rank != len(start) or rank != len(size):
1575 error_result = True
1576
1577 info_dict = {
1578 "error_name": error_name,
1579 "error_result": error_result,
1580 "error_reason": error_reason,
1581 "param_reqs": param_reqs,
1582 }
1583 return info_dict
1584
1585 @staticmethod
1586 def evIndexOutsideBounds(check=False, **kwargs):
1587 error_name = ErrorIf.IndexOutsideBounds
1588 param_reqs = {"rank": None, "dtype": None, "shape": None}
1589 error_result = False
1590 error_reason = "Index outside of allowed bounds"
1591
1592 if check:
1593 input_shape = kwargs["input_shape"]
1594 perms = kwargs["perms"]
1595 rank = len(input_shape)
1596
1597 for index in perms:
1598 if index < 0 or index > rank:
1599 error_result = True
1600
1601 info_dict = {
1602 "error_name": error_name,
1603 "error_result": error_result,
1604 "error_reason": error_reason,
1605 "param_reqs": param_reqs,
1606 }
1607 return info_dict
1608
1609 @staticmethod
1610 def evIndexUsedTwice(check=False, **kwargs):
1611 error_name = ErrorIf.IndexUsedTwice
1612 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1613 error_result = False
1614 error_reason = "Index used multiple times"
1615
1616 if check:
1617 perms = kwargs["perms"]
1618
1619 unique_indices = []
1620 for index in perms:
1621 if index in unique_indices:
1622 error_result = True
1623 else:
1624 unique_indices.append(index)
1625
1626 info_dict = {
1627 "error_name": error_name,
1628 "error_result": error_result,
1629 "error_reason": error_reason,
1630 "param_reqs": param_reqs,
1631 }
1632 return info_dict
1633
1634 @staticmethod
1635 def evMaxSmallerMin(check=False, **kwargs):
1636 error_name = ErrorIf.MaxSmallerMin
1637 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1638 error_result = False
1639 error_reason = "Max value smaller than min value"
1640
1641 if check:
1642 max_val = kwargs["max_val"]
1643 min_val = kwargs["min_val"]
1644 if max_val < min_val:
1645 error_result = True
1646
1647 info_dict = {
1648 "error_name": error_name,
1649 "error_result": error_result,
1650 "error_reason": error_reason,
1651 "param_reqs": param_reqs,
1652 }
1653 return info_dict
1654
1655 @staticmethod
1656 def evConcatInputRankMismatch(check=False, **kwargs):
1657 error_name = ErrorIf.ConcatInputRankMismatch
1658 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1659 error_result = False
1660 error_reason = "Input ranks are not identical"
1661
1662 if check:
1663 inputs = kwargs["inputs"]
1664 input_shape = kwargs["input_shape"]
1665 for input in inputs:
1666 if len(input.shape) != len(input_shape):
1667 error_result = True
1668
1669 info_dict = {
1670 "error_name": error_name,
1671 "error_result": error_result,
1672 "error_reason": error_reason,
1673 "param_reqs": param_reqs,
1674 }
1675 return info_dict
1676
1677 @staticmethod
1678 def evConcatInputDimMismatch(check=False, **kwargs):
1679 error_name = ErrorIf.ConcatInputDimMismatch
1680 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1681 error_result = False
1682 error_reason = "Input dimensions differ on too many axes"
1683
1684 if check:
1685 inputs = kwargs["inputs"]
1686 input_shape = kwargs["input_shape"]
1687 axis = kwargs["axis"]
1688
1689 # Ensure rank is valid before checking dims.
1690 valid_rank = True
1691 for input in inputs:
1692 if len(input.shape) != len(input_shape):
1693 valid_rank = False
1694
1695 if valid_rank:
1696 for input in inputs:
1697 for i, dim in enumerate(input.shape):
1698 if dim != input_shape[i] and axis != i:
1699 error_result = True
1700
1701 info_dict = {
1702 "error_name": error_name,
1703 "error_result": error_result,
1704 "error_reason": error_reason,
1705 "param_reqs": param_reqs,
1706 }
1707 return info_dict
1708
1709 @staticmethod
1710 def evConcatShapeSumMismatch(check=False, **kwargs):
1711 error_name = ErrorIf.ConcatShapeSumMismatch
1712 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1713 error_result = False
1714 error_reason = "Sum of dimensions on axis not equal to output dimension"
1715
1716 if check:
1717 inputs = kwargs["inputs"]
1718 input_shape = kwargs["input_shape"]
1719 output_shape = kwargs["output_shape"]
1720 axis = kwargs["axis"]
1721
1722 # Ensure rank is valid before checking dims.
1723 valid_params = True
1724 for input in inputs:
1725 if len(input.shape) != len(input_shape):
1726 valid_params = False
1727 if axis < 0 or axis > len(input_shape):
1728 valid_params = False
1729
1730 if valid_params:
1731 axis_dim_sum = 0
1732 for input in inputs:
1733 axis_dim_sum += input.shape[axis]
1734
1735 if axis_dim_sum != output_shape[axis]:
1736 error_result = True
1737
1738 info_dict = {
1739 "error_name": error_name,
1740 "error_result": error_result,
1741 "error_reason": error_reason,
1742 "param_reqs": param_reqs,
1743 }
1744 return info_dict
1745
1746 @staticmethod
1747 def evInputListThenGraphMismatch(check=False, **kwargs):
1748 error_name = ErrorIf.CondIfInputListThenGraphMismatch
1749 param_reqs = {"rank": None, "dtype": None, "shape": None}
1750 error_result = False
1751 error_reason = "Input list shape does not match then-graph shape"
1752
1753 if check:
1754 a = kwargs["a"]
1755 b = kwargs["b"]
1756 basicBlocks = kwargs["basicBlocks"]
1757 then_block = basicBlocks[1]
1758 then_inputs = then_block.inputs
1759 then_tens = then_block.tensors
1760 if (a.shape != then_tens[then_inputs[0]].shape) or (
1761 b.shape != then_tens[then_inputs[1]].shape
1762 ):
1763 error_result = True
1764
1765 info_dict = {
1766 "error_name": error_name,
1767 "error_result": error_result,
1768 "error_reason": error_reason,
1769 "param_reqs": param_reqs,
1770 }
1771 return info_dict
1772
1773 @staticmethod
1774 def evInputListElseGraphMismatch(check=False, **kwargs):
1775 error_name = ErrorIf.CondIfInputListElseGraphMismatch
1776 param_reqs = {"rank": None, "dtype": None, "shape": None}
1777 error_result = False
1778 error_reason = "Input list shape does not match else-graph shape"
1779
1780 if check:
1781 a = kwargs["a"]
1782 b = kwargs["b"]
1783 basicBlocks = kwargs["basicBlocks"]
1784 else_block = basicBlocks[2]
1785 else_inputs = else_block.inputs
1786 else_tens = else_block.tensors
1787 if (a.shape != else_tens[else_inputs[0]].shape) or (
1788 b.shape != else_tens[else_inputs[1]].shape
1789 ):
1790 error_result = True
1791
1792 info_dict = {
1793 "error_name": error_name,
1794 "error_result": error_result,
1795 "error_reason": error_reason,
1796 "param_reqs": param_reqs,
1797 }
1798 return info_dict
1799
1800 @staticmethod
1801 def evOutputListThenGraphMismatch(check=False, **kwargs):
1802 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
1803 param_reqs = {"rank": None, "dtype": None, "shape": None}
1804 error_result = False
1805 error_reason = "Output list shape does not match then-graph shape"
1806
1807 if check:
1808 basicBlocks = kwargs["basicBlocks"]
1809 cond_block = basicBlocks[0]
1810 cond_outputs = cond_block.outputs
1811 cond_tens = cond_block.tensors
1812 then_block = basicBlocks[1]
1813 then_outputs = then_block.outputs
1814 then_tens = then_block.tensors
1815 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
1816 error_result = True
1817
1818 info_dict = {
1819 "error_name": error_name,
1820 "error_result": error_result,
1821 "error_reason": error_reason,
1822 "param_reqs": param_reqs,
1823 }
1824 return info_dict
1825
1826 @staticmethod
1827 def evOutputListElseGraphMismatch(check=False, **kwargs):
1828 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
1829 param_reqs = {"rank": None, "dtype": None, "shape": None}
1830 error_result = False
1831 error_reason = "Output list shape does not match else-graph shape"
1832
1833 if check:
1834 basicBlocks = kwargs["basicBlocks"]
1835 cond_block = basicBlocks[0]
1836 cond_outputs = cond_block.outputs
1837 cond_tens = cond_block.tensors
1838 else_block = basicBlocks[2]
1839 else_outputs = else_block.outputs
1840 else_tens = else_block.tensors
1841 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
1842 error_result = True
1843
1844 info_dict = {
1845 "error_name": error_name,
1846 "error_result": error_result,
1847 "error_reason": error_reason,
1848 "param_reqs": param_reqs,
1849 }
1850 return info_dict
1851
1852 @staticmethod
1853 def evInputListOutputListMismatch(check=False, **kwargs):
1854 error_name = ErrorIf.InputListOutputListMismatch
1855 param_reqs = {"rank": None, "dtype": None, "shape": None}
1856 error_result = False
1857 error_reason = "Input list does not match output list"
1858
1859 if check:
1860 basicBlocks = kwargs["basicBlocks"]
1861 while_block = basicBlocks[0]
1862 while_inputs = while_block.inputs
1863 while_outputs = while_block.outputs
1864 while_tens = while_block.tensors
1865 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
1866 error_result = True
1867
1868 info_dict = {
1869 "error_name": error_name,
1870 "error_result": error_result,
1871 "error_reason": error_reason,
1872 "param_reqs": param_reqs,
1873 }
1874 return info_dict
1875
1876 @staticmethod
1877 def evInputListCondGraphMismatch(check=False, **kwargs):
1878 error_name = ErrorIf.InputListCondGraphMismatch
1879 param_reqs = {"rank": None, "dtype": None, "shape": None}
1880 error_result = False
1881 error_reason = "Input list does not match cond graph"
1882
1883 if check:
1884 basicBlocks = kwargs["basicBlocks"]
1885 while_block = basicBlocks[0]
1886 while_inputs = while_block.inputs
1887 while_tens = while_block.tensors
1888 cond_block = basicBlocks[1]
1889 cond_inputs = cond_block.inputs
1890 cond_tens = cond_block.tensors
1891 if (
1892 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
1893 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
1894 error_result = True
1895
1896 info_dict = {
1897 "error_name": error_name,
1898 "error_result": error_result,
1899 "error_reason": error_reason,
1900 "param_reqs": param_reqs,
1901 }
1902 return info_dict
1903
1904 @staticmethod
1905 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
1906 error_name = ErrorIf.InputListBodyGraphInputMismatch
1907 param_reqs = {"rank": None, "dtype": None, "shape": None}
1908 error_result = False
1909 error_reason = "Input list does not match body graph input"
1910
1911 if check:
1912 basicBlocks = kwargs["basicBlocks"]
1913 while_block = basicBlocks[0]
1914 while_inputs = while_block.inputs
1915 while_tens = while_block.tensors
1916 body_block = basicBlocks[2]
1917 body_outputs = body_block.inputs
1918 body_tens = body_block.tensors
1919 if (
1920 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
1921 ) or (
1922 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
1923 ):
1924 error_result = True
1925
1926 info_dict = {
1927 "error_name": error_name,
1928 "error_result": error_result,
1929 "error_reason": error_reason,
1930 "param_reqs": param_reqs,
1931 }
1932 return info_dict
1933
1934 @staticmethod
1935 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
1936 error_name = ErrorIf.InputListBodyGraphOutputMismatch
1937 param_reqs = {"rank": None, "dtype": None, "shape": None}
1938 error_result = False
1939 error_reason = "Input list does not match body graph output"
1940
1941 if check:
1942 basicBlocks = kwargs["basicBlocks"]
1943 while_block = basicBlocks[0]
1944 while_inputs = while_block.inputs
1945 while_tens = while_block.tensors
1946 body_block = basicBlocks[2]
1947 body_outputs = body_block.outputs
1948 body_tens = body_block.tensors
1949 if (
1950 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
1951 ) or (
1952 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
1953 ):
1954 error_result = True
1955 info_dict = {
1956 "error_name": error_name,
1957 "error_result": error_result,
1958 "error_reason": error_reason,
1959 "param_reqs": param_reqs,
1960 }
1961 return info_dict
1962
1963 @staticmethod
1964 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
1965 error_name = ErrorIf.CondGraphOutputNotMatchingBool
1966 param_reqs = {"rank": None, "dtype": None, "shape": None}
1967 error_result = False
1968 error_reason = "Cond graph output is not a match list of booleans"
1969
1970 if check:
1971 basicBlocks = kwargs["basicBlocks"]
1972 cond_block = basicBlocks[1]
1973 cond_outputs = cond_block.outputs
1974 cond_tens = cond_block.tensors
1975 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
1976 error_result = True
1977
1978 info_dict = {
1979 "error_name": error_name,
1980 "error_result": error_result,
1981 "error_reason": error_reason,
1982 "param_reqs": param_reqs,
1983 }
1984 return info_dict
1985
1986
1987class TosaInvalidValidator:
1988 @staticmethod
1989 def ivWrongDataTypeOrModeResize(**kwargs):
1990 input_dtype = kwargs["input_dtype"]
1991 args = kwargs["args"]
1992 mode = args[0]
1993 output_dtype = args[8]
1994
1995 if mode == ResizeMode.BILINEAR:
1996 # Invalid output data type / Invalid input datatype
1997 return (
1998 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
1999 or not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
2000 or not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
2001 or (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2002 )
2003 elif mode == ResizeMode.NEAREST:
2004 # Invalid output data type / Invalid input datatype
2005 return (input_dtype != output_dtype) or (
2006 input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]
2007 )
2008 else:
2009 # Invalid resize mode
2010 return True
2011
2012 @staticmethod
2013 def ivBadStride(**kwargs):
2014 input_dtype = kwargs["input_dtype"]
2015 args = kwargs["args"]
2016 stride_x = args[1][0]
2017 stride_y = args[1][1]
2018 stride_fp_x = args[4][0]
2019 stride_fp_y = args[4][1]
2020
2021 if input_dtype == DType.FLOAT:
2022 if stride_fp_x <= 0 or stride_fp_y <= 0:
2023 # Negative or zero stride
2024 return True
2025 else:
2026 if stride_x <= 0 or stride_y <= 0:
2027 # Negative or zero stride
2028 return True
2029 return False
2030
2031 @staticmethod
2032 def ivHeightWidthInvalid(**kwargs):
2033 opName = kwargs["opName"]
2034
2035 inputShapes = kwargs["shapeList"]
2036 input_shape = inputShapes[0]
2037
2038 args = kwargs["args"]
2039 strides = args[0]
2040 padding = args[1]
2041
2042 if opName.endswith("pool2d"):
2043 # avg_pool2d, max_pool2d
2044 kernel_shape = args[2]
2045 h = (
2046 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2047 ) // strides[0]
2048 w = (
2049 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2050 ) // strides[1]
2051 # return True if any dimension is < 1
2052 return h < 1 or w < 1
2053
2054 if opName.startswith("transpose_conv2d"):
2055 # transpose_conv2d
2056 dilations = args[2]
2057 output_shape = args[3]
2058 filter_shape = inputShapes[1]
2059 kernel_shape = filter_shape[1:-1]
2060
2061 def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
2062 """Calculate the transpose_conv2d output size for a dimension.
2063
2064 Based on the keras function deconv_output_length, in
2065 https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
2066
2067 Args:
2068 in_size: the input size - int
2069 stride: the stride - int
2070 kernel_size: the kernel size - int
2071 dilation: the kernel dilation - int
2072 out_pad: the output padding - int
2073 in_pad: the input padding - int
2074
2075 Returns:
2076 the output size
2077 """
2078 dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
2079 return (
2080 (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
2081 )
2082
2083 for pad_h, pad_w in (
2084 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2085 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2086 (0, 0), # VALID padding
2087 ):
2088 h = get_out_size(
2089 input_shape[1],
2090 strides[0],
2091 kernel_shape[0],
2092 dilations[0],
2093 padding[0],
2094 pad_h,
2095 )
2096 w = get_out_size(
2097 input_shape[2],
2098 strides[1],
2099 kernel_shape[1],
2100 dilations[1],
2101 padding[1],
2102 pad_w,
2103 )
2104 if output_shape[1] == h and output_shape[2] == w:
2105 return False
2106
2107 # output shape does not match the expected shape for any padding option
2108 return True
2109
2110 if "conv2d" in opName or "conv3d" in opName:
2111 # conv2d, conv3d, depthwise_conv2d
2112 dilations = args[2]
2113 filter_shape = inputShapes[1]
2114 kernel_shape = (
2115 filter_shape[0:2]
2116 if opName.startswith("depthwise_conv2d")
2117 else filter_shape[1:-1]
2118 )
2119
2120 for i in range(len(kernel_shape)):
2121 dim = (
2122 input_shape[i + 1]
2123 - kernel_shape[i]
2124 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2125 + padding[i * 2 + 0]
2126 + padding[i * 2 + 1]
2127 ) // strides[i] + 1
2128 # return True if any dimension is < 1
2129 if dim < 1:
2130 return True
2131 return False
2132
2133 assert False, f"Unrecognized Op: {opName}"
2134
2135 @staticmethod
2136 def ivNonPositiveOutputShape(**kwargs):
2137 args = kwargs["args"]
2138 output_shape = args[3]
2139 if output_shape[1] <= 0 or output_shape[2] <= 0:
2140 # Negative output shape
2141 return True
2142 return False