blob: b331a42bb337f77dbdee8297a298a1f89d0f27d7 [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003import numpy as np
4from generator.tosa_utils import product
5from generator.tosa_utils import usableDTypes
6from generator.tosa_utils import valueToName
7from tosa.DType import DType
8from tosa.Op import Op
9from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010
Matthew Haddone86fd342021-09-07 16:12:21 +010011
12class ErrorIf(object):
13 MaxDimExceeded = "MaxDimExceeded"
14 StrideSmallerEqualZero = "StrideSmallerEqualZero"
15 StrideLargerEqualMax = "StrideLargerEqualMax"
16 StrideLargerDimension = "StrideLargerDimension"
17 OffsetSmallerEqualMin = "OffsetSmallerEqualMin"
18 OffsetLargerEqualMax = "OffsetLargerEqualMax"
19 ShiftNotZero = "ShiftNotZero"
20 ShiftSmallerOne = "ShiftSmallerOne"
21 ShiftLargerEleven = "ShiftLargerEleven"
Matthew Haddon848efb42021-09-09 12:30:53 +010022 WrongInputType = "WrongInputType"
23 WrongOutputType = "WrongOutputType"
24 WrongInputList = "WrongInputList"
25 WrongOutputList = "WrongOutputList"
26 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010027 BatchMismatch = "BatchMismatch"
28 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010029 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000030 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010031 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010032 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010033 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010034 AxisSmallerZero = "AxisSmallerZero"
35 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
37 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010039 KernelSmallerOne = "KernelSmallerOne"
40 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000041 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010042 PadSmallerZero = "PadSmallerZero"
43 PadLargerEqualKernel = "PadLargerEqualKernel"
44 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010045 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
46 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
47 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010048 ScaleNotTrue = "ScaleNotTrue"
49 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010050 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
51 StartSmallerZero = "StartSmallerZero"
52 SizeSmallerEqualZero = "SizeSmallerEqualZero"
53 StartSizeOutsideBounds = "StartSizeOutsideBounds"
54 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
55 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
56 IndexOutsideBounds = "IndexOutsideBounds"
57 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010058 MaxSmallerMin = "MaxSmallerMin"
59 ConcatInputRankMismatch = "ConcatInputRankMismatch"
60 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010061 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010062 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
63 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
64 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
65 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
66 InputListOutputListMismatch = "InputListOutputListMismatch"
67 InputListCondGraphMismatch = "InputListCondGraphMismatch"
68 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
69 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
70 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010071 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
72 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010073
74
75class TosaErrorIfArgGen:
76 @staticmethod
77 def eiResizeErrorIf(
78 testGen,
79 error_name,
80 mode,
81 dtype,
82 shapeList,
83 outputDType,
84 shift,
85 stride,
86 stride_fp,
87 offset,
88 offset_fp,
89 ):
90
91 if outputDType == DType.FLOAT:
92 if error_name == ErrorIf.StrideSmallerEqualZero:
93 stride_fp = testGen.rng.random(size=[2]) - 2
94 elif error_name == ErrorIf.ShiftNotZero:
95 shift = testGen.rng.integers(1, 5)
96 elif error_name == ErrorIf.StrideLargerDimension:
97 shape = shapeList[0]
98 transform_height = testGen.rng.choice([False, True])
99 if transform_height:
100 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
101 else:
102 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
103 else:
104 if error_name == ErrorIf.StrideSmallerEqualZero:
105 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
106 elif error_name == ErrorIf.ShiftSmallerOne:
107 shift = testGen.rng.integers(-3, 1)
108 if shift <= 0:
109 stride = [
110 (16 >> -shift) - 1,
111 (16 >> -shift) - 1,
112 ] # avoids other ERROR_IF checks
113 offset = [
114 (16 >> -shift) - 1,
115 (16 >> -shift) - 1,
116 ] # avoids other ERROR_IF checks
117 else:
118 stride = [
119 (16 << shift) - 1,
120 (16 << shift) - 1,
121 ] # avoids other ERROR_IF checks
122 offset = [
123 (16 << shift) - 1,
124 (16 << shift) - 1,
125 ] # avoids other ERROR_IF checks
126 elif error_name == ErrorIf.ShiftLargerEleven:
127 shift = np.int16(testGen.rng.integers(12, 15))
128 elif error_name == ErrorIf.StrideLargerDimension:
129 shape = shapeList[0]
130 transform_height = testGen.rng.choice([False, True])
131 if transform_height:
132 stride[0] = shape[1] + testGen.rng.integers(1, 10)
133 else:
134 stride[1] = shape[2] + testGen.rng.integers(1, 10)
135 elif error_name == ErrorIf.StrideLargerEqualMax:
136 stride = [(16 << shift) + 1, (16 << shift) + 1]
137 elif error_name == ErrorIf.OffsetLargerEqualMax:
138 offset = [(16 << shift) + 1, (16 << shift) + 1]
139 elif error_name == ErrorIf.OffsetSmallerEqualMin:
140 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
141
142 if error_name == ErrorIf.WrongOutputType:
143 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
144 incorrect_types = (
145 DType.INT4,
146 DType.INT16,
147 DType.INT32,
148 DType.INT48,
149 DType.FLOAT,
150 )
151 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
152 incorrect_types = (
153 DType.INT4,
154 DType.INT8,
155 DType.INT32,
156 DType.INT48,
157 DType.FLOAT,
158 )
159 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
160 incorrect_types = (
161 DType.INT4,
162 DType.INT8,
163 DType.INT16,
164 DType.INT48,
165 DType.FLOAT,
166 )
167 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
168 incorrect_types = (
169 DType.INT4,
170 DType.INT8,
171 DType.INT16,
172 DType.INT32,
173 DType.FLOAT,
174 )
175 elif dtype == DType.FLOAT:
176 incorrect_types = (
177 DType.INT4,
178 DType.INT8,
179 DType.INT16,
180 DType.INT32,
181 DType.INT48,
182 )
183 outputDType = testGen.rng.choice(a=incorrect_types)
184
185 return shift, stride, stride_fp, offset, offset_fp, outputDType
186
187 @staticmethod
188 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
189 if (
190 error_name == ErrorIf.StrideSmallerOne
191 # padding must not exceed the kernel size
192 and pad[0] < kernel[0]
193 and pad[1] < kernel[0]
194 and pad[2] < kernel[1]
195 and pad[3] < kernel[1]
196 ):
197 wrongStride = (
198 testGen.rng.choice([0, -1, -2, -3]),
199 testGen.rng.choice([0, -1, -2, -3]),
200 )
201 return wrongStride, pad, kernel
202 elif error_name == ErrorIf.PadSmallerZero:
203 wrongPad = (
204 testGen.rng.choice([-1, -2, -3]),
205 testGen.rng.choice([-1, -2, -3]),
206 testGen.rng.choice([-1, -2, -3]),
207 testGen.rng.choice([-1, -2, -3]),
208 )
209 return stride, wrongPad, kernel
210 elif error_name == ErrorIf.KernelSmallerOne:
211 wrongKernel = (
212 testGen.rng.choice([0, -1, -2, -3]),
213 testGen.rng.choice([0, -1, -2, -3]),
214 )
215 return stride, pad, wrongKernel
216 elif error_name == ErrorIf.PadLargerEqualKernel:
217 wrongPad = (
218 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
219 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
220 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
221 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
222 )
223 return stride, wrongPad, kernel
224 else:
225 return None, None, None
226
227 @staticmethod
228 def eiRescaleWrongOutputType(input_dtype, output_dtype):
229 if input_dtype == DType.INT8:
230 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
231 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100232 elif input_dtype == DType.INT16:
233 if output_dtype not in [
234 DType.UINT8,
235 DType.INT8,
236 DType.UINT16,
237 DType.INT16,
238 DType.INT32,
239 ]:
240 return True
241 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100242 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
243 return True
244 elif input_dtype == DType.INT48:
245 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
246 return True
247 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100248 if output_dtype not in [DType.INT8, DType.INT16]:
249 return True
250 elif input_dtype == DType.UINT16:
251 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100252 return True
253 return False
254
255 @staticmethod
256 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
257 # Mess up input/output tensors for ERROR_IF checks
258 if error_name == "WrongInputList":
259 add_input = testGen.rng.choice([True, False])
260 if add_input:
261 input_list.append("eiDummyInput")
262 else:
263 input_list = input_list[:-1]
264 elif error_name == "WrongOutputList":
265 add_output = testGen.rng.choice([True, False])
266 if add_output:
267 output_list.append("eiDummyOutput")
268 else:
269 output_list = []
270 return input_list, output_list
271
272 @staticmethod
273 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
274 """Restrict the dimensions and overall size of a shape to
275 max_dim and max_items.
276 """
277 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
278 while product(new_shape) > max_items:
279 new_shape = [max(d - 1, 1) for d in new_shape]
280 return new_shape
281
282 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
283 if error_name == ErrorIf.StartSmallerZero:
284 newStart = []
285 for i in range(len(input_shape)):
286 newStart.append(testGen.rng.choice([-3, -2, -1]))
287 return newStart, size
288 elif error_name == ErrorIf.SizeSmallerEqualZero:
289 newSize = []
290 for i in range(len(input_shape)):
291 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
292 return start, newSize
293 elif error_name == ErrorIf.StartSizeOutsideBounds:
294 newStart, newSize = [], []
295 for i in range(len(input_shape)):
296 newStart.append(input_shape[i] - 1)
297 newSize.append(testGen.rng.choice([2, 3, 4]))
298 return newStart, newSize
299 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
300 remove = testGen.rng.choice([True, False])
301 if remove:
302 newStart = start[1:]
303 newSize = size[1:]
304 else:
305 newStart = start
306 newStart.append(1)
307 newSize = size
308 newSize.append(1)
309 return newStart, newSize
310 else:
311 return start, size
312
313 @staticmethod
314 def eiCastErrorIf(testGen, input_dtype):
315 if input_dtype in [DType.BOOL, DType.FLOAT]:
316 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
317 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
318 outputDType = [DType.INT48]
319 else:
320 assert True, f"input_dtype ({input_dtype}) not supported"
321 return outputDType
322
323
324class TosaErrorValidator:
325 @staticmethod
326 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
327 """Check ERROR_IF statements are caught and set the expected result.
328
329 Args:
330 serializer: the serializer to set the expected result in
331 validator_fcns: a sequence of validator functions to verify the result
332 error_name: the name of the ERROR_IF condition to check for
333 kwargs: keyword arguments for the validator functions
334 Returns:
335 True if the result matches the expected result; otherwise False
336 """
337 overall_result = True
338 for val_fcn in validator_fcns:
339 val_result = val_fcn(True, **kwargs)
340 validator_name = val_result["error_name"]
341 error_result = val_result["error_result"]
342 error_reason = val_result["error_reason"]
343
344 # expect an error IFF the error_name and validator_name match
345 expected_result = error_result == (error_name == validator_name)
346 overall_result &= expected_result
347
348 if expected_result and error_result:
349 serializer.setExpectedReturnCode(2, True, desc=error_reason)
350 elif error_result: # and not expected_result
351 print(
352 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
353 f" Expected: {error_name}, Got: {validator_name}"
354 )
355 elif not expected_result: # and not error_result
356 print(
357 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
358 f" Expected: {error_name}"
359 )
360
361 if not expected_result:
362 for k, v in sorted(kwargs.items()):
363 if k != "op":
364 if k.endswith("dtype"):
365 v = valueToName(DType, v)
366 print(f" {k} = {v}")
367
368 return overall_result
369
370 @staticmethod
371 def evWrongInputType(check=False, **kwargs):
372 error_result = False
373
374 # Find the unsupported input data types
375 op = kwargs["op"]
376 input_dtypes = op["types"]
377 allowed_input_dtypes = {
378 t[0] if isinstance(t, list) else t for t in input_dtypes
379 }
380 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
381
382 if op["op"] == Op.CLAMP:
383 wrong_input_dtypes.remove(DType.INT48)
384
385 if check:
386 input_dtype = kwargs["input_dtype"]
387 if input_dtype not in allowed_input_dtypes:
388 error_result = True
389
390 info_dict = {
391 "error_name": ErrorIf.WrongInputType,
392 "error_result": error_result,
393 "error_reason": "Input data type not supported for this operator",
394 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
395 }
396 return info_dict
397
398 @staticmethod
399 def evWrongOutputType(check=False, **kwargs):
400 error_result = False
401
402 if check:
403 input_dtype = kwargs["input_dtype"]
404 output_dtype = kwargs["output_dtype"]
405 op = kwargs["op"]
406
407 if op["op"] == Op.RESIZE:
408 mode = kwargs["mode"]
409 if (
410 (
411 mode == ResizeMode.NEAREST
412 and input_dtype == DType.INT8
413 and output_dtype != DType.INT8
414 )
415 or (
416 mode == ResizeMode.NEAREST
417 and input_dtype == DType.INT16
418 and output_dtype != DType.INT16
419 )
420 or (
421 mode == ResizeMode.BILINEAR
422 and input_dtype == DType.INT8
423 and output_dtype != DType.INT32
424 )
425 or (
426 mode == ResizeMode.BILINEAR
427 and input_dtype == DType.INT16
428 and output_dtype != DType.INT48
429 )
430 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
431 ):
432 error_result = True
433
434 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100435 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
436 input_dtype, output_dtype
437 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100438
439 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
440 if (
441 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
442 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
443 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
444 ):
445 error_result = True
446
447 elif op["op"] == Op.ARGMAX:
448 if (
449 input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
450 and output_dtype != DType.INT32
451 ):
452 error_result = True
453
454 elif op["op"] == Op.MUL:
455 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
456 error_result = True
457 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
458 error_result = True
459
460 elif op["op"] == Op.TABLE:
461 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
462 error_result = True
463 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
464 error_result = True
465
466 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
467 if output_dtype != DType.BOOL:
468 error_result = True
469
470 elif op["op"] == Op.CAST:
471 if (
472 (
473 input_dtype == DType.BOOL
474 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
475 )
476 or (
477 input_dtype == DType.INT8
478 and output_dtype
479 not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
480 )
481 or (
482 input_dtype == DType.INT16
483 and output_dtype
484 not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
485 )
486 or (
487 input_dtype == DType.INT32
488 and output_dtype
489 not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
490 )
491 or (
492 input_dtype == DType.FLOAT
493 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
494 )
495 ):
496 error_result = True
497
498 elif op["op"] in {
499 Op.CONV2D,
500 Op.CONV3D,
501 Op.DEPTHWISE_CONV2D,
502 Op.TRANSPOSE_CONV2D,
503 }:
504 if (
505 input_dtype == DType.INT8
506 and output_dtype != DType.INT32
507 or input_dtype == DType.INT16
508 and output_dtype != DType.INT48
509 or input_dtype == DType.FLOAT
510 and output_dtype != DType.FLOAT
511 ):
512 error_result = True
513 # invalid input types are ignored, to avoid reporting multiple errors
514
515 else:
516 if output_dtype != input_dtype:
517 error_result = True
518
519 info_dict = {
520 "error_name": ErrorIf.WrongOutputType,
521 "error_result": error_result,
522 "error_reason": (
523 "Output data type not supported for this configuration of operator"
524 ),
525 "param_reqs": {"rank": None, "dtype": None, "shape": None},
526 }
527 return info_dict
528
529 @staticmethod
530 def evWrongRank(check=False, **kwargs):
531 all_ranks = (1, 2, 3, 4, 5)
532
533 # Make a list of incorrect ranks
534 assert "op" in kwargs
535 op = kwargs["op"]
536 rmin, rmax = op["rank"]
537 rank_range = range(rmin, rmax + 1)
538 incorrect_ranks = list(set(all_ranks) - set(rank_range))
539 # Remove small incorrect ranks to avoid index errors
540 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
541 # Set minimum incorrect rank to 3 to avoid index error
542 if op["op"] in [Op.RESIZE]:
543 incorrect_ranks = [3, 5]
544 elif op["op"] in [Op.TRANSPOSE]:
545 incorrect_ranks = [7, 8]
546 elif op["op"] in [Op.CONV3D]:
547 incorrect_ranks = [6, 7]
548
549 error_name = ErrorIf.WrongRank
550 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
551 error_result = False
552 error_reason = "Rank not supported for this operator"
553
554 if check:
555 input_shape = kwargs["input_shape"]
556
557 if (
558 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
559 and len(input_shape) != 4
560 ):
561 error_result = True
562 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
563 error_result = True
564 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
565 error_result = True
566 else:
567 if len(input_shape) not in rank_range:
568 error_result = True
569
570 info_dict = {
571 "error_name": error_name,
572 "error_result": error_result,
573 "error_reason": error_reason,
574 "param_reqs": param_reqs,
575 }
576 return info_dict
577
578 @staticmethod
579 def evWrongInputList(check=False, **kwargs):
580 error_name = ErrorIf.WrongInputList
581 param_reqs = {"rank": None, "dtype": None, "shape": None}
582 error_result = False
583 error_reason = "Op input list does not match expected input"
584
585 if check:
586 op = kwargs["op"]
587 input_list = kwargs["input_list"]
588 num_operands = kwargs["num_operands"]
589 if op["op"] in [Op.SCATTER, Op.GATHER]:
590 # SCATTER/GATHER add an indices input tensor in their build functions
591 num_operands += 1
592 if len(input_list) != num_operands:
593 error_result = True
594
595 info_dict = {
596 "error_name": error_name,
597 "error_result": error_result,
598 "error_reason": error_reason,
599 "param_reqs": param_reqs,
600 }
601 return info_dict
602
603 @staticmethod
604 def evWrongOutputList(check=False, **kwargs):
605 error_name = ErrorIf.WrongOutputList
606 param_reqs = {"rank": None, "dtype": None, "shape": None}
607 error_result = False
608 error_reason = "Op output list does not match expected output"
609
610 if check:
611 output_list = kwargs["output_list"]
612 # Note this will be incorrect if an operator returns more than one output
613 if len(output_list) != 1:
614 error_result = True
615
616 info_dict = {
617 "error_name": error_name,
618 "error_result": error_result,
619 "error_reason": error_reason,
620 "param_reqs": param_reqs,
621 }
622 return info_dict
623
624 @staticmethod
625 def evMaxDimExceeded(check=False, **kwargs):
626 error_name = ErrorIf.MaxDimExceeded
627 param_reqs = {
628 "rank": [4, 4],
629 "dtype": [DType.INT8],
630 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
631 }
632 error_result = False
633 error_reason = (
634 "At least one maximum dimension is greater than or equal to 16384"
635 )
636
637 if check:
638 input_shape = kwargs["input_shape"]
639 output_shape = kwargs["output_shape"] # Note this is just (OH, OW)
640 if (
641 (input_shape[1] >= 16384)
642 or (input_shape[2] >= 16384)
643 or (output_shape[0] >= 16384)
644 or (output_shape[1] >= 16384)
645 ):
646 error_result = True
647
648 info_dict = {
649 "error_name": error_name,
650 "error_result": error_result,
651 "error_reason": error_reason,
652 "param_reqs": param_reqs,
653 }
654 return info_dict
655
656 @staticmethod
657 def evBatchMismatch(check=False, **kwargs):
658 error_name = ErrorIf.BatchMismatch
659 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
660 error_result = False
661 error_reason = "Input batch size not equal to output batch size"
662
663 assert "op" in kwargs
664 op = kwargs["op"]
665 rmin, rmax = op["rank"]
666 rank_range = range(rmin, rmax + 1)
667
668 if check:
669 input_shape = kwargs["input_shape"]
670 output_shape = kwargs[
671 "result_tensor"
672 ].shape # Note this is just (N, OH, OW, C)
673
674 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
675 error_result = True
676
677 info_dict = {
678 "error_name": error_name,
679 "error_result": error_result,
680 "error_reason": error_reason,
681 "param_reqs": param_reqs,
682 }
683 return info_dict
684
685 @staticmethod
686 def evChannelMismatch(check=False, **kwargs):
687 error_name = ErrorIf.ChannelMismatch
688 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
689 error_result = False
690 error_reason = "Input channel size not equal to output channel size"
691
692 assert "op" in kwargs
693 op = kwargs["op"]
694 rmin, rmax = op["rank"]
695 rank_range = range(rmin, rmax + 1)
696
697 if check:
698 input_shape = kwargs["input_shape"]
699 output_shape = kwargs[
700 "result_tensor"
701 ].shape # Note this is just (N, OH, OW, C)
702 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
703 error_result = True
704
705 info_dict = {
706 "error_name": error_name,
707 "error_result": error_result,
708 "error_reason": error_reason,
709 "param_reqs": param_reqs,
710 }
711 return info_dict
712
713 @staticmethod
714 def evStrideSmallerEqualZero(check=False, **kwargs):
715 error_name = ErrorIf.StrideSmallerEqualZero
716 param_reqs = {"rank": None, "dtype": None, "shape": None}
717 error_result = False
718 error_reason = "Stride value smaller than or equal zero"
719
720 if check:
721 input_dtype = kwargs["input_dtype"]
722 output_dtype = kwargs["output_dtype"]
723 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
724 stride = kwargs["stride"] # Work around wrong input/output type tests
725 elif output_dtype == DType.FLOAT:
726 stride = kwargs["stride_fp"]
727 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
728 stride = kwargs[
729 "stride_fp"
730 ] # Work around wrong input/output type tests
731 else:
732 stride = kwargs["stride"]
733
734 if min(stride) <= 0:
735 error_result = True
736
737 info_dict = {
738 "error_name": error_name,
739 "error_result": error_result,
740 "error_reason": error_reason,
741 "param_reqs": param_reqs,
742 }
743 return info_dict
744
745 @staticmethod
746 def evStrideLargerEqualMax(check=False, **kwargs):
747 error_name = ErrorIf.StrideLargerEqualMax
748 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
749 error_result = False
750 error_reason = "Stride value larger than or equal to maximum value"
751
752 if check:
753 shift = kwargs["shift"]
754 input_dtype = kwargs["input_dtype"]
755 stride = kwargs["stride"]
756 if input_dtype in [DType.INT8, DType.INT16]:
757 if shift >= 0 and (
758 stride[0] >= (16 << shift) or stride[1] >= (16 << shift)
759 ):
760 error_result = True
761 elif shift < 0 and (
762 stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)
763 ):
764 error_result = True
765
766 info_dict = {
767 "error_name": error_name,
768 "error_result": error_result,
769 "error_reason": error_reason,
770 "param_reqs": param_reqs,
771 }
772 return info_dict
773
774 @staticmethod
775 def evStrideLargerDimension(check=False, **kwargs):
776 error_name = ErrorIf.StrideLargerDimension
777 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
778 error_result = False
779 error_reason = "Stride value larger than or equal to H/W dimension"
780
781 if check:
782 shape = kwargs["input_shape"]
783 input_dtype = kwargs["input_dtype"]
784 stride = kwargs["stride_fp"]
785
786 if (
787 input_dtype == DType.FLOAT
788 and (stride[0] > shape[1])
789 or (stride[1] > shape[2])
790 ):
791 error_result = True
792
793 info_dict = {
794 "error_name": error_name,
795 "error_result": error_result,
796 "error_reason": error_reason,
797 "param_reqs": param_reqs,
798 }
799 return info_dict
800
801 @staticmethod
802 def evOffsetSmallerEqualMin(check=False, **kwargs):
803 error_name = ErrorIf.OffsetSmallerEqualMin
804 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
805 error_result = False
806 error_reason = "Offset value smaller than or equal to minimum value"
807
808 if check:
809 shift = kwargs["shift"]
810 output_dtype = kwargs["output_dtype"]
811 if output_dtype == DType.FLOAT:
812 offset = kwargs["offset_fp"]
813 else:
814 offset = kwargs["offset"]
815
816 if shift >= 0 and (
817 offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)
818 ):
819 error_result = True
820 elif shift < 0 and (
821 offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)
822 ):
823 error_result = True
824
825 info_dict = {
826 "error_name": error_name,
827 "error_result": error_result,
828 "error_reason": error_reason,
829 "param_reqs": param_reqs,
830 }
831 return info_dict
832
833 @staticmethod
834 def evOffsetLargerEqualMax(check=False, **kwargs):
835 error_name = ErrorIf.OffsetLargerEqualMax
836 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
837 error_result = False
838 error_reason = "Offset value larger than or equal to maximum value"
839
840 if check:
841 shift = kwargs["shift"]
842 output_dtype = kwargs["output_dtype"]
843 if output_dtype == DType.FLOAT:
844 offset = kwargs["offset_fp"]
845 else:
846 offset = kwargs["offset"]
847
848 if shift >= 0:
849 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
850 error_result = True
851
852 if shift >= 0 and (
853 offset[0] >= (16 << shift) or offset[1] >= (16 << shift)
854 ):
855 error_result = True
856 elif shift < 0 and (
857 offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)
858 ):
859 error_result = True
860
861 info_dict = {
862 "error_name": error_name,
863 "error_result": error_result,
864 "error_reason": error_reason,
865 "param_reqs": param_reqs,
866 }
867 return info_dict
868
869 @staticmethod
870 def evShiftNotZero(check=False, **kwargs):
871 error_name = ErrorIf.ShiftNotZero
872 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
873 error_result = False
874 error_reason = "Shift value must be zero for float input"
875
876 if check:
877 shift = kwargs["shift"]
878 input_dtype = kwargs["input_dtype"]
879 output_dtype = kwargs["output_dtype"]
880 if (
881 input_dtype == DType.FLOAT
882 and output_dtype == DType.FLOAT
883 and shift != 0
884 ):
885 error_result = True
886
887 info_dict = {
888 "error_name": error_name,
889 "error_result": error_result,
890 "error_reason": error_reason,
891 "param_reqs": param_reqs,
892 }
893 return info_dict
894
895 @staticmethod
896 def evShiftSmallerOne(check=False, **kwargs):
897 error_name = ErrorIf.ShiftSmallerOne
898 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
899 error_result = False
900 error_reason = "Shift value smaller than one"
901
902 if check:
903 shift = kwargs["shift"]
904 input_dtype = kwargs["input_dtype"]
905 output_dtype = kwargs["output_dtype"]
906 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
907 error_result = True
908
909 info_dict = {
910 "error_name": error_name,
911 "error_result": error_result,
912 "error_reason": error_reason,
913 "param_reqs": param_reqs,
914 }
915 return info_dict
916
917 @staticmethod
918 def evShiftLargerEleven(check=False, **kwargs):
919 error_name = ErrorIf.ShiftLargerEleven
920 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
921 error_result = False
922 error_reason = "Shift value larger than eleven"
923
924 if check:
925 shift = kwargs["shift"]
926 if shift > 11:
927 error_result = True
928
929 info_dict = {
930 "error_name": error_name,
931 "error_result": error_result,
932 "error_reason": error_reason,
933 "param_reqs": param_reqs,
934 }
935 return info_dict
936
937 @staticmethod
938 def evRankMismatch(check=False, **kwargs):
939 error_name = ErrorIf.RankMismatch
940 param_reqs = {"rank": None, "dtype": None, "shape": None}
941 error_result = False
942 error_reason = "Input Rank does not match output rank"
943
944 if check:
945 input1_shape = kwargs["input1"].shape
946 input2_shape = kwargs["input2"].shape
947 # In case of SELECT op
948 input3_shape = (
949 kwargs["input3"].shape if "input3" in kwargs else input2_shape
950 )
951 output_shape = kwargs["result_tensor"].shape
952 if (
953 (len(input1_shape) != len(output_shape))
954 or (len(input2_shape) != len(output_shape))
955 or (len(input3_shape) != len(output_shape))
956 ):
957 error_result = True
958
959 info_dict = {
960 "error_name": error_name,
961 "error_result": error_result,
962 "error_reason": error_reason,
963 "param_reqs": param_reqs,
964 }
965 return info_dict
966
967 @staticmethod
968 def evDimensionMismatch(check=False, **kwargs):
969 error_name = ErrorIf.DimensionMismatch
970 param_reqs = {"rank": None, "dtype": None, "shape": None}
971 error_result = False
972 error_reason = "Input Dimensions do not match output"
973
974 if check:
975 input1_shape = kwargs["input1"].shape
976 input2_shape = kwargs["input2"].shape
977 # In case of SELECT op
978 input3_shape = (
979 kwargs["input3"].shape if "input3" in kwargs else input2_shape
980 )
981 output_shape = kwargs["result_tensor"].shape
982 for i in range(
983 min(len(input1_shape), len(input2_shape), len(input3_shape))
984 ):
985 if (
986 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
987 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
988 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
989 ):
990 error_result = True
991
992 info_dict = {
993 "error_name": error_name,
994 "error_result": error_result,
995 "error_reason": error_reason,
996 "param_reqs": param_reqs,
997 }
998 return info_dict
999
1000 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001001 def _getZeroPoint(qinfo, index):
1002 """Return zero point value from quantization info.
1003
1004 Generally input_zp is index 0, output_zp is index 1
1005 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001006 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001007
1008 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001009 def evInputZeroPointNotZero(check=False, **kwargs):
1010 op = kwargs["op"]
1011 error_result = False
1012
1013 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001014 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001015
1016 # This does not apply to quantizable types
1017 inputDtypes = [
1018 dtype
1019 for dtype in op["types"]
1020 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1021 or (not isinstance(dtype, list) and dtype not in qTypes)
1022 ]
1023
1024 if check:
1025 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001026 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001027 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001028 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001029 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001030 (kwargs["input_dtype"], input_zero_point),
1031 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001032 ):
1033 if dtype not in qTypes and zp != 0:
1034 error_result = True
1035 break
1036 else:
1037 error_result = input_dtype not in qTypes and input_zero_point != 0
1038
1039 info_dict = {
1040 "error_name": ErrorIf.InputZeroPointNotZero,
1041 "error_result": error_result,
1042 "error_reason": "Input DType not INT8 and zero point not 0",
1043 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1044 }
1045 return info_dict
1046
1047 @staticmethod
1048 def evWeightZeroPointNotZero(check=False, **kwargs):
1049 op = kwargs["op"]
1050
1051 # exclude inputs with INT8 weights
1052 inputDtypes = [
1053 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1054 ]
1055
1056 error_name = ErrorIf.WeightZeroPointNotZero
1057 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1058 error_result = False
1059 error_reason = "Weight DType not INT8 and zero point not 0"
1060
1061 if check:
1062 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001063 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001064 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1065 error_result = True
1066
1067 info_dict = {
1068 "error_name": error_name,
1069 "error_result": error_result,
1070 "error_reason": error_reason,
1071 "param_reqs": param_reqs,
1072 }
1073 return info_dict
1074
1075 @staticmethod
1076 def evOutputZeroPointNotZero(check=False, **kwargs):
1077 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001078 inputDtypes = [
1079 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1080 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001081
1082 error_name = ErrorIf.OutputZeroPointNotZero
1083 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1084 error_result = False
1085 error_reason = "Output DType not INT8 and zero point not 0"
1086
1087 if check:
1088 input_dtype = kwargs["input_dtype"]
1089 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001090 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001091 if op["op"] == Op.AVG_POOL2D:
1092 if input_dtype != DType.INT8 and output_zero_point != 0:
1093 error_result = True
1094 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001095 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1096 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001097 ):
1098 error_result = True
1099
1100 info_dict = {
1101 "error_name": error_name,
1102 "error_result": error_result,
1103 "error_reason": error_reason,
1104 "param_reqs": param_reqs,
1105 }
1106 return info_dict
1107
1108 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001109 def evU16InputZeroPointNotValid(check=False, **kwargs):
1110 error_name = ErrorIf.U16InputZeroPointNotValid
1111 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1112 error_result = False
1113 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1114
1115 if check:
1116 input_dtype = kwargs["input_dtype"]
1117 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1118 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1119 0,
1120 32768,
1121 ]
1122
1123 info_dict = {
1124 "error_name": error_name,
1125 "error_result": error_result,
1126 "error_reason": error_reason,
1127 "param_reqs": param_reqs,
1128 }
1129 return info_dict
1130
1131 @staticmethod
1132 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1133 error_name = ErrorIf.U16OutputZeroPointNotValid
1134 param_reqs = {"rank": None, "dtype": None, "shape": None}
1135 error_result = False
1136 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1137
1138 if check:
1139 output_dtype = kwargs["output_dtype"]
1140 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1141
1142 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1143 0,
1144 32768,
1145 ]
1146
1147 info_dict = {
1148 "error_name": error_name,
1149 "error_result": error_result,
1150 "error_reason": error_reason,
1151 "param_reqs": param_reqs,
1152 }
1153 return info_dict
1154
1155 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001156 def evAxisSmallerZero(check=False, **kwargs):
1157 error_name = ErrorIf.AxisSmallerZero
1158 param_reqs = {"rank": None, "dtype": None, "shape": None}
1159 error_result = False
1160 error_reason = "Axis smaller than zero"
1161
1162 if check:
1163 axis = kwargs["axis"]
1164 if axis < 0:
1165 error_result = True
1166
1167 info_dict = {
1168 "error_name": error_name,
1169 "error_result": error_result,
1170 "error_reason": error_reason,
1171 "param_reqs": param_reqs,
1172 }
1173 return info_dict
1174
1175 @staticmethod
1176 def evAxisLargerRank(check=False, **kwargs):
1177 error_name = ErrorIf.AxisLargerRank
1178 param_reqs = {"rank": None, "dtype": None, "shape": None}
1179 error_result = False
1180 error_reason = "Axis larger than rank"
1181
1182 if check:
1183 axis = kwargs["axis"]
1184 shape = kwargs["input_shape"]
1185 if axis > len(shape):
1186 error_result = True
1187
1188 info_dict = {
1189 "error_name": error_name,
1190 "error_result": error_result,
1191 "error_reason": error_reason,
1192 "param_reqs": param_reqs,
1193 }
1194 return info_dict
1195
1196 @staticmethod
1197 def evShapeOfAxisNotOne(check=False, **kwargs):
1198 error_name = ErrorIf.ShapeOfAxisNotOne
1199 param_reqs = {"rank": None, "dtype": None, "shape": None}
1200 error_result = False
1201 error_reason = "shape[axis] is not equal to 1"
1202
1203 if check:
1204 axis = kwargs["axis"]
1205 shape = kwargs["output_shape"]
1206 if (0 <= axis < len(shape)) and shape[axis] != 1:
1207 error_result = True
1208
1209 info_dict = {
1210 "error_name": error_name,
1211 "error_result": error_result,
1212 "error_reason": error_reason,
1213 "param_reqs": param_reqs,
1214 }
1215 return info_dict
1216
1217 @staticmethod
1218 def evPadSmallerZero(check=False, **kwargs):
1219 error_name = ErrorIf.PadSmallerZero
1220 param_reqs = {"rank": None, "dtype": None, "shape": None}
1221 error_result = False
1222 error_reason = "At least one pad is smaller than zero"
1223
1224 if check:
1225 op = kwargs["op"]
1226 pad = kwargs["pad"]
1227 if op["op"] == Op.PAD:
1228 for padding in pad:
1229 if min(padding) < 0:
1230 error_result = True
1231 else:
1232 if min(pad) < 0:
1233 error_result = True
1234
1235 info_dict = {
1236 "error_name": error_name,
1237 "error_result": error_result,
1238 "error_reason": error_reason,
1239 "param_reqs": param_reqs,
1240 }
1241 return info_dict
1242
1243 @staticmethod
1244 def evPadLargerEqualKernel(check=False, **kwargs):
1245 error_name = ErrorIf.PadLargerEqualKernel
1246 param_reqs = {"rank": None, "dtype": None, "shape": None}
1247 error_result = False
1248 error_reason = "At least one pad is larger than kernel dimension"
1249
1250 if check:
1251 pad = kwargs["pad"]
1252 kernel = kwargs["kernel"]
1253 if min(pad) > 0 and min(kernel) > 1:
1254 if (
1255 pad[0] >= kernel[0]
1256 or pad[1] >= kernel[0]
1257 or pad[2] >= kernel[1]
1258 or pad[3] >= kernel[1]
1259 ):
1260 error_result = True
1261
1262 info_dict = {
1263 "error_name": error_name,
1264 "error_result": error_result,
1265 "error_reason": error_reason,
1266 "param_reqs": param_reqs,
1267 }
1268 return info_dict
1269
1270 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001271 def checkPoolingParams(kernel, stride, pad):
1272 return (
1273 min(kernel) >= 1
1274 and min(stride) >= 1
1275 and min(pad) >= 0
1276 and not (
1277 pad[0] >= kernel[0]
1278 or pad[1] >= kernel[0]
1279 or pad[2] >= kernel[1]
1280 or pad[3] >= kernel[1]
1281 )
1282 )
1283
1284 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001285 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1286 error_name = ErrorIf.PoolingOutputShapeMismatch
1287 param_reqs = {"rank": None, "dtype": None, "shape": None}
1288 error_result = False
1289 error_reason = (
1290 "Mismatch between output shape provided and expected output shape"
1291 )
1292
1293 if check:
1294 pad = kwargs["pad"]
1295 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1296
1297 kernel = kwargs["kernel"]
1298 kernel_y, kernel_x = kernel[0], kernel[1]
1299
1300 input_shape = kwargs["input_shape"]
1301 IH, IW = input_shape[1], input_shape[2]
1302
1303 output_shape = kwargs["output_shape"]
1304 OH, OW = output_shape[1], output_shape[2]
1305
1306 stride = kwargs["stride"]
1307 stride_y, stride_x = stride[0], stride[1]
1308
1309 # calculate correct height, width dimensions
1310 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001311 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1312 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001313
1314 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001315 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001316
1317 if params_valid and (OH != y_correct or OW != x_correct):
1318 error_result = True
1319
1320 info_dict = {
1321 "error_name": error_name,
1322 "error_result": error_result,
1323 "error_reason": error_reason,
1324 "param_reqs": param_reqs,
1325 }
1326 return info_dict
1327
1328 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001329 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1330 error_name = ErrorIf.PoolingOutputShapeNonInteger
1331 param_reqs = {"rank": None, "dtype": None, "shape": None}
1332 error_result = False
1333 error_reason = "Parameters do not yield exact integer output dimensions"
1334
1335 if check:
1336 pad = kwargs["pad"]
1337 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1338
1339 kernel = kwargs["kernel"]
1340 kernel_y, kernel_x = kernel[0], kernel[1]
1341
1342 input_shape = kwargs["input_shape"]
1343 IH, IW = input_shape[1], input_shape[2]
1344
1345 stride = kwargs["stride"]
1346 stride_y, stride_x = stride[0], stride[1]
1347
1348 # calculate remainder of height, width dimensions
1349 if stride_x != 0 and stride_y != 0:
1350 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1351 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1352
1353 # ensure parameters are valid
1354 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1355 if params_valid and (y_remainder != 0 or x_remainder != 0):
1356 error_result = True
1357
1358 info_dict = {
1359 "error_name": error_name,
1360 "error_result": error_result,
1361 "error_reason": error_reason,
1362 "param_reqs": param_reqs,
1363 }
1364 return info_dict
1365
1366 @staticmethod
1367 def checkConvParams(weight_shape, stride, pad, dilation):
1368 return (
1369 # Check kernel sizes
1370 min(weight_shape[1:-1]) >= 1
1371 and min(stride) >= 1
1372 and min(pad) >= 0
1373 and (dilation is None or min(dilation) >= 1)
1374 )
1375
1376 @staticmethod
1377 def evConvOutputShapeMismatch(check=False, **kwargs):
1378 error_name = ErrorIf.ConvOutputShapeMismatch
1379 param_reqs = {"rank": None, "dtype": None, "shape": None}
1380 error_result = False
1381 error_reason = (
1382 "Mismatch between output shape provided and expected output shape"
1383 )
1384
1385 if check:
1386 op = kwargs["op"]
1387 pad = kwargs["pad"]
1388 weight_shape = kwargs["weight_shape"]
1389 input_shape = kwargs["input_shape"]
1390 output_shape = kwargs["output_shape"]
1391 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1392 stride = kwargs["stride"]
1393
1394 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1395
1396 # calculate correct dimensions
1397 dims_correct = []
1398 if min(stride) > 0:
1399 for index in range(len(stride)):
1400 pad_offset = index * 2
1401 if op["op"] == Op.TRANSPOSE_CONV2D:
1402 dims_correct.append(
1403 (input_shape[index + 1] - 1) * stride[index]
1404 - pad[pad_offset]
1405 - pad[pad_offset + 1]
1406 + weight_shape[index + kernel_offset]
1407 )
1408 else:
1409 dims_correct.append(
1410 (
1411 input_shape[index + 1]
1412 - 1
1413 + pad[pad_offset]
1414 + pad[pad_offset + 1]
1415 - (weight_shape[index + kernel_offset] - 1)
1416 * dilation[index]
1417 )
1418 // stride[index]
1419 + 1
1420 )
1421
1422 # ensure parameters are valid
1423 params_valid = TosaErrorValidator.checkConvParams(
1424 weight_shape, stride, pad, dilation
1425 )
1426
1427 if params_valid and output_shape[1:-1] != dims_correct:
1428 error_result = True
1429
1430 info_dict = {
1431 "error_name": error_name,
1432 "error_result": error_result,
1433 "error_reason": error_reason,
1434 "param_reqs": param_reqs,
1435 }
1436 return info_dict
1437
1438 @staticmethod
1439 def evConvOutputShapeNonInteger(check=False, **kwargs):
1440 error_name = ErrorIf.ConvOutputShapeNonInteger
1441 param_reqs = {"rank": None, "dtype": None, "shape": None}
1442 error_result = False
1443 error_reason = "Parameters do not yield exact integer output dimensions"
1444
1445 if check:
1446 op = kwargs["op"]
1447 pad = kwargs["pad"]
1448 weight_shape = kwargs["weight_shape"]
1449 input_shape = kwargs["input_shape"]
1450 dilation = kwargs["dilation"]
1451 stride = kwargs["stride"]
1452
1453 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1454
1455 # calculate correct height, width dimensions
1456 remainders = []
1457 if min(stride) > 0:
1458 for index in range(len(stride)):
1459 pad_offset = index * 2
1460 remainders.append(
1461 (
1462 input_shape[index + 1]
1463 - 1
1464 + pad[pad_offset]
1465 + pad[pad_offset + 1]
1466 - (weight_shape[index + kernel_offset] - 1)
1467 * dilation[index]
1468 )
1469 % stride[index]
1470 )
1471
1472 # ensure parameters are valid
1473 params_valid = TosaErrorValidator.checkConvParams(
1474 weight_shape, stride, pad, dilation
1475 )
1476 if params_valid and max(remainders) > 0:
1477 error_result = True
1478
1479 info_dict = {
1480 "error_name": error_name,
1481 "error_result": error_result,
1482 "error_reason": error_reason,
1483 "param_reqs": param_reqs,
1484 }
1485 return info_dict
1486
1487 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001488 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1489 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1490 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1491 error_result = False
1492 error_reason = (
1493 "Mismatch between output shape provided and expected output shape"
1494 )
1495
1496 if check:
1497 output_shape = kwargs["output_shape"]
1498 input_shape = kwargs["input_shape"]
1499 axis = kwargs["axis"]
1500
1501 dimension_match = True
1502 axis_shift = 0
1503
1504 # Check that rank is correct before trying to check dimensions
1505 if (len(input_shape) - 1) == len(output_shape):
1506 for i in range(len(input_shape)):
1507 if i == axis:
1508 axis_shift = 1
1509 continue
1510 if input_shape[i] != output_shape[i - axis_shift]:
1511 dimension_match = False
1512
1513 if not dimension_match:
1514 error_result = True
1515
1516 info_dict = {
1517 "error_name": error_name,
1518 "error_result": error_result,
1519 "error_reason": error_reason,
1520 "param_reqs": param_reqs,
1521 }
1522 return info_dict
1523
1524 @staticmethod
1525 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1526 error_name = ErrorIf.ArgmaxOutputRankMismatch
1527 param_reqs = {"rank": None, "dtype": None, "shape": None}
1528 error_result = False
1529 error_reason = (
1530 "Mismatch between output shape provided and expected output shape"
1531 )
1532
1533 if check:
1534 output_shape = kwargs["output_shape"]
1535 input_shape = kwargs["input_shape"]
1536 axis = kwargs["axis"]
1537 valid_params = axis >= 0 and axis < len(input_shape)
1538
1539 if valid_params and (len(input_shape) - 1) != len(output_shape):
1540 error_result = True
1541
1542 info_dict = {
1543 "error_name": error_name,
1544 "error_result": error_result,
1545 "error_reason": error_reason,
1546 "param_reqs": param_reqs,
1547 }
1548 return info_dict
1549
1550 @staticmethod
1551 def evKernelSmallerOne(check=False, **kwargs):
1552 error_name = ErrorIf.KernelSmallerOne
1553 param_reqs = {"rank": None, "dtype": None, "shape": None}
1554 error_result = False
1555 error_reason = "At least one kernel dimension is smaller than zero"
1556
1557 if check:
1558 kernel = kwargs["kernel"]
1559 if min(kernel) < 1:
1560 error_result = True
1561
1562 info_dict = {
1563 "error_name": error_name,
1564 "error_result": error_result,
1565 "error_reason": error_reason,
1566 "param_reqs": param_reqs,
1567 }
1568 return info_dict
1569
1570 @staticmethod
1571 def evStrideSmallerOne(check=False, **kwargs):
1572 error_name = ErrorIf.StrideSmallerOne
1573 param_reqs = {"rank": None, "dtype": None, "shape": None}
1574 error_result = False
1575 error_reason = "At least one stride dimension is smaller than zero"
1576
1577 if check:
1578 stride = kwargs["stride"]
1579 if min(stride) < 1:
1580 error_result = True
1581
1582 info_dict = {
1583 "error_name": error_name,
1584 "error_result": error_result,
1585 "error_reason": error_reason,
1586 "param_reqs": param_reqs,
1587 }
1588 return info_dict
1589
1590 @staticmethod
1591 def evDilationSmallerOne(check=False, **kwargs):
1592 error_result = check and min(kwargs["dilation"]) < 1
1593 return {
1594 "error_name": ErrorIf.DilationSmallerOne,
1595 "error_reason": "At least one dilation is smaller than one",
1596 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1597 "error_result": error_result,
1598 }
1599
1600 @staticmethod
1601 def evScaleTrue(check=False, **kwargs):
1602 error_name = ErrorIf.ScaleTrue
1603 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1604 error_result = False
1605 error_reason = "Scale set to true but input type is INT48"
1606
1607 if check:
1608 input_dtype = kwargs["input_dtype"]
1609 scale32 = kwargs["scale32"]
1610 if scale32 and input_dtype == DType.INT48:
1611 error_result = True
1612
1613 info_dict = {
1614 "error_name": error_name,
1615 "error_result": error_result,
1616 "error_reason": error_reason,
1617 "param_reqs": param_reqs,
1618 }
1619 return info_dict
1620
1621 @staticmethod
1622 def evScaleNotTrue(check=False, **kwargs):
1623 error_name = ErrorIf.ScaleNotTrue
1624 param_reqs = {"rank": None, "dtype": None, "shape": None}
1625 error_result = False
1626 error_reason = "Scale set to false but double round set to true"
1627
1628 if check:
1629 scale32 = kwargs["scale32"]
1630 double_round = kwargs["double_round"]
1631 if not scale32 and double_round:
1632 error_result = True
1633
1634 info_dict = {
1635 "error_name": error_name,
1636 "error_result": error_result,
1637 "error_reason": error_reason,
1638 "param_reqs": param_reqs,
1639 }
1640 return info_dict
1641
1642 @staticmethod
1643 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1644 error_name = ErrorIf.TensorSizeInputOutputMismatch
1645 param_reqs = {"rank": None, "dtype": None, "shape": None}
1646 error_result = False
1647 error_reason = "Input tensor size does not match output tensor size"
1648
1649 if check:
1650 input_shape = kwargs["input_shape"]
1651 output_shape = kwargs["output_shape"]
1652 input_size = np.prod(input_shape)
1653 output_size = np.prod(output_shape)
1654 if input_size != output_size:
1655 error_result = True
1656
1657 info_dict = {
1658 "error_name": error_name,
1659 "error_result": error_result,
1660 "error_reason": error_reason,
1661 "param_reqs": param_reqs,
1662 }
1663 return info_dict
1664
1665 @staticmethod
1666 def evStartSmallerZero(check=False, **kwargs):
1667 error_name = ErrorIf.StartSmallerZero
1668 param_reqs = {"rank": None, "dtype": None, "shape": None}
1669 error_result = False
1670 error_reason = "Starting point smaller than zero"
1671
1672 if check:
1673 input_shape = kwargs["input_shape"]
1674 start = kwargs["start"]
1675 rank = len(input_shape)
1676 if len(start) == rank:
1677 for index in range(rank):
1678 if start[index] < 0:
1679 error_result = True
1680
1681 info_dict = {
1682 "error_name": error_name,
1683 "error_result": error_result,
1684 "error_reason": error_reason,
1685 "param_reqs": param_reqs,
1686 }
1687 return info_dict
1688
1689 @staticmethod
1690 def evSizeSmallerEqualZero(check=False, **kwargs):
1691 error_name = ErrorIf.SizeSmallerEqualZero
1692 param_reqs = {"rank": None, "dtype": None, "shape": None}
1693 error_result = False
1694 error_reason = "Size smaller than or equal to zero"
1695
1696 if check:
1697 input_shape = kwargs["input_shape"]
1698 size = kwargs["size"]
1699 rank = len(input_shape)
1700 if len(size) == rank:
1701 for index in range(rank):
1702 if size[index] <= 0:
1703 error_result = True
1704
1705 info_dict = {
1706 "error_name": error_name,
1707 "error_result": error_result,
1708 "error_reason": error_reason,
1709 "param_reqs": param_reqs,
1710 }
1711 return info_dict
1712
1713 @staticmethod
1714 def evStartSizeOutsideBounds(check=False, **kwargs):
1715 error_name = ErrorIf.StartSizeOutsideBounds
1716 param_reqs = {"rank": None, "dtype": None, "shape": None}
1717 error_result = False
1718 error_reason = "starting point plus size larger than input dimension"
1719
1720 if check:
1721 input_shape = kwargs["input_shape"]
1722 start = kwargs["start"]
1723 size = kwargs["size"]
1724 rank = len(input_shape)
1725 if len(start) == rank and len(size) == rank:
1726 for index in range(rank):
1727 if start[index] + size[index] > input_shape[index]:
1728 error_result = True
1729
1730 info_dict = {
1731 "error_name": error_name,
1732 "error_result": error_result,
1733 "error_reason": error_reason,
1734 "param_reqs": param_reqs,
1735 }
1736 return info_dict
1737
1738 @staticmethod
1739 def evSizeOutputShapeMismatch(check=False, **kwargs):
1740 error_name = ErrorIf.SizeOutputShapeMismatch
1741 param_reqs = {"rank": None, "dtype": None, "shape": None}
1742 error_result = False
1743 error_reason = "Size does not match output dimension"
1744
1745 if check:
1746 input_shape = kwargs["input_shape"]
1747 output_shape = kwargs["output_shape"]
1748 size = kwargs["size"]
1749 rank = len(input_shape)
1750 if len(size) == rank:
1751 for index in range(rank):
1752 if size[index] != output_shape[index]:
1753 error_result = True
1754
1755 info_dict = {
1756 "error_name": error_name,
1757 "error_result": error_result,
1758 "error_reason": error_reason,
1759 "param_reqs": param_reqs,
1760 }
1761 return info_dict
1762
1763 @staticmethod
1764 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1765 error_name = ErrorIf.InputSizeStartLengthMismatch
1766 param_reqs = {"rank": None, "dtype": None, "shape": None}
1767 error_result = False
1768 error_reason = "rank of input not equal to length of start or size"
1769
1770 if check:
1771 input_shape = kwargs["input_shape"]
1772 start = kwargs["start"]
1773 size = kwargs["size"]
1774 rank = len(input_shape)
1775 if rank != len(start) or rank != len(size):
1776 error_result = True
1777
1778 info_dict = {
1779 "error_name": error_name,
1780 "error_result": error_result,
1781 "error_reason": error_reason,
1782 "param_reqs": param_reqs,
1783 }
1784 return info_dict
1785
1786 @staticmethod
1787 def evIndexOutsideBounds(check=False, **kwargs):
1788 error_name = ErrorIf.IndexOutsideBounds
1789 param_reqs = {"rank": None, "dtype": None, "shape": None}
1790 error_result = False
1791 error_reason = "Index outside of allowed bounds"
1792
1793 if check:
1794 input_shape = kwargs["input_shape"]
1795 perms = kwargs["perms"]
1796 rank = len(input_shape)
1797
1798 for index in perms:
1799 if index < 0 or index > rank:
1800 error_result = True
1801
1802 info_dict = {
1803 "error_name": error_name,
1804 "error_result": error_result,
1805 "error_reason": error_reason,
1806 "param_reqs": param_reqs,
1807 }
1808 return info_dict
1809
1810 @staticmethod
1811 def evIndexUsedTwice(check=False, **kwargs):
1812 error_name = ErrorIf.IndexUsedTwice
1813 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1814 error_result = False
1815 error_reason = "Index used multiple times"
1816
1817 if check:
1818 perms = kwargs["perms"]
1819
1820 unique_indices = []
1821 for index in perms:
1822 if index in unique_indices:
1823 error_result = True
1824 else:
1825 unique_indices.append(index)
1826
1827 info_dict = {
1828 "error_name": error_name,
1829 "error_result": error_result,
1830 "error_reason": error_reason,
1831 "param_reqs": param_reqs,
1832 }
1833 return info_dict
1834
1835 @staticmethod
1836 def evMaxSmallerMin(check=False, **kwargs):
1837 error_name = ErrorIf.MaxSmallerMin
1838 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1839 error_result = False
1840 error_reason = "Max value smaller than min value"
1841
1842 if check:
1843 max_val = kwargs["max_val"]
1844 min_val = kwargs["min_val"]
1845 if max_val < min_val:
1846 error_result = True
1847
1848 info_dict = {
1849 "error_name": error_name,
1850 "error_result": error_result,
1851 "error_reason": error_reason,
1852 "param_reqs": param_reqs,
1853 }
1854 return info_dict
1855
1856 @staticmethod
1857 def evConcatInputRankMismatch(check=False, **kwargs):
1858 error_name = ErrorIf.ConcatInputRankMismatch
1859 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1860 error_result = False
1861 error_reason = "Input ranks are not identical"
1862
1863 if check:
1864 inputs = kwargs["inputs"]
1865 input_shape = kwargs["input_shape"]
1866 for input in inputs:
1867 if len(input.shape) != len(input_shape):
1868 error_result = True
1869
1870 info_dict = {
1871 "error_name": error_name,
1872 "error_result": error_result,
1873 "error_reason": error_reason,
1874 "param_reqs": param_reqs,
1875 }
1876 return info_dict
1877
1878 @staticmethod
1879 def evConcatInputDimMismatch(check=False, **kwargs):
1880 error_name = ErrorIf.ConcatInputDimMismatch
1881 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1882 error_result = False
1883 error_reason = "Input dimensions differ on too many axes"
1884
1885 if check:
1886 inputs = kwargs["inputs"]
1887 input_shape = kwargs["input_shape"]
1888 axis = kwargs["axis"]
1889
1890 # Ensure rank is valid before checking dims.
1891 valid_rank = True
1892 for input in inputs:
1893 if len(input.shape) != len(input_shape):
1894 valid_rank = False
1895
1896 if valid_rank:
1897 for input in inputs:
1898 for i, dim in enumerate(input.shape):
1899 if dim != input_shape[i] and axis != i:
1900 error_result = True
1901
1902 info_dict = {
1903 "error_name": error_name,
1904 "error_result": error_result,
1905 "error_reason": error_reason,
1906 "param_reqs": param_reqs,
1907 }
1908 return info_dict
1909
1910 @staticmethod
1911 def evConcatShapeSumMismatch(check=False, **kwargs):
1912 error_name = ErrorIf.ConcatShapeSumMismatch
1913 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1914 error_result = False
1915 error_reason = "Sum of dimensions on axis not equal to output dimension"
1916
1917 if check:
1918 inputs = kwargs["inputs"]
1919 input_shape = kwargs["input_shape"]
1920 output_shape = kwargs["output_shape"]
1921 axis = kwargs["axis"]
1922
1923 # Ensure rank is valid before checking dims.
1924 valid_params = True
1925 for input in inputs:
1926 if len(input.shape) != len(input_shape):
1927 valid_params = False
1928 if axis < 0 or axis > len(input_shape):
1929 valid_params = False
1930
1931 if valid_params:
1932 axis_dim_sum = 0
1933 for input in inputs:
1934 axis_dim_sum += input.shape[axis]
1935
1936 if axis_dim_sum != output_shape[axis]:
1937 error_result = True
1938
1939 info_dict = {
1940 "error_name": error_name,
1941 "error_result": error_result,
1942 "error_reason": error_reason,
1943 "param_reqs": param_reqs,
1944 }
1945 return info_dict
1946
1947 @staticmethod
1948 def evInputListThenGraphMismatch(check=False, **kwargs):
1949 error_name = ErrorIf.CondIfInputListThenGraphMismatch
1950 param_reqs = {"rank": None, "dtype": None, "shape": None}
1951 error_result = False
1952 error_reason = "Input list shape does not match then-graph shape"
1953
1954 if check:
1955 a = kwargs["a"]
1956 b = kwargs["b"]
1957 basicBlocks = kwargs["basicBlocks"]
1958 then_block = basicBlocks[1]
1959 then_inputs = then_block.inputs
1960 then_tens = then_block.tensors
1961 if (a.shape != then_tens[then_inputs[0]].shape) or (
1962 b.shape != then_tens[then_inputs[1]].shape
1963 ):
1964 error_result = True
1965
1966 info_dict = {
1967 "error_name": error_name,
1968 "error_result": error_result,
1969 "error_reason": error_reason,
1970 "param_reqs": param_reqs,
1971 }
1972 return info_dict
1973
1974 @staticmethod
1975 def evInputListElseGraphMismatch(check=False, **kwargs):
1976 error_name = ErrorIf.CondIfInputListElseGraphMismatch
1977 param_reqs = {"rank": None, "dtype": None, "shape": None}
1978 error_result = False
1979 error_reason = "Input list shape does not match else-graph shape"
1980
1981 if check:
1982 a = kwargs["a"]
1983 b = kwargs["b"]
1984 basicBlocks = kwargs["basicBlocks"]
1985 else_block = basicBlocks[2]
1986 else_inputs = else_block.inputs
1987 else_tens = else_block.tensors
1988 if (a.shape != else_tens[else_inputs[0]].shape) or (
1989 b.shape != else_tens[else_inputs[1]].shape
1990 ):
1991 error_result = True
1992
1993 info_dict = {
1994 "error_name": error_name,
1995 "error_result": error_result,
1996 "error_reason": error_reason,
1997 "param_reqs": param_reqs,
1998 }
1999 return info_dict
2000
2001 @staticmethod
2002 def evOutputListThenGraphMismatch(check=False, **kwargs):
2003 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2004 param_reqs = {"rank": None, "dtype": None, "shape": None}
2005 error_result = False
2006 error_reason = "Output list shape does not match then-graph shape"
2007
2008 if check:
2009 basicBlocks = kwargs["basicBlocks"]
2010 cond_block = basicBlocks[0]
2011 cond_outputs = cond_block.outputs
2012 cond_tens = cond_block.tensors
2013 then_block = basicBlocks[1]
2014 then_outputs = then_block.outputs
2015 then_tens = then_block.tensors
2016 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2017 error_result = True
2018
2019 info_dict = {
2020 "error_name": error_name,
2021 "error_result": error_result,
2022 "error_reason": error_reason,
2023 "param_reqs": param_reqs,
2024 }
2025 return info_dict
2026
2027 @staticmethod
2028 def evOutputListElseGraphMismatch(check=False, **kwargs):
2029 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2030 param_reqs = {"rank": None, "dtype": None, "shape": None}
2031 error_result = False
2032 error_reason = "Output list shape does not match else-graph shape"
2033
2034 if check:
2035 basicBlocks = kwargs["basicBlocks"]
2036 cond_block = basicBlocks[0]
2037 cond_outputs = cond_block.outputs
2038 cond_tens = cond_block.tensors
2039 else_block = basicBlocks[2]
2040 else_outputs = else_block.outputs
2041 else_tens = else_block.tensors
2042 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2043 error_result = True
2044
2045 info_dict = {
2046 "error_name": error_name,
2047 "error_result": error_result,
2048 "error_reason": error_reason,
2049 "param_reqs": param_reqs,
2050 }
2051 return info_dict
2052
2053 @staticmethod
2054 def evInputListOutputListMismatch(check=False, **kwargs):
2055 error_name = ErrorIf.InputListOutputListMismatch
2056 param_reqs = {"rank": None, "dtype": None, "shape": None}
2057 error_result = False
2058 error_reason = "Input list does not match output list"
2059
2060 if check:
2061 basicBlocks = kwargs["basicBlocks"]
2062 while_block = basicBlocks[0]
2063 while_inputs = while_block.inputs
2064 while_outputs = while_block.outputs
2065 while_tens = while_block.tensors
2066 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2067 error_result = True
2068
2069 info_dict = {
2070 "error_name": error_name,
2071 "error_result": error_result,
2072 "error_reason": error_reason,
2073 "param_reqs": param_reqs,
2074 }
2075 return info_dict
2076
2077 @staticmethod
2078 def evInputListCondGraphMismatch(check=False, **kwargs):
2079 error_name = ErrorIf.InputListCondGraphMismatch
2080 param_reqs = {"rank": None, "dtype": None, "shape": None}
2081 error_result = False
2082 error_reason = "Input list does not match cond graph"
2083
2084 if check:
2085 basicBlocks = kwargs["basicBlocks"]
2086 while_block = basicBlocks[0]
2087 while_inputs = while_block.inputs
2088 while_tens = while_block.tensors
2089 cond_block = basicBlocks[1]
2090 cond_inputs = cond_block.inputs
2091 cond_tens = cond_block.tensors
2092 if (
2093 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2094 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2095 error_result = True
2096
2097 info_dict = {
2098 "error_name": error_name,
2099 "error_result": error_result,
2100 "error_reason": error_reason,
2101 "param_reqs": param_reqs,
2102 }
2103 return info_dict
2104
2105 @staticmethod
2106 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2107 error_name = ErrorIf.InputListBodyGraphInputMismatch
2108 param_reqs = {"rank": None, "dtype": None, "shape": None}
2109 error_result = False
2110 error_reason = "Input list does not match body graph input"
2111
2112 if check:
2113 basicBlocks = kwargs["basicBlocks"]
2114 while_block = basicBlocks[0]
2115 while_inputs = while_block.inputs
2116 while_tens = while_block.tensors
2117 body_block = basicBlocks[2]
2118 body_outputs = body_block.inputs
2119 body_tens = body_block.tensors
2120 if (
2121 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2122 ) or (
2123 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2124 ):
2125 error_result = True
2126
2127 info_dict = {
2128 "error_name": error_name,
2129 "error_result": error_result,
2130 "error_reason": error_reason,
2131 "param_reqs": param_reqs,
2132 }
2133 return info_dict
2134
2135 @staticmethod
2136 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2137 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2138 param_reqs = {"rank": None, "dtype": None, "shape": None}
2139 error_result = False
2140 error_reason = "Input list does not match body graph output"
2141
2142 if check:
2143 basicBlocks = kwargs["basicBlocks"]
2144 while_block = basicBlocks[0]
2145 while_inputs = while_block.inputs
2146 while_tens = while_block.tensors
2147 body_block = basicBlocks[2]
2148 body_outputs = body_block.outputs
2149 body_tens = body_block.tensors
2150 if (
2151 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2152 ) or (
2153 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2154 ):
2155 error_result = True
2156 info_dict = {
2157 "error_name": error_name,
2158 "error_result": error_result,
2159 "error_reason": error_reason,
2160 "param_reqs": param_reqs,
2161 }
2162 return info_dict
2163
2164 @staticmethod
2165 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2166 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2167 param_reqs = {"rank": None, "dtype": None, "shape": None}
2168 error_result = False
2169 error_reason = "Cond graph output is not a match list of booleans"
2170
2171 if check:
2172 basicBlocks = kwargs["basicBlocks"]
2173 cond_block = basicBlocks[1]
2174 cond_outputs = cond_block.outputs
2175 cond_tens = cond_block.tensors
2176 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2177 error_result = True
2178
2179 info_dict = {
2180 "error_name": error_name,
2181 "error_result": error_result,
2182 "error_reason": error_reason,
2183 "param_reqs": param_reqs,
2184 }
2185 return info_dict
2186
2187
2188class TosaInvalidValidator:
2189 @staticmethod
2190 def ivWrongDataTypeOrModeResize(**kwargs):
2191 input_dtype = kwargs["input_dtype"]
2192 args = kwargs["args"]
2193 mode = args[0]
2194 output_dtype = args[8]
2195
2196 if mode == ResizeMode.BILINEAR:
2197 # Invalid output data type / Invalid input datatype
2198 return (
2199 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
2200 or not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
2201 or not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
2202 or (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2203 )
2204 elif mode == ResizeMode.NEAREST:
2205 # Invalid output data type / Invalid input datatype
2206 return (input_dtype != output_dtype) or (
2207 input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]
2208 )
2209 else:
2210 # Invalid resize mode
2211 return True
2212
2213 @staticmethod
2214 def ivBadStride(**kwargs):
2215 input_dtype = kwargs["input_dtype"]
2216 args = kwargs["args"]
2217 stride_x = args[1][0]
2218 stride_y = args[1][1]
2219 stride_fp_x = args[4][0]
2220 stride_fp_y = args[4][1]
2221
2222 if input_dtype == DType.FLOAT:
2223 if stride_fp_x <= 0 or stride_fp_y <= 0:
2224 # Negative or zero stride
2225 return True
2226 else:
2227 if stride_x <= 0 or stride_y <= 0:
2228 # Negative or zero stride
2229 return True
2230 return False
2231
2232 @staticmethod
2233 def ivHeightWidthInvalid(**kwargs):
2234 opName = kwargs["opName"]
2235
2236 inputShapes = kwargs["shapeList"]
2237 input_shape = inputShapes[0]
2238
2239 args = kwargs["args"]
2240 strides = args[0]
2241 padding = args[1]
2242
2243 if opName.endswith("pool2d"):
2244 # avg_pool2d, max_pool2d
2245 kernel_shape = args[2]
2246 h = (
2247 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2248 ) // strides[0]
2249 w = (
2250 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2251 ) // strides[1]
2252 # return True if any dimension is < 1
2253 return h < 1 or w < 1
2254
2255 if opName.startswith("transpose_conv2d"):
2256 # transpose_conv2d
TatWai Chong24594f52022-06-08 00:48:04 -07002257 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002258 filter_shape = inputShapes[1]
2259 kernel_shape = filter_shape[1:-1]
2260
TatWai Chong24594f52022-06-08 00:48:04 -07002261 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002262 """Calculate the transpose_conv2d output size for a dimension.
2263
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002264 Args:
2265 in_size: the input size - int
2266 stride: the stride - int
2267 kernel_size: the kernel size - int
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002268 out_pad: the output padding - int
2269 in_pad: the input padding - int
2270
2271 Returns:
2272 the output size
2273 """
TatWai Chong24594f52022-06-08 00:48:04 -07002274 return (in_size - 1) * stride + kernel_size - in_pad - out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002275
2276 for pad_h, pad_w in (
2277 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2278 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2279 (0, 0), # VALID padding
2280 ):
2281 h = get_out_size(
2282 input_shape[1],
2283 strides[0],
2284 kernel_shape[0],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002285 padding[0],
2286 pad_h,
2287 )
2288 w = get_out_size(
2289 input_shape[2],
2290 strides[1],
2291 kernel_shape[1],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002292 padding[1],
2293 pad_w,
2294 )
2295 if output_shape[1] == h and output_shape[2] == w:
2296 return False
2297
2298 # output shape does not match the expected shape for any padding option
2299 return True
2300
2301 if "conv2d" in opName or "conv3d" in opName:
2302 # conv2d, conv3d, depthwise_conv2d
2303 dilations = args[2]
2304 filter_shape = inputShapes[1]
2305 kernel_shape = (
2306 filter_shape[0:2]
2307 if opName.startswith("depthwise_conv2d")
2308 else filter_shape[1:-1]
2309 )
2310
2311 for i in range(len(kernel_shape)):
2312 dim = (
2313 input_shape[i + 1]
2314 - kernel_shape[i]
2315 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2316 + padding[i * 2 + 0]
2317 + padding[i * 2 + 1]
2318 ) // strides[i] + 1
2319 # return True if any dimension is < 1
2320 if dim < 1:
2321 return True
2322 return False
2323
2324 assert False, f"Unrecognized Op: {opName}"
2325
2326 @staticmethod
2327 def ivNonPositiveOutputShape(**kwargs):
2328 args = kwargs["args"]
TatWai Chong24594f52022-06-08 00:48:04 -07002329 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002330 if output_shape[1] <= 0 or output_shape[2] <= 0:
2331 # Negative output shape
2332 return True
2333 return False