blob: 1900d8a75ee9357a44d308438c64c2dd4b53566c [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003import numpy as np
4from generator.tosa_utils import product
5from generator.tosa_utils import usableDTypes
6from generator.tosa_utils import valueToName
7from tosa.DType import DType
8from tosa.Op import Op
9from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010
Matthew Haddone86fd342021-09-07 16:12:21 +010011
12class ErrorIf(object):
13 MaxDimExceeded = "MaxDimExceeded"
14 StrideSmallerEqualZero = "StrideSmallerEqualZero"
15 StrideLargerEqualMax = "StrideLargerEqualMax"
16 StrideLargerDimension = "StrideLargerDimension"
17 OffsetSmallerEqualMin = "OffsetSmallerEqualMin"
18 OffsetLargerEqualMax = "OffsetLargerEqualMax"
19 ShiftNotZero = "ShiftNotZero"
20 ShiftSmallerOne = "ShiftSmallerOne"
21 ShiftLargerEleven = "ShiftLargerEleven"
Matthew Haddon848efb42021-09-09 12:30:53 +010022 WrongInputType = "WrongInputType"
23 WrongOutputType = "WrongOutputType"
24 WrongInputList = "WrongInputList"
25 WrongOutputList = "WrongOutputList"
26 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010027 BatchMismatch = "BatchMismatch"
28 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010029 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000030 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010031 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010032 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010033 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010034 AxisSmallerZero = "AxisSmallerZero"
35 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
37 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010039 KernelSmallerOne = "KernelSmallerOne"
40 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000041 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010042 PadSmallerZero = "PadSmallerZero"
43 PadLargerEqualKernel = "PadLargerEqualKernel"
44 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010045 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
46 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
47 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010048 ScaleNotTrue = "ScaleNotTrue"
49 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010050 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
51 StartSmallerZero = "StartSmallerZero"
52 SizeSmallerEqualZero = "SizeSmallerEqualZero"
53 StartSizeOutsideBounds = "StartSizeOutsideBounds"
54 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
55 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
56 IndexOutsideBounds = "IndexOutsideBounds"
57 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010058 MaxSmallerMin = "MaxSmallerMin"
59 ConcatInputRankMismatch = "ConcatInputRankMismatch"
60 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010061 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010062 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
63 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
64 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
65 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
66 InputListOutputListMismatch = "InputListOutputListMismatch"
67 InputListCondGraphMismatch = "InputListCondGraphMismatch"
68 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
69 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
70 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010071 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
72 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010073
74
75class TosaErrorIfArgGen:
76 @staticmethod
77 def eiResizeErrorIf(
78 testGen,
79 error_name,
80 mode,
81 dtype,
82 shapeList,
83 outputDType,
84 shift,
85 stride,
86 stride_fp,
87 offset,
88 offset_fp,
89 ):
90
91 if outputDType == DType.FLOAT:
92 if error_name == ErrorIf.StrideSmallerEqualZero:
93 stride_fp = testGen.rng.random(size=[2]) - 2
94 elif error_name == ErrorIf.ShiftNotZero:
95 shift = testGen.rng.integers(1, 5)
96 elif error_name == ErrorIf.StrideLargerDimension:
97 shape = shapeList[0]
98 transform_height = testGen.rng.choice([False, True])
99 if transform_height:
100 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
101 else:
102 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
103 else:
104 if error_name == ErrorIf.StrideSmallerEqualZero:
105 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
106 elif error_name == ErrorIf.ShiftSmallerOne:
107 shift = testGen.rng.integers(-3, 1)
108 if shift <= 0:
109 stride = [
110 (16 >> -shift) - 1,
111 (16 >> -shift) - 1,
112 ] # avoids other ERROR_IF checks
113 offset = [
114 (16 >> -shift) - 1,
115 (16 >> -shift) - 1,
116 ] # avoids other ERROR_IF checks
117 else:
118 stride = [
119 (16 << shift) - 1,
120 (16 << shift) - 1,
121 ] # avoids other ERROR_IF checks
122 offset = [
123 (16 << shift) - 1,
124 (16 << shift) - 1,
125 ] # avoids other ERROR_IF checks
126 elif error_name == ErrorIf.ShiftLargerEleven:
127 shift = np.int16(testGen.rng.integers(12, 15))
128 elif error_name == ErrorIf.StrideLargerDimension:
129 shape = shapeList[0]
130 transform_height = testGen.rng.choice([False, True])
131 if transform_height:
132 stride[0] = shape[1] + testGen.rng.integers(1, 10)
133 else:
134 stride[1] = shape[2] + testGen.rng.integers(1, 10)
135 elif error_name == ErrorIf.StrideLargerEqualMax:
136 stride = [(16 << shift) + 1, (16 << shift) + 1]
137 elif error_name == ErrorIf.OffsetLargerEqualMax:
138 offset = [(16 << shift) + 1, (16 << shift) + 1]
139 elif error_name == ErrorIf.OffsetSmallerEqualMin:
140 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
141
142 if error_name == ErrorIf.WrongOutputType:
143 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
144 incorrect_types = (
145 DType.INT4,
146 DType.INT16,
147 DType.INT32,
148 DType.INT48,
149 DType.FLOAT,
150 )
151 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
152 incorrect_types = (
153 DType.INT4,
154 DType.INT8,
155 DType.INT32,
156 DType.INT48,
157 DType.FLOAT,
158 )
159 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
160 incorrect_types = (
161 DType.INT4,
162 DType.INT8,
163 DType.INT16,
164 DType.INT48,
165 DType.FLOAT,
166 )
167 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
168 incorrect_types = (
169 DType.INT4,
170 DType.INT8,
171 DType.INT16,
172 DType.INT32,
173 DType.FLOAT,
174 )
175 elif dtype == DType.FLOAT:
176 incorrect_types = (
177 DType.INT4,
178 DType.INT8,
179 DType.INT16,
180 DType.INT32,
181 DType.INT48,
182 )
183 outputDType = testGen.rng.choice(a=incorrect_types)
184
185 return shift, stride, stride_fp, offset, offset_fp, outputDType
186
187 @staticmethod
188 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
189 if (
190 error_name == ErrorIf.StrideSmallerOne
191 # padding must not exceed the kernel size
192 and pad[0] < kernel[0]
193 and pad[1] < kernel[0]
194 and pad[2] < kernel[1]
195 and pad[3] < kernel[1]
196 ):
197 wrongStride = (
198 testGen.rng.choice([0, -1, -2, -3]),
199 testGen.rng.choice([0, -1, -2, -3]),
200 )
201 return wrongStride, pad, kernel
202 elif error_name == ErrorIf.PadSmallerZero:
203 wrongPad = (
204 testGen.rng.choice([-1, -2, -3]),
205 testGen.rng.choice([-1, -2, -3]),
206 testGen.rng.choice([-1, -2, -3]),
207 testGen.rng.choice([-1, -2, -3]),
208 )
209 return stride, wrongPad, kernel
210 elif error_name == ErrorIf.KernelSmallerOne:
211 wrongKernel = (
212 testGen.rng.choice([0, -1, -2, -3]),
213 testGen.rng.choice([0, -1, -2, -3]),
214 )
215 return stride, pad, wrongKernel
216 elif error_name == ErrorIf.PadLargerEqualKernel:
217 wrongPad = (
218 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
219 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
220 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
221 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
222 )
223 return stride, wrongPad, kernel
224 else:
225 return None, None, None
226
227 @staticmethod
228 def eiRescaleWrongOutputType(input_dtype, output_dtype):
229 if input_dtype == DType.INT8:
230 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
231 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100232 elif input_dtype == DType.INT16:
233 if output_dtype not in [
234 DType.UINT8,
235 DType.INT8,
236 DType.UINT16,
237 DType.INT16,
238 DType.INT32,
239 ]:
240 return True
241 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100242 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
243 return True
244 elif input_dtype == DType.INT48:
245 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
246 return True
247 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100248 if output_dtype not in [DType.INT8, DType.INT16]:
249 return True
250 elif input_dtype == DType.UINT16:
251 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252 return True
253 return False
254
255 @staticmethod
256 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
257 # Mess up input/output tensors for ERROR_IF checks
258 if error_name == "WrongInputList":
259 add_input = testGen.rng.choice([True, False])
260 if add_input:
261 input_list.append("eiDummyInput")
262 else:
263 input_list = input_list[:-1]
264 elif error_name == "WrongOutputList":
265 add_output = testGen.rng.choice([True, False])
266 if add_output:
267 output_list.append("eiDummyOutput")
268 else:
269 output_list = []
270 return input_list, output_list
271
272 @staticmethod
273 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
274 """Restrict the dimensions and overall size of a shape to
275 max_dim and max_items.
276 """
277 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
278 while product(new_shape) > max_items:
279 new_shape = [max(d - 1, 1) for d in new_shape]
280 return new_shape
281
282 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
283 if error_name == ErrorIf.StartSmallerZero:
284 newStart = []
285 for i in range(len(input_shape)):
286 newStart.append(testGen.rng.choice([-3, -2, -1]))
287 return newStart, size
288 elif error_name == ErrorIf.SizeSmallerEqualZero:
289 newSize = []
290 for i in range(len(input_shape)):
291 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
292 return start, newSize
293 elif error_name == ErrorIf.StartSizeOutsideBounds:
294 newStart, newSize = [], []
295 for i in range(len(input_shape)):
296 newStart.append(input_shape[i] - 1)
297 newSize.append(testGen.rng.choice([2, 3, 4]))
298 return newStart, newSize
299 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
300 remove = testGen.rng.choice([True, False])
301 if remove:
302 newStart = start[1:]
303 newSize = size[1:]
304 else:
305 newStart = start
306 newStart.append(1)
307 newSize = size
308 newSize.append(1)
309 return newStart, newSize
310 else:
311 return start, size
312
313 @staticmethod
314 def eiCastErrorIf(testGen, input_dtype):
315 if input_dtype in [DType.BOOL, DType.FLOAT]:
316 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
317 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
318 outputDType = [DType.INT48]
319 else:
320 assert True, f"input_dtype ({input_dtype}) not supported"
321 return outputDType
322
323
324class TosaErrorValidator:
325 @staticmethod
326 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
327 """Check ERROR_IF statements are caught and set the expected result.
328
329 Args:
330 serializer: the serializer to set the expected result in
331 validator_fcns: a sequence of validator functions to verify the result
332 error_name: the name of the ERROR_IF condition to check for
333 kwargs: keyword arguments for the validator functions
334 Returns:
335 True if the result matches the expected result; otherwise False
336 """
337 overall_result = True
338 for val_fcn in validator_fcns:
339 val_result = val_fcn(True, **kwargs)
340 validator_name = val_result["error_name"]
341 error_result = val_result["error_result"]
342 error_reason = val_result["error_reason"]
343
344 # expect an error IFF the error_name and validator_name match
345 expected_result = error_result == (error_name == validator_name)
346 overall_result &= expected_result
347
348 if expected_result and error_result:
349 serializer.setExpectedReturnCode(2, True, desc=error_reason)
350 elif error_result: # and not expected_result
351 print(
352 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
353 f" Expected: {error_name}, Got: {validator_name}"
354 )
355 elif not expected_result: # and not error_result
356 print(
357 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
358 f" Expected: {error_name}"
359 )
360
361 if not expected_result:
362 for k, v in sorted(kwargs.items()):
363 if k != "op":
364 if k.endswith("dtype"):
365 v = valueToName(DType, v)
366 print(f" {k} = {v}")
367
368 return overall_result
369
370 @staticmethod
371 def evWrongInputType(check=False, **kwargs):
372 error_result = False
373
374 # Find the unsupported input data types
375 op = kwargs["op"]
376 input_dtypes = op["types"]
377 allowed_input_dtypes = {
378 t[0] if isinstance(t, list) else t for t in input_dtypes
379 }
380 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
381
382 if op["op"] == Op.CLAMP:
383 wrong_input_dtypes.remove(DType.INT48)
384
385 if check:
386 input_dtype = kwargs["input_dtype"]
387 if input_dtype not in allowed_input_dtypes:
388 error_result = True
389
390 info_dict = {
391 "error_name": ErrorIf.WrongInputType,
392 "error_result": error_result,
393 "error_reason": "Input data type not supported for this operator",
394 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
395 }
396 return info_dict
397
398 @staticmethod
399 def evWrongOutputType(check=False, **kwargs):
400 error_result = False
401
402 if check:
403 input_dtype = kwargs["input_dtype"]
404 output_dtype = kwargs["output_dtype"]
405 op = kwargs["op"]
406
407 if op["op"] == Op.RESIZE:
408 mode = kwargs["mode"]
409 if (
410 (
411 mode == ResizeMode.NEAREST
412 and input_dtype == DType.INT8
413 and output_dtype != DType.INT8
414 )
415 or (
416 mode == ResizeMode.NEAREST
417 and input_dtype == DType.INT16
418 and output_dtype != DType.INT16
419 )
420 or (
421 mode == ResizeMode.BILINEAR
422 and input_dtype == DType.INT8
423 and output_dtype != DType.INT32
424 )
425 or (
426 mode == ResizeMode.BILINEAR
427 and input_dtype == DType.INT16
428 and output_dtype != DType.INT48
429 )
430 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
431 ):
432 error_result = True
433
434 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100435 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
436 input_dtype, output_dtype
437 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100438
439 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
440 if (
441 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
442 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
443 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
444 ):
445 error_result = True
446
447 elif op["op"] == Op.ARGMAX:
448 if (
449 input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
450 and output_dtype != DType.INT32
451 ):
452 error_result = True
453
454 elif op["op"] == Op.MUL:
455 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
456 error_result = True
457 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
458 error_result = True
459
460 elif op["op"] == Op.TABLE:
461 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
462 error_result = True
463 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
464 error_result = True
465
466 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
467 if output_dtype != DType.BOOL:
468 error_result = True
469
470 elif op["op"] == Op.CAST:
471 if (
472 (
473 input_dtype == DType.BOOL
474 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
475 )
476 or (
477 input_dtype == DType.INT8
478 and output_dtype
479 not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
480 )
481 or (
482 input_dtype == DType.INT16
483 and output_dtype
484 not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
485 )
486 or (
487 input_dtype == DType.INT32
488 and output_dtype
489 not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
490 )
491 or (
492 input_dtype == DType.FLOAT
493 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
494 )
495 ):
496 error_result = True
497
498 elif op["op"] in {
499 Op.CONV2D,
500 Op.CONV3D,
501 Op.DEPTHWISE_CONV2D,
502 Op.TRANSPOSE_CONV2D,
503 }:
504 if (
505 input_dtype == DType.INT8
506 and output_dtype != DType.INT32
507 or input_dtype == DType.INT16
508 and output_dtype != DType.INT48
509 or input_dtype == DType.FLOAT
510 and output_dtype != DType.FLOAT
511 ):
512 error_result = True
513 # invalid input types are ignored, to avoid reporting multiple errors
514
515 else:
516 if output_dtype != input_dtype:
517 error_result = True
518
519 info_dict = {
520 "error_name": ErrorIf.WrongOutputType,
521 "error_result": error_result,
522 "error_reason": (
523 "Output data type not supported for this configuration of operator"
524 ),
525 "param_reqs": {"rank": None, "dtype": None, "shape": None},
526 }
527 return info_dict
528
529 @staticmethod
530 def evWrongRank(check=False, **kwargs):
531 all_ranks = (1, 2, 3, 4, 5)
532
533 # Make a list of incorrect ranks
534 assert "op" in kwargs
535 op = kwargs["op"]
536 rmin, rmax = op["rank"]
537 rank_range = range(rmin, rmax + 1)
538 incorrect_ranks = list(set(all_ranks) - set(rank_range))
539 # Remove small incorrect ranks to avoid index errors
540 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
541 # Set minimum incorrect rank to 3 to avoid index error
542 if op["op"] in [Op.RESIZE]:
543 incorrect_ranks = [3, 5]
544 elif op["op"] in [Op.TRANSPOSE]:
545 incorrect_ranks = [7, 8]
546 elif op["op"] in [Op.CONV3D]:
547 incorrect_ranks = [6, 7]
548
549 error_name = ErrorIf.WrongRank
550 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
551 error_result = False
552 error_reason = "Rank not supported for this operator"
553
554 if check:
555 input_shape = kwargs["input_shape"]
556
557 if (
558 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
559 and len(input_shape) != 4
560 ):
561 error_result = True
562 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
563 error_result = True
564 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
565 error_result = True
566 else:
567 if len(input_shape) not in rank_range:
568 error_result = True
569
570 info_dict = {
571 "error_name": error_name,
572 "error_result": error_result,
573 "error_reason": error_reason,
574 "param_reqs": param_reqs,
575 }
576 return info_dict
577
578 @staticmethod
579 def evWrongInputList(check=False, **kwargs):
580 error_name = ErrorIf.WrongInputList
581 param_reqs = {"rank": None, "dtype": None, "shape": None}
582 error_result = False
583 error_reason = "Op input list does not match expected input"
584
585 if check:
586 op = kwargs["op"]
587 input_list = kwargs["input_list"]
588 num_operands = kwargs["num_operands"]
589 if op["op"] in [Op.SCATTER, Op.GATHER]:
590 # SCATTER/GATHER add an indices input tensor in their build functions
591 num_operands += 1
592 if len(input_list) != num_operands:
593 error_result = True
594
595 info_dict = {
596 "error_name": error_name,
597 "error_result": error_result,
598 "error_reason": error_reason,
599 "param_reqs": param_reqs,
600 }
601 return info_dict
602
603 @staticmethod
604 def evWrongOutputList(check=False, **kwargs):
605 error_name = ErrorIf.WrongOutputList
606 param_reqs = {"rank": None, "dtype": None, "shape": None}
607 error_result = False
608 error_reason = "Op output list does not match expected output"
609
610 if check:
611 output_list = kwargs["output_list"]
612 # Note this will be incorrect if an operator returns more than one output
613 if len(output_list) != 1:
614 error_result = True
615
616 info_dict = {
617 "error_name": error_name,
618 "error_result": error_result,
619 "error_reason": error_reason,
620 "param_reqs": param_reqs,
621 }
622 return info_dict
623
624 @staticmethod
625 def evMaxDimExceeded(check=False, **kwargs):
626 error_name = ErrorIf.MaxDimExceeded
627 param_reqs = {
628 "rank": [4, 4],
629 "dtype": [DType.INT8],
630 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
631 }
632 error_result = False
633 error_reason = (
634 "At least one maximum dimension is greater than or equal to 16384"
635 )
636
637 if check:
638 input_shape = kwargs["input_shape"]
639 output_shape = kwargs["output_shape"] # Note this is just (OH, OW)
640 if (
641 (input_shape[1] >= 16384)
642 or (input_shape[2] >= 16384)
643 or (output_shape[0] >= 16384)
644 or (output_shape[1] >= 16384)
645 ):
646 error_result = True
647
648 info_dict = {
649 "error_name": error_name,
650 "error_result": error_result,
651 "error_reason": error_reason,
652 "param_reqs": param_reqs,
653 }
654 return info_dict
655
656 @staticmethod
657 def evBatchMismatch(check=False, **kwargs):
658 error_name = ErrorIf.BatchMismatch
659 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
660 error_result = False
661 error_reason = "Input batch size not equal to output batch size"
662
663 assert "op" in kwargs
664 op = kwargs["op"]
665 rmin, rmax = op["rank"]
666 rank_range = range(rmin, rmax + 1)
667
668 if check:
669 input_shape = kwargs["input_shape"]
670 output_shape = kwargs[
671 "result_tensor"
672 ].shape # Note this is just (N, OH, OW, C)
673
674 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
675 error_result = True
676
677 info_dict = {
678 "error_name": error_name,
679 "error_result": error_result,
680 "error_reason": error_reason,
681 "param_reqs": param_reqs,
682 }
683 return info_dict
684
685 @staticmethod
686 def evChannelMismatch(check=False, **kwargs):
687 error_name = ErrorIf.ChannelMismatch
688 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
689 error_result = False
690 error_reason = "Input channel size not equal to output channel size"
691
692 assert "op" in kwargs
693 op = kwargs["op"]
694 rmin, rmax = op["rank"]
695 rank_range = range(rmin, rmax + 1)
696
697 if check:
698 input_shape = kwargs["input_shape"]
699 output_shape = kwargs[
700 "result_tensor"
701 ].shape # Note this is just (N, OH, OW, C)
702 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
703 error_result = True
704
705 info_dict = {
706 "error_name": error_name,
707 "error_result": error_result,
708 "error_reason": error_reason,
709 "param_reqs": param_reqs,
710 }
711 return info_dict
712
713 @staticmethod
714 def evStrideSmallerEqualZero(check=False, **kwargs):
715 error_name = ErrorIf.StrideSmallerEqualZero
716 param_reqs = {"rank": None, "dtype": None, "shape": None}
717 error_result = False
718 error_reason = "Stride value smaller than or equal zero"
719
720 if check:
721 input_dtype = kwargs["input_dtype"]
722 output_dtype = kwargs["output_dtype"]
723 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
724 stride = kwargs["stride"] # Work around wrong input/output type tests
725 elif output_dtype == DType.FLOAT:
726 stride = kwargs["stride_fp"]
727 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
728 stride = kwargs[
729 "stride_fp"
730 ] # Work around wrong input/output type tests
731 else:
732 stride = kwargs["stride"]
733
734 if min(stride) <= 0:
735 error_result = True
736
737 info_dict = {
738 "error_name": error_name,
739 "error_result": error_result,
740 "error_reason": error_reason,
741 "param_reqs": param_reqs,
742 }
743 return info_dict
744
745 @staticmethod
746 def evStrideLargerEqualMax(check=False, **kwargs):
747 error_name = ErrorIf.StrideLargerEqualMax
748 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
749 error_result = False
750 error_reason = "Stride value larger than or equal to maximum value"
751
752 if check:
753 shift = kwargs["shift"]
754 input_dtype = kwargs["input_dtype"]
755 stride = kwargs["stride"]
756 if input_dtype in [DType.INT8, DType.INT16]:
757 if shift >= 0 and (
758 stride[0] >= (16 << shift) or stride[1] >= (16 << shift)
759 ):
760 error_result = True
761 elif shift < 0 and (
762 stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)
763 ):
764 error_result = True
765
766 info_dict = {
767 "error_name": error_name,
768 "error_result": error_result,
769 "error_reason": error_reason,
770 "param_reqs": param_reqs,
771 }
772 return info_dict
773
774 @staticmethod
775 def evStrideLargerDimension(check=False, **kwargs):
776 error_name = ErrorIf.StrideLargerDimension
777 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
778 error_result = False
779 error_reason = "Stride value larger than or equal to H/W dimension"
780
781 if check:
782 shape = kwargs["input_shape"]
783 input_dtype = kwargs["input_dtype"]
784 stride = kwargs["stride_fp"]
785
786 if (
787 input_dtype == DType.FLOAT
788 and (stride[0] > shape[1])
789 or (stride[1] > shape[2])
790 ):
791 error_result = True
792
793 info_dict = {
794 "error_name": error_name,
795 "error_result": error_result,
796 "error_reason": error_reason,
797 "param_reqs": param_reqs,
798 }
799 return info_dict
800
801 @staticmethod
802 def evOffsetSmallerEqualMin(check=False, **kwargs):
803 error_name = ErrorIf.OffsetSmallerEqualMin
804 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
805 error_result = False
806 error_reason = "Offset value smaller than or equal to minimum value"
807
808 if check:
809 shift = kwargs["shift"]
810 output_dtype = kwargs["output_dtype"]
811 if output_dtype == DType.FLOAT:
812 offset = kwargs["offset_fp"]
813 else:
814 offset = kwargs["offset"]
815
816 if shift >= 0 and (
817 offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)
818 ):
819 error_result = True
820 elif shift < 0 and (
821 offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)
822 ):
823 error_result = True
824
825 info_dict = {
826 "error_name": error_name,
827 "error_result": error_result,
828 "error_reason": error_reason,
829 "param_reqs": param_reqs,
830 }
831 return info_dict
832
833 @staticmethod
834 def evOffsetLargerEqualMax(check=False, **kwargs):
835 error_name = ErrorIf.OffsetLargerEqualMax
836 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
837 error_result = False
838 error_reason = "Offset value larger than or equal to maximum value"
839
840 if check:
841 shift = kwargs["shift"]
842 output_dtype = kwargs["output_dtype"]
843 if output_dtype == DType.FLOAT:
844 offset = kwargs["offset_fp"]
845 else:
846 offset = kwargs["offset"]
847
848 if shift >= 0:
849 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
850 error_result = True
851
852 if shift >= 0 and (
853 offset[0] >= (16 << shift) or offset[1] >= (16 << shift)
854 ):
855 error_result = True
856 elif shift < 0 and (
857 offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)
858 ):
859 error_result = True
860
861 info_dict = {
862 "error_name": error_name,
863 "error_result": error_result,
864 "error_reason": error_reason,
865 "param_reqs": param_reqs,
866 }
867 return info_dict
868
869 @staticmethod
870 def evShiftNotZero(check=False, **kwargs):
871 error_name = ErrorIf.ShiftNotZero
872 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
873 error_result = False
874 error_reason = "Shift value must be zero for float input"
875
876 if check:
877 shift = kwargs["shift"]
878 input_dtype = kwargs["input_dtype"]
879 output_dtype = kwargs["output_dtype"]
880 if (
881 input_dtype == DType.FLOAT
882 and output_dtype == DType.FLOAT
883 and shift != 0
884 ):
885 error_result = True
886
887 info_dict = {
888 "error_name": error_name,
889 "error_result": error_result,
890 "error_reason": error_reason,
891 "param_reqs": param_reqs,
892 }
893 return info_dict
894
895 @staticmethod
896 def evShiftSmallerOne(check=False, **kwargs):
897 error_name = ErrorIf.ShiftSmallerOne
898 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
899 error_result = False
900 error_reason = "Shift value smaller than one"
901
902 if check:
903 shift = kwargs["shift"]
904 input_dtype = kwargs["input_dtype"]
905 output_dtype = kwargs["output_dtype"]
906 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
907 error_result = True
908
909 info_dict = {
910 "error_name": error_name,
911 "error_result": error_result,
912 "error_reason": error_reason,
913 "param_reqs": param_reqs,
914 }
915 return info_dict
916
917 @staticmethod
918 def evShiftLargerEleven(check=False, **kwargs):
919 error_name = ErrorIf.ShiftLargerEleven
920 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
921 error_result = False
922 error_reason = "Shift value larger than eleven"
923
924 if check:
925 shift = kwargs["shift"]
926 if shift > 11:
927 error_result = True
928
929 info_dict = {
930 "error_name": error_name,
931 "error_result": error_result,
932 "error_reason": error_reason,
933 "param_reqs": param_reqs,
934 }
935 return info_dict
936
937 @staticmethod
938 def evRankMismatch(check=False, **kwargs):
939 error_name = ErrorIf.RankMismatch
940 param_reqs = {"rank": None, "dtype": None, "shape": None}
941 error_result = False
942 error_reason = "Input Rank does not match output rank"
943
944 if check:
945 input1_shape = kwargs["input1"].shape
946 input2_shape = kwargs["input2"].shape
947 # In case of SELECT op
948 input3_shape = (
949 kwargs["input3"].shape if "input3" in kwargs else input2_shape
950 )
951 output_shape = kwargs["result_tensor"].shape
952 if (
953 (len(input1_shape) != len(output_shape))
954 or (len(input2_shape) != len(output_shape))
955 or (len(input3_shape) != len(output_shape))
956 ):
957 error_result = True
958
959 info_dict = {
960 "error_name": error_name,
961 "error_result": error_result,
962 "error_reason": error_reason,
963 "param_reqs": param_reqs,
964 }
965 return info_dict
966
967 @staticmethod
968 def evDimensionMismatch(check=False, **kwargs):
969 error_name = ErrorIf.DimensionMismatch
970 param_reqs = {"rank": None, "dtype": None, "shape": None}
971 error_result = False
972 error_reason = "Input Dimensions do not match output"
973
974 if check:
975 input1_shape = kwargs["input1"].shape
976 input2_shape = kwargs["input2"].shape
977 # In case of SELECT op
978 input3_shape = (
979 kwargs["input3"].shape if "input3" in kwargs else input2_shape
980 )
981 output_shape = kwargs["result_tensor"].shape
982 for i in range(
983 min(len(input1_shape), len(input2_shape), len(input3_shape))
984 ):
985 if (
986 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
987 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
988 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
989 ):
990 error_result = True
991
992 info_dict = {
993 "error_name": error_name,
994 "error_result": error_result,
995 "error_reason": error_reason,
996 "param_reqs": param_reqs,
997 }
998 return info_dict
999
1000 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001001 def _getZeroPoint(qinfo, index):
1002 """Return zero point value from quantization info.
1003
1004 Generally input_zp is index 0, output_zp is index 1
1005 """
1006 if isinstance(qinfo, tuple):
1007 zero_point = qinfo[index]
1008 else:
1009 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1010 zero_point = qinfo.ints[index][1]
1011 return zero_point
1012
1013 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001014 def evInputZeroPointNotZero(check=False, **kwargs):
1015 op = kwargs["op"]
1016 error_result = False
1017
1018 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001019 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001020
1021 # This does not apply to quantizable types
1022 inputDtypes = [
1023 dtype
1024 for dtype in op["types"]
1025 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1026 or (not isinstance(dtype, list) and dtype not in qTypes)
1027 ]
1028
1029 if check:
1030 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001031 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001032 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001033 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001034 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001035 (kwargs["input_dtype"], input_zero_point),
1036 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001037 ):
1038 if dtype not in qTypes and zp != 0:
1039 error_result = True
1040 break
1041 else:
1042 error_result = input_dtype not in qTypes and input_zero_point != 0
1043
1044 info_dict = {
1045 "error_name": ErrorIf.InputZeroPointNotZero,
1046 "error_result": error_result,
1047 "error_reason": "Input DType not INT8 and zero point not 0",
1048 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1049 }
1050 return info_dict
1051
1052 @staticmethod
1053 def evWeightZeroPointNotZero(check=False, **kwargs):
1054 op = kwargs["op"]
1055
1056 # exclude inputs with INT8 weights
1057 inputDtypes = [
1058 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1059 ]
1060
1061 error_name = ErrorIf.WeightZeroPointNotZero
1062 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1063 error_result = False
1064 error_reason = "Weight DType not INT8 and zero point not 0"
1065
1066 if check:
1067 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001068 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001069 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1070 error_result = True
1071
1072 info_dict = {
1073 "error_name": error_name,
1074 "error_result": error_result,
1075 "error_reason": error_reason,
1076 "param_reqs": param_reqs,
1077 }
1078 return info_dict
1079
1080 @staticmethod
1081 def evOutputZeroPointNotZero(check=False, **kwargs):
1082 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001083 inputDtypes = [
1084 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1085 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001086
1087 error_name = ErrorIf.OutputZeroPointNotZero
1088 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1089 error_result = False
1090 error_reason = "Output DType not INT8 and zero point not 0"
1091
1092 if check:
1093 input_dtype = kwargs["input_dtype"]
1094 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001095 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001096 if op["op"] == Op.AVG_POOL2D:
1097 if input_dtype != DType.INT8 and output_zero_point != 0:
1098 error_result = True
1099 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001100 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1101 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001102 ):
1103 error_result = True
1104
1105 info_dict = {
1106 "error_name": error_name,
1107 "error_result": error_result,
1108 "error_reason": error_reason,
1109 "param_reqs": param_reqs,
1110 }
1111 return info_dict
1112
1113 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001114 def evU16InputZeroPointNotValid(check=False, **kwargs):
1115 error_name = ErrorIf.U16InputZeroPointNotValid
1116 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1117 error_result = False
1118 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1119
1120 if check:
1121 input_dtype = kwargs["input_dtype"]
1122 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1123 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1124 0,
1125 32768,
1126 ]
1127
1128 info_dict = {
1129 "error_name": error_name,
1130 "error_result": error_result,
1131 "error_reason": error_reason,
1132 "param_reqs": param_reqs,
1133 }
1134 return info_dict
1135
1136 @staticmethod
1137 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1138 error_name = ErrorIf.U16OutputZeroPointNotValid
1139 param_reqs = {"rank": None, "dtype": None, "shape": None}
1140 error_result = False
1141 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1142
1143 if check:
1144 output_dtype = kwargs["output_dtype"]
1145 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1146
1147 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1148 0,
1149 32768,
1150 ]
1151
1152 info_dict = {
1153 "error_name": error_name,
1154 "error_result": error_result,
1155 "error_reason": error_reason,
1156 "param_reqs": param_reqs,
1157 }
1158 return info_dict
1159
1160 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161 def evAxisSmallerZero(check=False, **kwargs):
1162 error_name = ErrorIf.AxisSmallerZero
1163 param_reqs = {"rank": None, "dtype": None, "shape": None}
1164 error_result = False
1165 error_reason = "Axis smaller than zero"
1166
1167 if check:
1168 axis = kwargs["axis"]
1169 if axis < 0:
1170 error_result = True
1171
1172 info_dict = {
1173 "error_name": error_name,
1174 "error_result": error_result,
1175 "error_reason": error_reason,
1176 "param_reqs": param_reqs,
1177 }
1178 return info_dict
1179
1180 @staticmethod
1181 def evAxisLargerRank(check=False, **kwargs):
1182 error_name = ErrorIf.AxisLargerRank
1183 param_reqs = {"rank": None, "dtype": None, "shape": None}
1184 error_result = False
1185 error_reason = "Axis larger than rank"
1186
1187 if check:
1188 axis = kwargs["axis"]
1189 shape = kwargs["input_shape"]
1190 if axis > len(shape):
1191 error_result = True
1192
1193 info_dict = {
1194 "error_name": error_name,
1195 "error_result": error_result,
1196 "error_reason": error_reason,
1197 "param_reqs": param_reqs,
1198 }
1199 return info_dict
1200
1201 @staticmethod
1202 def evShapeOfAxisNotOne(check=False, **kwargs):
1203 error_name = ErrorIf.ShapeOfAxisNotOne
1204 param_reqs = {"rank": None, "dtype": None, "shape": None}
1205 error_result = False
1206 error_reason = "shape[axis] is not equal to 1"
1207
1208 if check:
1209 axis = kwargs["axis"]
1210 shape = kwargs["output_shape"]
1211 if (0 <= axis < len(shape)) and shape[axis] != 1:
1212 error_result = True
1213
1214 info_dict = {
1215 "error_name": error_name,
1216 "error_result": error_result,
1217 "error_reason": error_reason,
1218 "param_reqs": param_reqs,
1219 }
1220 return info_dict
1221
1222 @staticmethod
1223 def evPadSmallerZero(check=False, **kwargs):
1224 error_name = ErrorIf.PadSmallerZero
1225 param_reqs = {"rank": None, "dtype": None, "shape": None}
1226 error_result = False
1227 error_reason = "At least one pad is smaller than zero"
1228
1229 if check:
1230 op = kwargs["op"]
1231 pad = kwargs["pad"]
1232 if op["op"] == Op.PAD:
1233 for padding in pad:
1234 if min(padding) < 0:
1235 error_result = True
1236 else:
1237 if min(pad) < 0:
1238 error_result = True
1239
1240 info_dict = {
1241 "error_name": error_name,
1242 "error_result": error_result,
1243 "error_reason": error_reason,
1244 "param_reqs": param_reqs,
1245 }
1246 return info_dict
1247
1248 @staticmethod
1249 def evPadLargerEqualKernel(check=False, **kwargs):
1250 error_name = ErrorIf.PadLargerEqualKernel
1251 param_reqs = {"rank": None, "dtype": None, "shape": None}
1252 error_result = False
1253 error_reason = "At least one pad is larger than kernel dimension"
1254
1255 if check:
1256 pad = kwargs["pad"]
1257 kernel = kwargs["kernel"]
1258 if min(pad) > 0 and min(kernel) > 1:
1259 if (
1260 pad[0] >= kernel[0]
1261 or pad[1] >= kernel[0]
1262 or pad[2] >= kernel[1]
1263 or pad[3] >= kernel[1]
1264 ):
1265 error_result = True
1266
1267 info_dict = {
1268 "error_name": error_name,
1269 "error_result": error_result,
1270 "error_reason": error_reason,
1271 "param_reqs": param_reqs,
1272 }
1273 return info_dict
1274
1275 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001276 def checkPoolingParams(kernel, stride, pad):
1277 return (
1278 min(kernel) >= 1
1279 and min(stride) >= 1
1280 and min(pad) >= 0
1281 and not (
1282 pad[0] >= kernel[0]
1283 or pad[1] >= kernel[0]
1284 or pad[2] >= kernel[1]
1285 or pad[3] >= kernel[1]
1286 )
1287 )
1288
1289 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001290 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1291 error_name = ErrorIf.PoolingOutputShapeMismatch
1292 param_reqs = {"rank": None, "dtype": None, "shape": None}
1293 error_result = False
1294 error_reason = (
1295 "Mismatch between output shape provided and expected output shape"
1296 )
1297
1298 if check:
1299 pad = kwargs["pad"]
1300 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1301
1302 kernel = kwargs["kernel"]
1303 kernel_y, kernel_x = kernel[0], kernel[1]
1304
1305 input_shape = kwargs["input_shape"]
1306 IH, IW = input_shape[1], input_shape[2]
1307
1308 output_shape = kwargs["output_shape"]
1309 OH, OW = output_shape[1], output_shape[2]
1310
1311 stride = kwargs["stride"]
1312 stride_y, stride_x = stride[0], stride[1]
1313
1314 # calculate correct height, width dimensions
1315 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001316 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1317 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001318
1319 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001320 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001321
1322 if params_valid and (OH != y_correct or OW != x_correct):
1323 error_result = True
1324
1325 info_dict = {
1326 "error_name": error_name,
1327 "error_result": error_result,
1328 "error_reason": error_reason,
1329 "param_reqs": param_reqs,
1330 }
1331 return info_dict
1332
1333 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001334 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1335 error_name = ErrorIf.PoolingOutputShapeNonInteger
1336 param_reqs = {"rank": None, "dtype": None, "shape": None}
1337 error_result = False
1338 error_reason = "Parameters do not yield exact integer output dimensions"
1339
1340 if check:
1341 pad = kwargs["pad"]
1342 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1343
1344 kernel = kwargs["kernel"]
1345 kernel_y, kernel_x = kernel[0], kernel[1]
1346
1347 input_shape = kwargs["input_shape"]
1348 IH, IW = input_shape[1], input_shape[2]
1349
1350 stride = kwargs["stride"]
1351 stride_y, stride_x = stride[0], stride[1]
1352
1353 # calculate remainder of height, width dimensions
1354 if stride_x != 0 and stride_y != 0:
1355 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1356 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1357
1358 # ensure parameters are valid
1359 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1360 if params_valid and (y_remainder != 0 or x_remainder != 0):
1361 error_result = True
1362
1363 info_dict = {
1364 "error_name": error_name,
1365 "error_result": error_result,
1366 "error_reason": error_reason,
1367 "param_reqs": param_reqs,
1368 }
1369 return info_dict
1370
1371 @staticmethod
1372 def checkConvParams(weight_shape, stride, pad, dilation):
1373 return (
1374 # Check kernel sizes
1375 min(weight_shape[1:-1]) >= 1
1376 and min(stride) >= 1
1377 and min(pad) >= 0
1378 and (dilation is None or min(dilation) >= 1)
1379 )
1380
1381 @staticmethod
1382 def evConvOutputShapeMismatch(check=False, **kwargs):
1383 error_name = ErrorIf.ConvOutputShapeMismatch
1384 param_reqs = {"rank": None, "dtype": None, "shape": None}
1385 error_result = False
1386 error_reason = (
1387 "Mismatch between output shape provided and expected output shape"
1388 )
1389
1390 if check:
1391 op = kwargs["op"]
1392 pad = kwargs["pad"]
1393 weight_shape = kwargs["weight_shape"]
1394 input_shape = kwargs["input_shape"]
1395 output_shape = kwargs["output_shape"]
1396 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1397 stride = kwargs["stride"]
1398
1399 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1400
1401 # calculate correct dimensions
1402 dims_correct = []
1403 if min(stride) > 0:
1404 for index in range(len(stride)):
1405 pad_offset = index * 2
1406 if op["op"] == Op.TRANSPOSE_CONV2D:
1407 dims_correct.append(
1408 (input_shape[index + 1] - 1) * stride[index]
1409 - pad[pad_offset]
1410 - pad[pad_offset + 1]
1411 + weight_shape[index + kernel_offset]
1412 )
1413 else:
1414 dims_correct.append(
1415 (
1416 input_shape[index + 1]
1417 - 1
1418 + pad[pad_offset]
1419 + pad[pad_offset + 1]
1420 - (weight_shape[index + kernel_offset] - 1)
1421 * dilation[index]
1422 )
1423 // stride[index]
1424 + 1
1425 )
1426
1427 # ensure parameters are valid
1428 params_valid = TosaErrorValidator.checkConvParams(
1429 weight_shape, stride, pad, dilation
1430 )
1431
1432 if params_valid and output_shape[1:-1] != dims_correct:
1433 error_result = True
1434
1435 info_dict = {
1436 "error_name": error_name,
1437 "error_result": error_result,
1438 "error_reason": error_reason,
1439 "param_reqs": param_reqs,
1440 }
1441 return info_dict
1442
1443 @staticmethod
1444 def evConvOutputShapeNonInteger(check=False, **kwargs):
1445 error_name = ErrorIf.ConvOutputShapeNonInteger
1446 param_reqs = {"rank": None, "dtype": None, "shape": None}
1447 error_result = False
1448 error_reason = "Parameters do not yield exact integer output dimensions"
1449
1450 if check:
1451 op = kwargs["op"]
1452 pad = kwargs["pad"]
1453 weight_shape = kwargs["weight_shape"]
1454 input_shape = kwargs["input_shape"]
1455 dilation = kwargs["dilation"]
1456 stride = kwargs["stride"]
1457
1458 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1459
1460 # calculate correct height, width dimensions
1461 remainders = []
1462 if min(stride) > 0:
1463 for index in range(len(stride)):
1464 pad_offset = index * 2
1465 remainders.append(
1466 (
1467 input_shape[index + 1]
1468 - 1
1469 + pad[pad_offset]
1470 + pad[pad_offset + 1]
1471 - (weight_shape[index + kernel_offset] - 1)
1472 * dilation[index]
1473 )
1474 % stride[index]
1475 )
1476
1477 # ensure parameters are valid
1478 params_valid = TosaErrorValidator.checkConvParams(
1479 weight_shape, stride, pad, dilation
1480 )
1481 if params_valid and max(remainders) > 0:
1482 error_result = True
1483
1484 info_dict = {
1485 "error_name": error_name,
1486 "error_result": error_result,
1487 "error_reason": error_reason,
1488 "param_reqs": param_reqs,
1489 }
1490 return info_dict
1491
1492 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001493 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1494 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1495 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1496 error_result = False
1497 error_reason = (
1498 "Mismatch between output shape provided and expected output shape"
1499 )
1500
1501 if check:
1502 output_shape = kwargs["output_shape"]
1503 input_shape = kwargs["input_shape"]
1504 axis = kwargs["axis"]
1505
1506 dimension_match = True
1507 axis_shift = 0
1508
1509 # Check that rank is correct before trying to check dimensions
1510 if (len(input_shape) - 1) == len(output_shape):
1511 for i in range(len(input_shape)):
1512 if i == axis:
1513 axis_shift = 1
1514 continue
1515 if input_shape[i] != output_shape[i - axis_shift]:
1516 dimension_match = False
1517
1518 if not dimension_match:
1519 error_result = True
1520
1521 info_dict = {
1522 "error_name": error_name,
1523 "error_result": error_result,
1524 "error_reason": error_reason,
1525 "param_reqs": param_reqs,
1526 }
1527 return info_dict
1528
1529 @staticmethod
1530 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1531 error_name = ErrorIf.ArgmaxOutputRankMismatch
1532 param_reqs = {"rank": None, "dtype": None, "shape": None}
1533 error_result = False
1534 error_reason = (
1535 "Mismatch between output shape provided and expected output shape"
1536 )
1537
1538 if check:
1539 output_shape = kwargs["output_shape"]
1540 input_shape = kwargs["input_shape"]
1541 axis = kwargs["axis"]
1542 valid_params = axis >= 0 and axis < len(input_shape)
1543
1544 if valid_params and (len(input_shape) - 1) != len(output_shape):
1545 error_result = True
1546
1547 info_dict = {
1548 "error_name": error_name,
1549 "error_result": error_result,
1550 "error_reason": error_reason,
1551 "param_reqs": param_reqs,
1552 }
1553 return info_dict
1554
1555 @staticmethod
1556 def evKernelSmallerOne(check=False, **kwargs):
1557 error_name = ErrorIf.KernelSmallerOne
1558 param_reqs = {"rank": None, "dtype": None, "shape": None}
1559 error_result = False
1560 error_reason = "At least one kernel dimension is smaller than zero"
1561
1562 if check:
1563 kernel = kwargs["kernel"]
1564 if min(kernel) < 1:
1565 error_result = True
1566
1567 info_dict = {
1568 "error_name": error_name,
1569 "error_result": error_result,
1570 "error_reason": error_reason,
1571 "param_reqs": param_reqs,
1572 }
1573 return info_dict
1574
1575 @staticmethod
1576 def evStrideSmallerOne(check=False, **kwargs):
1577 error_name = ErrorIf.StrideSmallerOne
1578 param_reqs = {"rank": None, "dtype": None, "shape": None}
1579 error_result = False
1580 error_reason = "At least one stride dimension is smaller than zero"
1581
1582 if check:
1583 stride = kwargs["stride"]
1584 if min(stride) < 1:
1585 error_result = True
1586
1587 info_dict = {
1588 "error_name": error_name,
1589 "error_result": error_result,
1590 "error_reason": error_reason,
1591 "param_reqs": param_reqs,
1592 }
1593 return info_dict
1594
1595 @staticmethod
1596 def evDilationSmallerOne(check=False, **kwargs):
1597 error_result = check and min(kwargs["dilation"]) < 1
1598 return {
1599 "error_name": ErrorIf.DilationSmallerOne,
1600 "error_reason": "At least one dilation is smaller than one",
1601 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1602 "error_result": error_result,
1603 }
1604
1605 @staticmethod
1606 def evScaleTrue(check=False, **kwargs):
1607 error_name = ErrorIf.ScaleTrue
1608 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1609 error_result = False
1610 error_reason = "Scale set to true but input type is INT48"
1611
1612 if check:
1613 input_dtype = kwargs["input_dtype"]
1614 scale32 = kwargs["scale32"]
1615 if scale32 and input_dtype == DType.INT48:
1616 error_result = True
1617
1618 info_dict = {
1619 "error_name": error_name,
1620 "error_result": error_result,
1621 "error_reason": error_reason,
1622 "param_reqs": param_reqs,
1623 }
1624 return info_dict
1625
1626 @staticmethod
1627 def evScaleNotTrue(check=False, **kwargs):
1628 error_name = ErrorIf.ScaleNotTrue
1629 param_reqs = {"rank": None, "dtype": None, "shape": None}
1630 error_result = False
1631 error_reason = "Scale set to false but double round set to true"
1632
1633 if check:
1634 scale32 = kwargs["scale32"]
1635 double_round = kwargs["double_round"]
1636 if not scale32 and double_round:
1637 error_result = True
1638
1639 info_dict = {
1640 "error_name": error_name,
1641 "error_result": error_result,
1642 "error_reason": error_reason,
1643 "param_reqs": param_reqs,
1644 }
1645 return info_dict
1646
1647 @staticmethod
1648 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1649 error_name = ErrorIf.TensorSizeInputOutputMismatch
1650 param_reqs = {"rank": None, "dtype": None, "shape": None}
1651 error_result = False
1652 error_reason = "Input tensor size does not match output tensor size"
1653
1654 if check:
1655 input_shape = kwargs["input_shape"]
1656 output_shape = kwargs["output_shape"]
1657 input_size = np.prod(input_shape)
1658 output_size = np.prod(output_shape)
1659 if input_size != output_size:
1660 error_result = True
1661
1662 info_dict = {
1663 "error_name": error_name,
1664 "error_result": error_result,
1665 "error_reason": error_reason,
1666 "param_reqs": param_reqs,
1667 }
1668 return info_dict
1669
1670 @staticmethod
1671 def evStartSmallerZero(check=False, **kwargs):
1672 error_name = ErrorIf.StartSmallerZero
1673 param_reqs = {"rank": None, "dtype": None, "shape": None}
1674 error_result = False
1675 error_reason = "Starting point smaller than zero"
1676
1677 if check:
1678 input_shape = kwargs["input_shape"]
1679 start = kwargs["start"]
1680 rank = len(input_shape)
1681 if len(start) == rank:
1682 for index in range(rank):
1683 if start[index] < 0:
1684 error_result = True
1685
1686 info_dict = {
1687 "error_name": error_name,
1688 "error_result": error_result,
1689 "error_reason": error_reason,
1690 "param_reqs": param_reqs,
1691 }
1692 return info_dict
1693
1694 @staticmethod
1695 def evSizeSmallerEqualZero(check=False, **kwargs):
1696 error_name = ErrorIf.SizeSmallerEqualZero
1697 param_reqs = {"rank": None, "dtype": None, "shape": None}
1698 error_result = False
1699 error_reason = "Size smaller than or equal to zero"
1700
1701 if check:
1702 input_shape = kwargs["input_shape"]
1703 size = kwargs["size"]
1704 rank = len(input_shape)
1705 if len(size) == rank:
1706 for index in range(rank):
1707 if size[index] <= 0:
1708 error_result = True
1709
1710 info_dict = {
1711 "error_name": error_name,
1712 "error_result": error_result,
1713 "error_reason": error_reason,
1714 "param_reqs": param_reqs,
1715 }
1716 return info_dict
1717
1718 @staticmethod
1719 def evStartSizeOutsideBounds(check=False, **kwargs):
1720 error_name = ErrorIf.StartSizeOutsideBounds
1721 param_reqs = {"rank": None, "dtype": None, "shape": None}
1722 error_result = False
1723 error_reason = "starting point plus size larger than input dimension"
1724
1725 if check:
1726 input_shape = kwargs["input_shape"]
1727 start = kwargs["start"]
1728 size = kwargs["size"]
1729 rank = len(input_shape)
1730 if len(start) == rank and len(size) == rank:
1731 for index in range(rank):
1732 if start[index] + size[index] > input_shape[index]:
1733 error_result = True
1734
1735 info_dict = {
1736 "error_name": error_name,
1737 "error_result": error_result,
1738 "error_reason": error_reason,
1739 "param_reqs": param_reqs,
1740 }
1741 return info_dict
1742
1743 @staticmethod
1744 def evSizeOutputShapeMismatch(check=False, **kwargs):
1745 error_name = ErrorIf.SizeOutputShapeMismatch
1746 param_reqs = {"rank": None, "dtype": None, "shape": None}
1747 error_result = False
1748 error_reason = "Size does not match output dimension"
1749
1750 if check:
1751 input_shape = kwargs["input_shape"]
1752 output_shape = kwargs["output_shape"]
1753 size = kwargs["size"]
1754 rank = len(input_shape)
1755 if len(size) == rank:
1756 for index in range(rank):
1757 if size[index] != output_shape[index]:
1758 error_result = True
1759
1760 info_dict = {
1761 "error_name": error_name,
1762 "error_result": error_result,
1763 "error_reason": error_reason,
1764 "param_reqs": param_reqs,
1765 }
1766 return info_dict
1767
1768 @staticmethod
1769 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1770 error_name = ErrorIf.InputSizeStartLengthMismatch
1771 param_reqs = {"rank": None, "dtype": None, "shape": None}
1772 error_result = False
1773 error_reason = "rank of input not equal to length of start or size"
1774
1775 if check:
1776 input_shape = kwargs["input_shape"]
1777 start = kwargs["start"]
1778 size = kwargs["size"]
1779 rank = len(input_shape)
1780 if rank != len(start) or rank != len(size):
1781 error_result = True
1782
1783 info_dict = {
1784 "error_name": error_name,
1785 "error_result": error_result,
1786 "error_reason": error_reason,
1787 "param_reqs": param_reqs,
1788 }
1789 return info_dict
1790
1791 @staticmethod
1792 def evIndexOutsideBounds(check=False, **kwargs):
1793 error_name = ErrorIf.IndexOutsideBounds
1794 param_reqs = {"rank": None, "dtype": None, "shape": None}
1795 error_result = False
1796 error_reason = "Index outside of allowed bounds"
1797
1798 if check:
1799 input_shape = kwargs["input_shape"]
1800 perms = kwargs["perms"]
1801 rank = len(input_shape)
1802
1803 for index in perms:
1804 if index < 0 or index > rank:
1805 error_result = True
1806
1807 info_dict = {
1808 "error_name": error_name,
1809 "error_result": error_result,
1810 "error_reason": error_reason,
1811 "param_reqs": param_reqs,
1812 }
1813 return info_dict
1814
1815 @staticmethod
1816 def evIndexUsedTwice(check=False, **kwargs):
1817 error_name = ErrorIf.IndexUsedTwice
1818 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1819 error_result = False
1820 error_reason = "Index used multiple times"
1821
1822 if check:
1823 perms = kwargs["perms"]
1824
1825 unique_indices = []
1826 for index in perms:
1827 if index in unique_indices:
1828 error_result = True
1829 else:
1830 unique_indices.append(index)
1831
1832 info_dict = {
1833 "error_name": error_name,
1834 "error_result": error_result,
1835 "error_reason": error_reason,
1836 "param_reqs": param_reqs,
1837 }
1838 return info_dict
1839
1840 @staticmethod
1841 def evMaxSmallerMin(check=False, **kwargs):
1842 error_name = ErrorIf.MaxSmallerMin
1843 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1844 error_result = False
1845 error_reason = "Max value smaller than min value"
1846
1847 if check:
1848 max_val = kwargs["max_val"]
1849 min_val = kwargs["min_val"]
1850 if max_val < min_val:
1851 error_result = True
1852
1853 info_dict = {
1854 "error_name": error_name,
1855 "error_result": error_result,
1856 "error_reason": error_reason,
1857 "param_reqs": param_reqs,
1858 }
1859 return info_dict
1860
1861 @staticmethod
1862 def evConcatInputRankMismatch(check=False, **kwargs):
1863 error_name = ErrorIf.ConcatInputRankMismatch
1864 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1865 error_result = False
1866 error_reason = "Input ranks are not identical"
1867
1868 if check:
1869 inputs = kwargs["inputs"]
1870 input_shape = kwargs["input_shape"]
1871 for input in inputs:
1872 if len(input.shape) != len(input_shape):
1873 error_result = True
1874
1875 info_dict = {
1876 "error_name": error_name,
1877 "error_result": error_result,
1878 "error_reason": error_reason,
1879 "param_reqs": param_reqs,
1880 }
1881 return info_dict
1882
1883 @staticmethod
1884 def evConcatInputDimMismatch(check=False, **kwargs):
1885 error_name = ErrorIf.ConcatInputDimMismatch
1886 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1887 error_result = False
1888 error_reason = "Input dimensions differ on too many axes"
1889
1890 if check:
1891 inputs = kwargs["inputs"]
1892 input_shape = kwargs["input_shape"]
1893 axis = kwargs["axis"]
1894
1895 # Ensure rank is valid before checking dims.
1896 valid_rank = True
1897 for input in inputs:
1898 if len(input.shape) != len(input_shape):
1899 valid_rank = False
1900
1901 if valid_rank:
1902 for input in inputs:
1903 for i, dim in enumerate(input.shape):
1904 if dim != input_shape[i] and axis != i:
1905 error_result = True
1906
1907 info_dict = {
1908 "error_name": error_name,
1909 "error_result": error_result,
1910 "error_reason": error_reason,
1911 "param_reqs": param_reqs,
1912 }
1913 return info_dict
1914
1915 @staticmethod
1916 def evConcatShapeSumMismatch(check=False, **kwargs):
1917 error_name = ErrorIf.ConcatShapeSumMismatch
1918 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1919 error_result = False
1920 error_reason = "Sum of dimensions on axis not equal to output dimension"
1921
1922 if check:
1923 inputs = kwargs["inputs"]
1924 input_shape = kwargs["input_shape"]
1925 output_shape = kwargs["output_shape"]
1926 axis = kwargs["axis"]
1927
1928 # Ensure rank is valid before checking dims.
1929 valid_params = True
1930 for input in inputs:
1931 if len(input.shape) != len(input_shape):
1932 valid_params = False
1933 if axis < 0 or axis > len(input_shape):
1934 valid_params = False
1935
1936 if valid_params:
1937 axis_dim_sum = 0
1938 for input in inputs:
1939 axis_dim_sum += input.shape[axis]
1940
1941 if axis_dim_sum != output_shape[axis]:
1942 error_result = True
1943
1944 info_dict = {
1945 "error_name": error_name,
1946 "error_result": error_result,
1947 "error_reason": error_reason,
1948 "param_reqs": param_reqs,
1949 }
1950 return info_dict
1951
1952 @staticmethod
1953 def evInputListThenGraphMismatch(check=False, **kwargs):
1954 error_name = ErrorIf.CondIfInputListThenGraphMismatch
1955 param_reqs = {"rank": None, "dtype": None, "shape": None}
1956 error_result = False
1957 error_reason = "Input list shape does not match then-graph shape"
1958
1959 if check:
1960 a = kwargs["a"]
1961 b = kwargs["b"]
1962 basicBlocks = kwargs["basicBlocks"]
1963 then_block = basicBlocks[1]
1964 then_inputs = then_block.inputs
1965 then_tens = then_block.tensors
1966 if (a.shape != then_tens[then_inputs[0]].shape) or (
1967 b.shape != then_tens[then_inputs[1]].shape
1968 ):
1969 error_result = True
1970
1971 info_dict = {
1972 "error_name": error_name,
1973 "error_result": error_result,
1974 "error_reason": error_reason,
1975 "param_reqs": param_reqs,
1976 }
1977 return info_dict
1978
1979 @staticmethod
1980 def evInputListElseGraphMismatch(check=False, **kwargs):
1981 error_name = ErrorIf.CondIfInputListElseGraphMismatch
1982 param_reqs = {"rank": None, "dtype": None, "shape": None}
1983 error_result = False
1984 error_reason = "Input list shape does not match else-graph shape"
1985
1986 if check:
1987 a = kwargs["a"]
1988 b = kwargs["b"]
1989 basicBlocks = kwargs["basicBlocks"]
1990 else_block = basicBlocks[2]
1991 else_inputs = else_block.inputs
1992 else_tens = else_block.tensors
1993 if (a.shape != else_tens[else_inputs[0]].shape) or (
1994 b.shape != else_tens[else_inputs[1]].shape
1995 ):
1996 error_result = True
1997
1998 info_dict = {
1999 "error_name": error_name,
2000 "error_result": error_result,
2001 "error_reason": error_reason,
2002 "param_reqs": param_reqs,
2003 }
2004 return info_dict
2005
2006 @staticmethod
2007 def evOutputListThenGraphMismatch(check=False, **kwargs):
2008 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2009 param_reqs = {"rank": None, "dtype": None, "shape": None}
2010 error_result = False
2011 error_reason = "Output list shape does not match then-graph shape"
2012
2013 if check:
2014 basicBlocks = kwargs["basicBlocks"]
2015 cond_block = basicBlocks[0]
2016 cond_outputs = cond_block.outputs
2017 cond_tens = cond_block.tensors
2018 then_block = basicBlocks[1]
2019 then_outputs = then_block.outputs
2020 then_tens = then_block.tensors
2021 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2022 error_result = True
2023
2024 info_dict = {
2025 "error_name": error_name,
2026 "error_result": error_result,
2027 "error_reason": error_reason,
2028 "param_reqs": param_reqs,
2029 }
2030 return info_dict
2031
2032 @staticmethod
2033 def evOutputListElseGraphMismatch(check=False, **kwargs):
2034 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2035 param_reqs = {"rank": None, "dtype": None, "shape": None}
2036 error_result = False
2037 error_reason = "Output list shape does not match else-graph shape"
2038
2039 if check:
2040 basicBlocks = kwargs["basicBlocks"]
2041 cond_block = basicBlocks[0]
2042 cond_outputs = cond_block.outputs
2043 cond_tens = cond_block.tensors
2044 else_block = basicBlocks[2]
2045 else_outputs = else_block.outputs
2046 else_tens = else_block.tensors
2047 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2048 error_result = True
2049
2050 info_dict = {
2051 "error_name": error_name,
2052 "error_result": error_result,
2053 "error_reason": error_reason,
2054 "param_reqs": param_reqs,
2055 }
2056 return info_dict
2057
2058 @staticmethod
2059 def evInputListOutputListMismatch(check=False, **kwargs):
2060 error_name = ErrorIf.InputListOutputListMismatch
2061 param_reqs = {"rank": None, "dtype": None, "shape": None}
2062 error_result = False
2063 error_reason = "Input list does not match output list"
2064
2065 if check:
2066 basicBlocks = kwargs["basicBlocks"]
2067 while_block = basicBlocks[0]
2068 while_inputs = while_block.inputs
2069 while_outputs = while_block.outputs
2070 while_tens = while_block.tensors
2071 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2072 error_result = True
2073
2074 info_dict = {
2075 "error_name": error_name,
2076 "error_result": error_result,
2077 "error_reason": error_reason,
2078 "param_reqs": param_reqs,
2079 }
2080 return info_dict
2081
2082 @staticmethod
2083 def evInputListCondGraphMismatch(check=False, **kwargs):
2084 error_name = ErrorIf.InputListCondGraphMismatch
2085 param_reqs = {"rank": None, "dtype": None, "shape": None}
2086 error_result = False
2087 error_reason = "Input list does not match cond graph"
2088
2089 if check:
2090 basicBlocks = kwargs["basicBlocks"]
2091 while_block = basicBlocks[0]
2092 while_inputs = while_block.inputs
2093 while_tens = while_block.tensors
2094 cond_block = basicBlocks[1]
2095 cond_inputs = cond_block.inputs
2096 cond_tens = cond_block.tensors
2097 if (
2098 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2099 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2100 error_result = True
2101
2102 info_dict = {
2103 "error_name": error_name,
2104 "error_result": error_result,
2105 "error_reason": error_reason,
2106 "param_reqs": param_reqs,
2107 }
2108 return info_dict
2109
2110 @staticmethod
2111 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2112 error_name = ErrorIf.InputListBodyGraphInputMismatch
2113 param_reqs = {"rank": None, "dtype": None, "shape": None}
2114 error_result = False
2115 error_reason = "Input list does not match body graph input"
2116
2117 if check:
2118 basicBlocks = kwargs["basicBlocks"]
2119 while_block = basicBlocks[0]
2120 while_inputs = while_block.inputs
2121 while_tens = while_block.tensors
2122 body_block = basicBlocks[2]
2123 body_outputs = body_block.inputs
2124 body_tens = body_block.tensors
2125 if (
2126 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2127 ) or (
2128 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2129 ):
2130 error_result = True
2131
2132 info_dict = {
2133 "error_name": error_name,
2134 "error_result": error_result,
2135 "error_reason": error_reason,
2136 "param_reqs": param_reqs,
2137 }
2138 return info_dict
2139
2140 @staticmethod
2141 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2142 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2143 param_reqs = {"rank": None, "dtype": None, "shape": None}
2144 error_result = False
2145 error_reason = "Input list does not match body graph output"
2146
2147 if check:
2148 basicBlocks = kwargs["basicBlocks"]
2149 while_block = basicBlocks[0]
2150 while_inputs = while_block.inputs
2151 while_tens = while_block.tensors
2152 body_block = basicBlocks[2]
2153 body_outputs = body_block.outputs
2154 body_tens = body_block.tensors
2155 if (
2156 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2157 ) or (
2158 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2159 ):
2160 error_result = True
2161 info_dict = {
2162 "error_name": error_name,
2163 "error_result": error_result,
2164 "error_reason": error_reason,
2165 "param_reqs": param_reqs,
2166 }
2167 return info_dict
2168
2169 @staticmethod
2170 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2171 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2172 param_reqs = {"rank": None, "dtype": None, "shape": None}
2173 error_result = False
2174 error_reason = "Cond graph output is not a match list of booleans"
2175
2176 if check:
2177 basicBlocks = kwargs["basicBlocks"]
2178 cond_block = basicBlocks[1]
2179 cond_outputs = cond_block.outputs
2180 cond_tens = cond_block.tensors
2181 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2182 error_result = True
2183
2184 info_dict = {
2185 "error_name": error_name,
2186 "error_result": error_result,
2187 "error_reason": error_reason,
2188 "param_reqs": param_reqs,
2189 }
2190 return info_dict
2191
2192
2193class TosaInvalidValidator:
2194 @staticmethod
2195 def ivWrongDataTypeOrModeResize(**kwargs):
2196 input_dtype = kwargs["input_dtype"]
2197 args = kwargs["args"]
2198 mode = args[0]
2199 output_dtype = args[8]
2200
2201 if mode == ResizeMode.BILINEAR:
2202 # Invalid output data type / Invalid input datatype
2203 return (
2204 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
2205 or not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
2206 or not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
2207 or (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2208 )
2209 elif mode == ResizeMode.NEAREST:
2210 # Invalid output data type / Invalid input datatype
2211 return (input_dtype != output_dtype) or (
2212 input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]
2213 )
2214 else:
2215 # Invalid resize mode
2216 return True
2217
2218 @staticmethod
2219 def ivBadStride(**kwargs):
2220 input_dtype = kwargs["input_dtype"]
2221 args = kwargs["args"]
2222 stride_x = args[1][0]
2223 stride_y = args[1][1]
2224 stride_fp_x = args[4][0]
2225 stride_fp_y = args[4][1]
2226
2227 if input_dtype == DType.FLOAT:
2228 if stride_fp_x <= 0 or stride_fp_y <= 0:
2229 # Negative or zero stride
2230 return True
2231 else:
2232 if stride_x <= 0 or stride_y <= 0:
2233 # Negative or zero stride
2234 return True
2235 return False
2236
2237 @staticmethod
2238 def ivHeightWidthInvalid(**kwargs):
2239 opName = kwargs["opName"]
2240
2241 inputShapes = kwargs["shapeList"]
2242 input_shape = inputShapes[0]
2243
2244 args = kwargs["args"]
2245 strides = args[0]
2246 padding = args[1]
2247
2248 if opName.endswith("pool2d"):
2249 # avg_pool2d, max_pool2d
2250 kernel_shape = args[2]
2251 h = (
2252 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2253 ) // strides[0]
2254 w = (
2255 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2256 ) // strides[1]
2257 # return True if any dimension is < 1
2258 return h < 1 or w < 1
2259
2260 if opName.startswith("transpose_conv2d"):
2261 # transpose_conv2d
2262 dilations = args[2]
2263 output_shape = args[3]
2264 filter_shape = inputShapes[1]
2265 kernel_shape = filter_shape[1:-1]
2266
2267 def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
2268 """Calculate the transpose_conv2d output size for a dimension.
2269
2270 Based on the keras function deconv_output_length, in
2271 https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
2272
2273 Args:
2274 in_size: the input size - int
2275 stride: the stride - int
2276 kernel_size: the kernel size - int
2277 dilation: the kernel dilation - int
2278 out_pad: the output padding - int
2279 in_pad: the input padding - int
2280
2281 Returns:
2282 the output size
2283 """
2284 dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
2285 return (
2286 (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
2287 )
2288
2289 for pad_h, pad_w in (
2290 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2291 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2292 (0, 0), # VALID padding
2293 ):
2294 h = get_out_size(
2295 input_shape[1],
2296 strides[0],
2297 kernel_shape[0],
2298 dilations[0],
2299 padding[0],
2300 pad_h,
2301 )
2302 w = get_out_size(
2303 input_shape[2],
2304 strides[1],
2305 kernel_shape[1],
2306 dilations[1],
2307 padding[1],
2308 pad_w,
2309 )
2310 if output_shape[1] == h and output_shape[2] == w:
2311 return False
2312
2313 # output shape does not match the expected shape for any padding option
2314 return True
2315
2316 if "conv2d" in opName or "conv3d" in opName:
2317 # conv2d, conv3d, depthwise_conv2d
2318 dilations = args[2]
2319 filter_shape = inputShapes[1]
2320 kernel_shape = (
2321 filter_shape[0:2]
2322 if opName.startswith("depthwise_conv2d")
2323 else filter_shape[1:-1]
2324 )
2325
2326 for i in range(len(kernel_shape)):
2327 dim = (
2328 input_shape[i + 1]
2329 - kernel_shape[i]
2330 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2331 + padding[i * 2 + 0]
2332 + padding[i * 2 + 1]
2333 ) // strides[i] + 1
2334 # return True if any dimension is < 1
2335 if dim < 1:
2336 return True
2337 return False
2338
2339 assert False, f"Unrecognized Op: {opName}"
2340
2341 @staticmethod
2342 def ivNonPositiveOutputShape(**kwargs):
2343 args = kwargs["args"]
2344 output_shape = args[3]
2345 if output_shape[1] <= 0 or output_shape[2] <= 0:
2346 # Negative output shape
2347 return True
2348 return False