blob: e7e758fbb608e4e71cd7dc1dd98daef5e1a8ef3e [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003import numpy as np
4from generator.tosa_utils import product
5from generator.tosa_utils import usableDTypes
6from generator.tosa_utils import valueToName
7from tosa.DType import DType
8from tosa.Op import Op
9from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000010
Matthew Haddone86fd342021-09-07 16:12:21 +010011
12class ErrorIf(object):
13 MaxDimExceeded = "MaxDimExceeded"
14 StrideSmallerEqualZero = "StrideSmallerEqualZero"
15 StrideLargerEqualMax = "StrideLargerEqualMax"
16 StrideLargerDimension = "StrideLargerDimension"
17 OffsetSmallerEqualMin = "OffsetSmallerEqualMin"
18 OffsetLargerEqualMax = "OffsetLargerEqualMax"
19 ShiftNotZero = "ShiftNotZero"
20 ShiftSmallerOne = "ShiftSmallerOne"
21 ShiftLargerEleven = "ShiftLargerEleven"
Matthew Haddon848efb42021-09-09 12:30:53 +010022 WrongInputType = "WrongInputType"
23 WrongOutputType = "WrongOutputType"
24 WrongInputList = "WrongInputList"
25 WrongOutputList = "WrongOutputList"
26 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010027 BatchMismatch = "BatchMismatch"
28 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010029 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000030 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010031 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010032 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010033 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010034 AxisSmallerZero = "AxisSmallerZero"
35 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
37 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010039 KernelSmallerOne = "KernelSmallerOne"
40 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000041 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010042 PadSmallerZero = "PadSmallerZero"
43 PadLargerEqualKernel = "PadLargerEqualKernel"
44 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010045 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
46 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
47 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010048 ScaleNotTrue = "ScaleNotTrue"
49 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010050 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
51 StartSmallerZero = "StartSmallerZero"
52 SizeSmallerEqualZero = "SizeSmallerEqualZero"
53 StartSizeOutsideBounds = "StartSizeOutsideBounds"
54 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
55 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
56 IndexOutsideBounds = "IndexOutsideBounds"
57 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010058 MaxSmallerMin = "MaxSmallerMin"
59 ConcatInputRankMismatch = "ConcatInputRankMismatch"
60 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010061 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010062 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
63 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
64 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
65 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
66 InputListOutputListMismatch = "InputListOutputListMismatch"
67 InputListCondGraphMismatch = "InputListCondGraphMismatch"
68 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
69 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
70 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010071
72
73class TosaErrorIfArgGen:
74 @staticmethod
75 def eiResizeErrorIf(
76 testGen,
77 error_name,
78 mode,
79 dtype,
80 shapeList,
81 outputDType,
82 shift,
83 stride,
84 stride_fp,
85 offset,
86 offset_fp,
87 ):
88
89 if outputDType == DType.FLOAT:
90 if error_name == ErrorIf.StrideSmallerEqualZero:
91 stride_fp = testGen.rng.random(size=[2]) - 2
92 elif error_name == ErrorIf.ShiftNotZero:
93 shift = testGen.rng.integers(1, 5)
94 elif error_name == ErrorIf.StrideLargerDimension:
95 shape = shapeList[0]
96 transform_height = testGen.rng.choice([False, True])
97 if transform_height:
98 stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
99 else:
100 stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
101 else:
102 if error_name == ErrorIf.StrideSmallerEqualZero:
103 stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
104 elif error_name == ErrorIf.ShiftSmallerOne:
105 shift = testGen.rng.integers(-3, 1)
106 if shift <= 0:
107 stride = [
108 (16 >> -shift) - 1,
109 (16 >> -shift) - 1,
110 ] # avoids other ERROR_IF checks
111 offset = [
112 (16 >> -shift) - 1,
113 (16 >> -shift) - 1,
114 ] # avoids other ERROR_IF checks
115 else:
116 stride = [
117 (16 << shift) - 1,
118 (16 << shift) - 1,
119 ] # avoids other ERROR_IF checks
120 offset = [
121 (16 << shift) - 1,
122 (16 << shift) - 1,
123 ] # avoids other ERROR_IF checks
124 elif error_name == ErrorIf.ShiftLargerEleven:
125 shift = np.int16(testGen.rng.integers(12, 15))
126 elif error_name == ErrorIf.StrideLargerDimension:
127 shape = shapeList[0]
128 transform_height = testGen.rng.choice([False, True])
129 if transform_height:
130 stride[0] = shape[1] + testGen.rng.integers(1, 10)
131 else:
132 stride[1] = shape[2] + testGen.rng.integers(1, 10)
133 elif error_name == ErrorIf.StrideLargerEqualMax:
134 stride = [(16 << shift) + 1, (16 << shift) + 1]
135 elif error_name == ErrorIf.OffsetLargerEqualMax:
136 offset = [(16 << shift) + 1, (16 << shift) + 1]
137 elif error_name == ErrorIf.OffsetSmallerEqualMin:
138 offset = [(-16 << shift) - 1, (-16 << shift) - 1]
139
140 if error_name == ErrorIf.WrongOutputType:
141 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
142 incorrect_types = (
143 DType.INT4,
144 DType.INT16,
145 DType.INT32,
146 DType.INT48,
147 DType.FLOAT,
148 )
149 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
150 incorrect_types = (
151 DType.INT4,
152 DType.INT8,
153 DType.INT32,
154 DType.INT48,
155 DType.FLOAT,
156 )
157 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
158 incorrect_types = (
159 DType.INT4,
160 DType.INT8,
161 DType.INT16,
162 DType.INT48,
163 DType.FLOAT,
164 )
165 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
166 incorrect_types = (
167 DType.INT4,
168 DType.INT8,
169 DType.INT16,
170 DType.INT32,
171 DType.FLOAT,
172 )
173 elif dtype == DType.FLOAT:
174 incorrect_types = (
175 DType.INT4,
176 DType.INT8,
177 DType.INT16,
178 DType.INT32,
179 DType.INT48,
180 )
181 outputDType = testGen.rng.choice(a=incorrect_types)
182
183 return shift, stride, stride_fp, offset, offset_fp, outputDType
184
185 @staticmethod
186 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
187 if (
188 error_name == ErrorIf.StrideSmallerOne
189 # padding must not exceed the kernel size
190 and pad[0] < kernel[0]
191 and pad[1] < kernel[0]
192 and pad[2] < kernel[1]
193 and pad[3] < kernel[1]
194 ):
195 wrongStride = (
196 testGen.rng.choice([0, -1, -2, -3]),
197 testGen.rng.choice([0, -1, -2, -3]),
198 )
199 return wrongStride, pad, kernel
200 elif error_name == ErrorIf.PadSmallerZero:
201 wrongPad = (
202 testGen.rng.choice([-1, -2, -3]),
203 testGen.rng.choice([-1, -2, -3]),
204 testGen.rng.choice([-1, -2, -3]),
205 testGen.rng.choice([-1, -2, -3]),
206 )
207 return stride, wrongPad, kernel
208 elif error_name == ErrorIf.KernelSmallerOne:
209 wrongKernel = (
210 testGen.rng.choice([0, -1, -2, -3]),
211 testGen.rng.choice([0, -1, -2, -3]),
212 )
213 return stride, pad, wrongKernel
214 elif error_name == ErrorIf.PadLargerEqualKernel:
215 wrongPad = (
216 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
217 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
218 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
219 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
220 )
221 return stride, wrongPad, kernel
222 else:
223 return None, None, None
224
225 @staticmethod
226 def eiRescaleWrongOutputType(input_dtype, output_dtype):
227 if input_dtype == DType.INT8:
228 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
229 return True
230 if input_dtype in [DType.INT16, DType.INT32]:
231 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
232 return True
233 elif input_dtype == DType.INT48:
234 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
235 return True
236 elif input_dtype == DType.UINT8:
237 if output_dtype != DType.INT8:
238 return True
239 return False
240
241 @staticmethod
242 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
243 # Mess up input/output tensors for ERROR_IF checks
244 if error_name == "WrongInputList":
245 add_input = testGen.rng.choice([True, False])
246 if add_input:
247 input_list.append("eiDummyInput")
248 else:
249 input_list = input_list[:-1]
250 elif error_name == "WrongOutputList":
251 add_output = testGen.rng.choice([True, False])
252 if add_output:
253 output_list.append("eiDummyOutput")
254 else:
255 output_list = []
256 return input_list, output_list
257
258 @staticmethod
259 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
260 """Restrict the dimensions and overall size of a shape to
261 max_dim and max_items.
262 """
263 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
264 while product(new_shape) > max_items:
265 new_shape = [max(d - 1, 1) for d in new_shape]
266 return new_shape
267
268 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
269 if error_name == ErrorIf.StartSmallerZero:
270 newStart = []
271 for i in range(len(input_shape)):
272 newStart.append(testGen.rng.choice([-3, -2, -1]))
273 return newStart, size
274 elif error_name == ErrorIf.SizeSmallerEqualZero:
275 newSize = []
276 for i in range(len(input_shape)):
277 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
278 return start, newSize
279 elif error_name == ErrorIf.StartSizeOutsideBounds:
280 newStart, newSize = [], []
281 for i in range(len(input_shape)):
282 newStart.append(input_shape[i] - 1)
283 newSize.append(testGen.rng.choice([2, 3, 4]))
284 return newStart, newSize
285 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
286 remove = testGen.rng.choice([True, False])
287 if remove:
288 newStart = start[1:]
289 newSize = size[1:]
290 else:
291 newStart = start
292 newStart.append(1)
293 newSize = size
294 newSize.append(1)
295 return newStart, newSize
296 else:
297 return start, size
298
299 @staticmethod
300 def eiCastErrorIf(testGen, input_dtype):
301 if input_dtype in [DType.BOOL, DType.FLOAT]:
302 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
303 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
304 outputDType = [DType.INT48]
305 else:
306 assert True, f"input_dtype ({input_dtype}) not supported"
307 return outputDType
308
309
310class TosaErrorValidator:
311 @staticmethod
312 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
313 """Check ERROR_IF statements are caught and set the expected result.
314
315 Args:
316 serializer: the serializer to set the expected result in
317 validator_fcns: a sequence of validator functions to verify the result
318 error_name: the name of the ERROR_IF condition to check for
319 kwargs: keyword arguments for the validator functions
320 Returns:
321 True if the result matches the expected result; otherwise False
322 """
323 overall_result = True
324 for val_fcn in validator_fcns:
325 val_result = val_fcn(True, **kwargs)
326 validator_name = val_result["error_name"]
327 error_result = val_result["error_result"]
328 error_reason = val_result["error_reason"]
329
330 # expect an error IFF the error_name and validator_name match
331 expected_result = error_result == (error_name == validator_name)
332 overall_result &= expected_result
333
334 if expected_result and error_result:
335 serializer.setExpectedReturnCode(2, True, desc=error_reason)
336 elif error_result: # and not expected_result
337 print(
338 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
339 f" Expected: {error_name}, Got: {validator_name}"
340 )
341 elif not expected_result: # and not error_result
342 print(
343 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
344 f" Expected: {error_name}"
345 )
346
347 if not expected_result:
348 for k, v in sorted(kwargs.items()):
349 if k != "op":
350 if k.endswith("dtype"):
351 v = valueToName(DType, v)
352 print(f" {k} = {v}")
353
354 return overall_result
355
356 @staticmethod
357 def evWrongInputType(check=False, **kwargs):
358 error_result = False
359
360 # Find the unsupported input data types
361 op = kwargs["op"]
362 input_dtypes = op["types"]
363 allowed_input_dtypes = {
364 t[0] if isinstance(t, list) else t for t in input_dtypes
365 }
366 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
367
368 if op["op"] == Op.CLAMP:
369 wrong_input_dtypes.remove(DType.INT48)
370
371 if check:
372 input_dtype = kwargs["input_dtype"]
373 if input_dtype not in allowed_input_dtypes:
374 error_result = True
375
376 info_dict = {
377 "error_name": ErrorIf.WrongInputType,
378 "error_result": error_result,
379 "error_reason": "Input data type not supported for this operator",
380 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
381 }
382 return info_dict
383
384 @staticmethod
385 def evWrongOutputType(check=False, **kwargs):
386 error_result = False
387
388 if check:
389 input_dtype = kwargs["input_dtype"]
390 output_dtype = kwargs["output_dtype"]
391 op = kwargs["op"]
392
393 if op["op"] == Op.RESIZE:
394 mode = kwargs["mode"]
395 if (
396 (
397 mode == ResizeMode.NEAREST
398 and input_dtype == DType.INT8
399 and output_dtype != DType.INT8
400 )
401 or (
402 mode == ResizeMode.NEAREST
403 and input_dtype == DType.INT16
404 and output_dtype != DType.INT16
405 )
406 or (
407 mode == ResizeMode.BILINEAR
408 and input_dtype == DType.INT8
409 and output_dtype != DType.INT32
410 )
411 or (
412 mode == ResizeMode.BILINEAR
413 and input_dtype == DType.INT16
414 and output_dtype != DType.INT48
415 )
416 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
417 ):
418 error_result = True
419
420 elif op["op"] == Op.RESCALE:
421 if input_dtype == DType.INT8:
422 if output_dtype not in [
423 DType.UINT8,
424 DType.INT8,
425 DType.INT16,
426 DType.INT32,
427 ]:
428 error_result = True
429 if input_dtype in [DType.INT16, DType.INT32]:
430 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
431 error_result = True
432 elif input_dtype == DType.INT48:
433 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
434 error_result = True
435 elif input_dtype == DType.UINT8:
436 if output_dtype != DType.INT8:
437 error_result = True
438
439 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
440 if (
441 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
442 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
443 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
444 ):
445 error_result = True
446
447 elif op["op"] == Op.ARGMAX:
448 if (
449 input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
450 and output_dtype != DType.INT32
451 ):
452 error_result = True
453
454 elif op["op"] == Op.MUL:
455 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
456 error_result = True
457 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
458 error_result = True
459
460 elif op["op"] == Op.TABLE:
461 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
462 error_result = True
463 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
464 error_result = True
465
466 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
467 if output_dtype != DType.BOOL:
468 error_result = True
469
470 elif op["op"] == Op.CAST:
471 if (
472 (
473 input_dtype == DType.BOOL
474 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
475 )
476 or (
477 input_dtype == DType.INT8
478 and output_dtype
479 not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
480 )
481 or (
482 input_dtype == DType.INT16
483 and output_dtype
484 not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
485 )
486 or (
487 input_dtype == DType.INT32
488 and output_dtype
489 not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
490 )
491 or (
492 input_dtype == DType.FLOAT
493 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
494 )
495 ):
496 error_result = True
497
498 elif op["op"] in {
499 Op.CONV2D,
500 Op.CONV3D,
501 Op.DEPTHWISE_CONV2D,
502 Op.TRANSPOSE_CONV2D,
503 }:
504 if (
505 input_dtype == DType.INT8
506 and output_dtype != DType.INT32
507 or input_dtype == DType.INT16
508 and output_dtype != DType.INT48
509 or input_dtype == DType.FLOAT
510 and output_dtype != DType.FLOAT
511 ):
512 error_result = True
513 # invalid input types are ignored, to avoid reporting multiple errors
514
515 else:
516 if output_dtype != input_dtype:
517 error_result = True
518
519 info_dict = {
520 "error_name": ErrorIf.WrongOutputType,
521 "error_result": error_result,
522 "error_reason": (
523 "Output data type not supported for this configuration of operator"
524 ),
525 "param_reqs": {"rank": None, "dtype": None, "shape": None},
526 }
527 return info_dict
528
529 @staticmethod
530 def evWrongRank(check=False, **kwargs):
531 all_ranks = (1, 2, 3, 4, 5)
532
533 # Make a list of incorrect ranks
534 assert "op" in kwargs
535 op = kwargs["op"]
536 rmin, rmax = op["rank"]
537 rank_range = range(rmin, rmax + 1)
538 incorrect_ranks = list(set(all_ranks) - set(rank_range))
539 # Remove small incorrect ranks to avoid index errors
540 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
541 # Set minimum incorrect rank to 3 to avoid index error
542 if op["op"] in [Op.RESIZE]:
543 incorrect_ranks = [3, 5]
544 elif op["op"] in [Op.TRANSPOSE]:
545 incorrect_ranks = [7, 8]
546 elif op["op"] in [Op.CONV3D]:
547 incorrect_ranks = [6, 7]
548
549 error_name = ErrorIf.WrongRank
550 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
551 error_result = False
552 error_reason = "Rank not supported for this operator"
553
554 if check:
555 input_shape = kwargs["input_shape"]
556
557 if (
558 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
559 and len(input_shape) != 4
560 ):
561 error_result = True
562 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
563 error_result = True
564 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
565 error_result = True
566 else:
567 if len(input_shape) not in rank_range:
568 error_result = True
569
570 info_dict = {
571 "error_name": error_name,
572 "error_result": error_result,
573 "error_reason": error_reason,
574 "param_reqs": param_reqs,
575 }
576 return info_dict
577
578 @staticmethod
579 def evWrongInputList(check=False, **kwargs):
580 error_name = ErrorIf.WrongInputList
581 param_reqs = {"rank": None, "dtype": None, "shape": None}
582 error_result = False
583 error_reason = "Op input list does not match expected input"
584
585 if check:
586 op = kwargs["op"]
587 input_list = kwargs["input_list"]
588 num_operands = kwargs["num_operands"]
589 if op["op"] in [Op.SCATTER, Op.GATHER]:
590 # SCATTER/GATHER add an indices input tensor in their build functions
591 num_operands += 1
592 if len(input_list) != num_operands:
593 error_result = True
594
595 info_dict = {
596 "error_name": error_name,
597 "error_result": error_result,
598 "error_reason": error_reason,
599 "param_reqs": param_reqs,
600 }
601 return info_dict
602
603 @staticmethod
604 def evWrongOutputList(check=False, **kwargs):
605 error_name = ErrorIf.WrongOutputList
606 param_reqs = {"rank": None, "dtype": None, "shape": None}
607 error_result = False
608 error_reason = "Op output list does not match expected output"
609
610 if check:
611 output_list = kwargs["output_list"]
612 # Note this will be incorrect if an operator returns more than one output
613 if len(output_list) != 1:
614 error_result = True
615
616 info_dict = {
617 "error_name": error_name,
618 "error_result": error_result,
619 "error_reason": error_reason,
620 "param_reqs": param_reqs,
621 }
622 return info_dict
623
624 @staticmethod
625 def evMaxDimExceeded(check=False, **kwargs):
626 error_name = ErrorIf.MaxDimExceeded
627 param_reqs = {
628 "rank": [4, 4],
629 "dtype": [DType.INT8],
630 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
631 }
632 error_result = False
633 error_reason = (
634 "At least one maximum dimension is greater than or equal to 16384"
635 )
636
637 if check:
638 input_shape = kwargs["input_shape"]
639 output_shape = kwargs["output_shape"] # Note this is just (OH, OW)
640 if (
641 (input_shape[1] >= 16384)
642 or (input_shape[2] >= 16384)
643 or (output_shape[0] >= 16384)
644 or (output_shape[1] >= 16384)
645 ):
646 error_result = True
647
648 info_dict = {
649 "error_name": error_name,
650 "error_result": error_result,
651 "error_reason": error_reason,
652 "param_reqs": param_reqs,
653 }
654 return info_dict
655
656 @staticmethod
657 def evBatchMismatch(check=False, **kwargs):
658 error_name = ErrorIf.BatchMismatch
659 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
660 error_result = False
661 error_reason = "Input batch size not equal to output batch size"
662
663 assert "op" in kwargs
664 op = kwargs["op"]
665 rmin, rmax = op["rank"]
666 rank_range = range(rmin, rmax + 1)
667
668 if check:
669 input_shape = kwargs["input_shape"]
670 output_shape = kwargs[
671 "result_tensor"
672 ].shape # Note this is just (N, OH, OW, C)
673
674 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
675 error_result = True
676
677 info_dict = {
678 "error_name": error_name,
679 "error_result": error_result,
680 "error_reason": error_reason,
681 "param_reqs": param_reqs,
682 }
683 return info_dict
684
685 @staticmethod
686 def evChannelMismatch(check=False, **kwargs):
687 error_name = ErrorIf.ChannelMismatch
688 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
689 error_result = False
690 error_reason = "Input channel size not equal to output channel size"
691
692 assert "op" in kwargs
693 op = kwargs["op"]
694 rmin, rmax = op["rank"]
695 rank_range = range(rmin, rmax + 1)
696
697 if check:
698 input_shape = kwargs["input_shape"]
699 output_shape = kwargs[
700 "result_tensor"
701 ].shape # Note this is just (N, OH, OW, C)
702 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
703 error_result = True
704
705 info_dict = {
706 "error_name": error_name,
707 "error_result": error_result,
708 "error_reason": error_reason,
709 "param_reqs": param_reqs,
710 }
711 return info_dict
712
713 @staticmethod
714 def evStrideSmallerEqualZero(check=False, **kwargs):
715 error_name = ErrorIf.StrideSmallerEqualZero
716 param_reqs = {"rank": None, "dtype": None, "shape": None}
717 error_result = False
718 error_reason = "Stride value smaller than or equal zero"
719
720 if check:
721 input_dtype = kwargs["input_dtype"]
722 output_dtype = kwargs["output_dtype"]
723 if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
724 stride = kwargs["stride"] # Work around wrong input/output type tests
725 elif output_dtype == DType.FLOAT:
726 stride = kwargs["stride_fp"]
727 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
728 stride = kwargs[
729 "stride_fp"
730 ] # Work around wrong input/output type tests
731 else:
732 stride = kwargs["stride"]
733
734 if min(stride) <= 0:
735 error_result = True
736
737 info_dict = {
738 "error_name": error_name,
739 "error_result": error_result,
740 "error_reason": error_reason,
741 "param_reqs": param_reqs,
742 }
743 return info_dict
744
745 @staticmethod
746 def evStrideLargerEqualMax(check=False, **kwargs):
747 error_name = ErrorIf.StrideLargerEqualMax
748 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
749 error_result = False
750 error_reason = "Stride value larger than or equal to maximum value"
751
752 if check:
753 shift = kwargs["shift"]
754 input_dtype = kwargs["input_dtype"]
755 stride = kwargs["stride"]
756 if input_dtype in [DType.INT8, DType.INT16]:
757 if shift >= 0 and (
758 stride[0] >= (16 << shift) or stride[1] >= (16 << shift)
759 ):
760 error_result = True
761 elif shift < 0 and (
762 stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)
763 ):
764 error_result = True
765
766 info_dict = {
767 "error_name": error_name,
768 "error_result": error_result,
769 "error_reason": error_reason,
770 "param_reqs": param_reqs,
771 }
772 return info_dict
773
774 @staticmethod
775 def evStrideLargerDimension(check=False, **kwargs):
776 error_name = ErrorIf.StrideLargerDimension
777 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
778 error_result = False
779 error_reason = "Stride value larger than or equal to H/W dimension"
780
781 if check:
782 shape = kwargs["input_shape"]
783 input_dtype = kwargs["input_dtype"]
784 stride = kwargs["stride_fp"]
785
786 if (
787 input_dtype == DType.FLOAT
788 and (stride[0] > shape[1])
789 or (stride[1] > shape[2])
790 ):
791 error_result = True
792
793 info_dict = {
794 "error_name": error_name,
795 "error_result": error_result,
796 "error_reason": error_reason,
797 "param_reqs": param_reqs,
798 }
799 return info_dict
800
801 @staticmethod
802 def evOffsetSmallerEqualMin(check=False, **kwargs):
803 error_name = ErrorIf.OffsetSmallerEqualMin
804 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
805 error_result = False
806 error_reason = "Offset value smaller than or equal to minimum value"
807
808 if check:
809 shift = kwargs["shift"]
810 output_dtype = kwargs["output_dtype"]
811 if output_dtype == DType.FLOAT:
812 offset = kwargs["offset_fp"]
813 else:
814 offset = kwargs["offset"]
815
816 if shift >= 0 and (
817 offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)
818 ):
819 error_result = True
820 elif shift < 0 and (
821 offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)
822 ):
823 error_result = True
824
825 info_dict = {
826 "error_name": error_name,
827 "error_result": error_result,
828 "error_reason": error_reason,
829 "param_reqs": param_reqs,
830 }
831 return info_dict
832
833 @staticmethod
834 def evOffsetLargerEqualMax(check=False, **kwargs):
835 error_name = ErrorIf.OffsetLargerEqualMax
836 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
837 error_result = False
838 error_reason = "Offset value larger than or equal to maximum value"
839
840 if check:
841 shift = kwargs["shift"]
842 output_dtype = kwargs["output_dtype"]
843 if output_dtype == DType.FLOAT:
844 offset = kwargs["offset_fp"]
845 else:
846 offset = kwargs["offset"]
847
848 if shift >= 0:
849 if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
850 error_result = True
851
852 if shift >= 0 and (
853 offset[0] >= (16 << shift) or offset[1] >= (16 << shift)
854 ):
855 error_result = True
856 elif shift < 0 and (
857 offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)
858 ):
859 error_result = True
860
861 info_dict = {
862 "error_name": error_name,
863 "error_result": error_result,
864 "error_reason": error_reason,
865 "param_reqs": param_reqs,
866 }
867 return info_dict
868
869 @staticmethod
870 def evShiftNotZero(check=False, **kwargs):
871 error_name = ErrorIf.ShiftNotZero
872 param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
873 error_result = False
874 error_reason = "Shift value must be zero for float input"
875
876 if check:
877 shift = kwargs["shift"]
878 input_dtype = kwargs["input_dtype"]
879 output_dtype = kwargs["output_dtype"]
880 if (
881 input_dtype == DType.FLOAT
882 and output_dtype == DType.FLOAT
883 and shift != 0
884 ):
885 error_result = True
886
887 info_dict = {
888 "error_name": error_name,
889 "error_result": error_result,
890 "error_reason": error_reason,
891 "param_reqs": param_reqs,
892 }
893 return info_dict
894
895 @staticmethod
896 def evShiftSmallerOne(check=False, **kwargs):
897 error_name = ErrorIf.ShiftSmallerOne
898 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
899 error_result = False
900 error_reason = "Shift value smaller than one"
901
902 if check:
903 shift = kwargs["shift"]
904 input_dtype = kwargs["input_dtype"]
905 output_dtype = kwargs["output_dtype"]
906 if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
907 error_result = True
908
909 info_dict = {
910 "error_name": error_name,
911 "error_result": error_result,
912 "error_reason": error_reason,
913 "param_reqs": param_reqs,
914 }
915 return info_dict
916
917 @staticmethod
918 def evShiftLargerEleven(check=False, **kwargs):
919 error_name = ErrorIf.ShiftLargerEleven
920 param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
921 error_result = False
922 error_reason = "Shift value larger than eleven"
923
924 if check:
925 shift = kwargs["shift"]
926 if shift > 11:
927 error_result = True
928
929 info_dict = {
930 "error_name": error_name,
931 "error_result": error_result,
932 "error_reason": error_reason,
933 "param_reqs": param_reqs,
934 }
935 return info_dict
936
937 @staticmethod
938 def evRankMismatch(check=False, **kwargs):
939 error_name = ErrorIf.RankMismatch
940 param_reqs = {"rank": None, "dtype": None, "shape": None}
941 error_result = False
942 error_reason = "Input Rank does not match output rank"
943
944 if check:
945 input1_shape = kwargs["input1"].shape
946 input2_shape = kwargs["input2"].shape
947 # In case of SELECT op
948 input3_shape = (
949 kwargs["input3"].shape if "input3" in kwargs else input2_shape
950 )
951 output_shape = kwargs["result_tensor"].shape
952 if (
953 (len(input1_shape) != len(output_shape))
954 or (len(input2_shape) != len(output_shape))
955 or (len(input3_shape) != len(output_shape))
956 ):
957 error_result = True
958
959 info_dict = {
960 "error_name": error_name,
961 "error_result": error_result,
962 "error_reason": error_reason,
963 "param_reqs": param_reqs,
964 }
965 return info_dict
966
967 @staticmethod
968 def evDimensionMismatch(check=False, **kwargs):
969 error_name = ErrorIf.DimensionMismatch
970 param_reqs = {"rank": None, "dtype": None, "shape": None}
971 error_result = False
972 error_reason = "Input Dimensions do not match output"
973
974 if check:
975 input1_shape = kwargs["input1"].shape
976 input2_shape = kwargs["input2"].shape
977 # In case of SELECT op
978 input3_shape = (
979 kwargs["input3"].shape if "input3" in kwargs else input2_shape
980 )
981 output_shape = kwargs["result_tensor"].shape
982 for i in range(
983 min(len(input1_shape), len(input2_shape), len(input3_shape))
984 ):
985 if (
986 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
987 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
988 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
989 ):
990 error_result = True
991
992 info_dict = {
993 "error_name": error_name,
994 "error_result": error_result,
995 "error_reason": error_reason,
996 "param_reqs": param_reqs,
997 }
998 return info_dict
999
1000 @staticmethod
1001 def evInputZeroPointNotZero(check=False, **kwargs):
1002 op = kwargs["op"]
1003 error_result = False
1004
1005 # Quantizable types
1006 qTypes = (DType.INT8, DType.UINT8)
1007
1008 # This does not apply to quantizable types
1009 inputDtypes = [
1010 dtype
1011 for dtype in op["types"]
1012 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1013 or (not isinstance(dtype, list) and dtype not in qTypes)
1014 ]
1015
1016 if check:
1017 input_dtype = kwargs["input_dtype"]
1018 if isinstance(kwargs["qinfo"], tuple):
1019 qinfo = kwargs["qinfo"]
1020 input_zero_point = qinfo[0]
1021 else:
1022 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1023 qinfo = kwargs["qinfo"].ints
1024 input_zero_point = qinfo[0][1]
1025
1026 if op["op"] == Op.MATMUL:
1027 qinfo = kwargs["qinfo"].ints
1028 for dtype, zp in (
1029 (kwargs["input_dtype"], qinfo[0][1]),
1030 (kwargs["input2_dtype"], qinfo[1][1]),
1031 ):
1032 if dtype not in qTypes and zp != 0:
1033 error_result = True
1034 break
1035 else:
1036 error_result = input_dtype not in qTypes and input_zero_point != 0
1037
1038 info_dict = {
1039 "error_name": ErrorIf.InputZeroPointNotZero,
1040 "error_result": error_result,
1041 "error_reason": "Input DType not INT8 and zero point not 0",
1042 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1043 }
1044 return info_dict
1045
1046 @staticmethod
1047 def evWeightZeroPointNotZero(check=False, **kwargs):
1048 op = kwargs["op"]
1049
1050 # exclude inputs with INT8 weights
1051 inputDtypes = [
1052 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1053 ]
1054
1055 error_name = ErrorIf.WeightZeroPointNotZero
1056 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1057 error_result = False
1058 error_reason = "Weight DType not INT8 and zero point not 0"
1059
1060 if check:
1061 weight_dtype = kwargs["weight_dtype"]
1062 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
1063 qinfo = kwargs["qinfo"].ints
1064 weight_zero_point = qinfo[1][1]
1065 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1066 error_result = True
1067
1068 info_dict = {
1069 "error_name": error_name,
1070 "error_result": error_result,
1071 "error_reason": error_reason,
1072 "param_reqs": param_reqs,
1073 }
1074 return info_dict
1075
1076 @staticmethod
1077 def evOutputZeroPointNotZero(check=False, **kwargs):
1078 op = kwargs["op"]
1079 inputDtypes = op["types"].copy()
1080 if DType.INT8 in inputDtypes:
1081 inputDtypes.remove(DType.INT8)
1082 if DType.UINT8 in inputDtypes:
1083 inputDtypes.remove(DType.UINT8)
1084
1085 error_name = ErrorIf.OutputZeroPointNotZero
1086 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1087 error_result = False
1088 error_reason = "Output DType not INT8 and zero point not 0"
1089
1090 if check:
1091 input_dtype = kwargs["input_dtype"]
1092 output_dtype = kwargs["output_dtype"]
1093 if isinstance(kwargs["qinfo"], tuple):
1094 qinfo = kwargs["qinfo"]
1095 output_zero_point = qinfo[1]
1096 else:
1097 # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
1098 qinfo = kwargs["qinfo"].ints
1099 output_zero_point = qinfo[1][1]
1100 if op["op"] == Op.AVG_POOL2D:
1101 if input_dtype != DType.INT8 and output_zero_point != 0:
1102 error_result = True
1103 elif (
1104 output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
1105 ):
1106 error_result = True
1107
1108 info_dict = {
1109 "error_name": error_name,
1110 "error_result": error_result,
1111 "error_reason": error_reason,
1112 "param_reqs": param_reqs,
1113 }
1114 return info_dict
1115
1116 @staticmethod
1117 def evAxisSmallerZero(check=False, **kwargs):
1118 error_name = ErrorIf.AxisSmallerZero
1119 param_reqs = {"rank": None, "dtype": None, "shape": None}
1120 error_result = False
1121 error_reason = "Axis smaller than zero"
1122
1123 if check:
1124 axis = kwargs["axis"]
1125 if axis < 0:
1126 error_result = True
1127
1128 info_dict = {
1129 "error_name": error_name,
1130 "error_result": error_result,
1131 "error_reason": error_reason,
1132 "param_reqs": param_reqs,
1133 }
1134 return info_dict
1135
1136 @staticmethod
1137 def evAxisLargerRank(check=False, **kwargs):
1138 error_name = ErrorIf.AxisLargerRank
1139 param_reqs = {"rank": None, "dtype": None, "shape": None}
1140 error_result = False
1141 error_reason = "Axis larger than rank"
1142
1143 if check:
1144 axis = kwargs["axis"]
1145 shape = kwargs["input_shape"]
1146 if axis > len(shape):
1147 error_result = True
1148
1149 info_dict = {
1150 "error_name": error_name,
1151 "error_result": error_result,
1152 "error_reason": error_reason,
1153 "param_reqs": param_reqs,
1154 }
1155 return info_dict
1156
1157 @staticmethod
1158 def evShapeOfAxisNotOne(check=False, **kwargs):
1159 error_name = ErrorIf.ShapeOfAxisNotOne
1160 param_reqs = {"rank": None, "dtype": None, "shape": None}
1161 error_result = False
1162 error_reason = "shape[axis] is not equal to 1"
1163
1164 if check:
1165 axis = kwargs["axis"]
1166 shape = kwargs["output_shape"]
1167 if (0 <= axis < len(shape)) and shape[axis] != 1:
1168 error_result = True
1169
1170 info_dict = {
1171 "error_name": error_name,
1172 "error_result": error_result,
1173 "error_reason": error_reason,
1174 "param_reqs": param_reqs,
1175 }
1176 return info_dict
1177
1178 @staticmethod
1179 def evPadSmallerZero(check=False, **kwargs):
1180 error_name = ErrorIf.PadSmallerZero
1181 param_reqs = {"rank": None, "dtype": None, "shape": None}
1182 error_result = False
1183 error_reason = "At least one pad is smaller than zero"
1184
1185 if check:
1186 op = kwargs["op"]
1187 pad = kwargs["pad"]
1188 if op["op"] == Op.PAD:
1189 for padding in pad:
1190 if min(padding) < 0:
1191 error_result = True
1192 else:
1193 if min(pad) < 0:
1194 error_result = True
1195
1196 info_dict = {
1197 "error_name": error_name,
1198 "error_result": error_result,
1199 "error_reason": error_reason,
1200 "param_reqs": param_reqs,
1201 }
1202 return info_dict
1203
1204 @staticmethod
1205 def evPadLargerEqualKernel(check=False, **kwargs):
1206 error_name = ErrorIf.PadLargerEqualKernel
1207 param_reqs = {"rank": None, "dtype": None, "shape": None}
1208 error_result = False
1209 error_reason = "At least one pad is larger than kernel dimension"
1210
1211 if check:
1212 pad = kwargs["pad"]
1213 kernel = kwargs["kernel"]
1214 if min(pad) > 0 and min(kernel) > 1:
1215 if (
1216 pad[0] >= kernel[0]
1217 or pad[1] >= kernel[0]
1218 or pad[2] >= kernel[1]
1219 or pad[3] >= kernel[1]
1220 ):
1221 error_result = True
1222
1223 info_dict = {
1224 "error_name": error_name,
1225 "error_result": error_result,
1226 "error_reason": error_reason,
1227 "param_reqs": param_reqs,
1228 }
1229 return info_dict
1230
1231 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001232 def checkPoolingParams(kernel, stride, pad):
1233 return (
1234 min(kernel) >= 1
1235 and min(stride) >= 1
1236 and min(pad) >= 0
1237 and not (
1238 pad[0] >= kernel[0]
1239 or pad[1] >= kernel[0]
1240 or pad[2] >= kernel[1]
1241 or pad[3] >= kernel[1]
1242 )
1243 )
1244
1245 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001246 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1247 error_name = ErrorIf.PoolingOutputShapeMismatch
1248 param_reqs = {"rank": None, "dtype": None, "shape": None}
1249 error_result = False
1250 error_reason = (
1251 "Mismatch between output shape provided and expected output shape"
1252 )
1253
1254 if check:
1255 pad = kwargs["pad"]
1256 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1257
1258 kernel = kwargs["kernel"]
1259 kernel_y, kernel_x = kernel[0], kernel[1]
1260
1261 input_shape = kwargs["input_shape"]
1262 IH, IW = input_shape[1], input_shape[2]
1263
1264 output_shape = kwargs["output_shape"]
1265 OH, OW = output_shape[1], output_shape[2]
1266
1267 stride = kwargs["stride"]
1268 stride_y, stride_x = stride[0], stride[1]
1269
1270 # calculate correct height, width dimensions
1271 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001272 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1273 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001274
1275 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001276 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001277
1278 if params_valid and (OH != y_correct or OW != x_correct):
1279 error_result = True
1280
1281 info_dict = {
1282 "error_name": error_name,
1283 "error_result": error_result,
1284 "error_reason": error_reason,
1285 "param_reqs": param_reqs,
1286 }
1287 return info_dict
1288
1289 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001290 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1291 error_name = ErrorIf.PoolingOutputShapeNonInteger
1292 param_reqs = {"rank": None, "dtype": None, "shape": None}
1293 error_result = False
1294 error_reason = "Parameters do not yield exact integer output dimensions"
1295
1296 if check:
1297 pad = kwargs["pad"]
1298 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1299
1300 kernel = kwargs["kernel"]
1301 kernel_y, kernel_x = kernel[0], kernel[1]
1302
1303 input_shape = kwargs["input_shape"]
1304 IH, IW = input_shape[1], input_shape[2]
1305
1306 stride = kwargs["stride"]
1307 stride_y, stride_x = stride[0], stride[1]
1308
1309 # calculate remainder of height, width dimensions
1310 if stride_x != 0 and stride_y != 0:
1311 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1312 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1313
1314 # ensure parameters are valid
1315 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1316 if params_valid and (y_remainder != 0 or x_remainder != 0):
1317 error_result = True
1318
1319 info_dict = {
1320 "error_name": error_name,
1321 "error_result": error_result,
1322 "error_reason": error_reason,
1323 "param_reqs": param_reqs,
1324 }
1325 return info_dict
1326
1327 @staticmethod
1328 def checkConvParams(weight_shape, stride, pad, dilation):
1329 return (
1330 # Check kernel sizes
1331 min(weight_shape[1:-1]) >= 1
1332 and min(stride) >= 1
1333 and min(pad) >= 0
1334 and (dilation is None or min(dilation) >= 1)
1335 )
1336
1337 @staticmethod
1338 def evConvOutputShapeMismatch(check=False, **kwargs):
1339 error_name = ErrorIf.ConvOutputShapeMismatch
1340 param_reqs = {"rank": None, "dtype": None, "shape": None}
1341 error_result = False
1342 error_reason = (
1343 "Mismatch between output shape provided and expected output shape"
1344 )
1345
1346 if check:
1347 op = kwargs["op"]
1348 pad = kwargs["pad"]
1349 weight_shape = kwargs["weight_shape"]
1350 input_shape = kwargs["input_shape"]
1351 output_shape = kwargs["output_shape"]
1352 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1353 stride = kwargs["stride"]
1354
1355 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1356
1357 # calculate correct dimensions
1358 dims_correct = []
1359 if min(stride) > 0:
1360 for index in range(len(stride)):
1361 pad_offset = index * 2
1362 if op["op"] == Op.TRANSPOSE_CONV2D:
1363 dims_correct.append(
1364 (input_shape[index + 1] - 1) * stride[index]
1365 - pad[pad_offset]
1366 - pad[pad_offset + 1]
1367 + weight_shape[index + kernel_offset]
1368 )
1369 else:
1370 dims_correct.append(
1371 (
1372 input_shape[index + 1]
1373 - 1
1374 + pad[pad_offset]
1375 + pad[pad_offset + 1]
1376 - (weight_shape[index + kernel_offset] - 1)
1377 * dilation[index]
1378 )
1379 // stride[index]
1380 + 1
1381 )
1382
1383 # ensure parameters are valid
1384 params_valid = TosaErrorValidator.checkConvParams(
1385 weight_shape, stride, pad, dilation
1386 )
1387
1388 if params_valid and output_shape[1:-1] != dims_correct:
1389 error_result = True
1390
1391 info_dict = {
1392 "error_name": error_name,
1393 "error_result": error_result,
1394 "error_reason": error_reason,
1395 "param_reqs": param_reqs,
1396 }
1397 return info_dict
1398
1399 @staticmethod
1400 def evConvOutputShapeNonInteger(check=False, **kwargs):
1401 error_name = ErrorIf.ConvOutputShapeNonInteger
1402 param_reqs = {"rank": None, "dtype": None, "shape": None}
1403 error_result = False
1404 error_reason = "Parameters do not yield exact integer output dimensions"
1405
1406 if check:
1407 op = kwargs["op"]
1408 pad = kwargs["pad"]
1409 weight_shape = kwargs["weight_shape"]
1410 input_shape = kwargs["input_shape"]
1411 dilation = kwargs["dilation"]
1412 stride = kwargs["stride"]
1413
1414 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1415
1416 # calculate correct height, width dimensions
1417 remainders = []
1418 if min(stride) > 0:
1419 for index in range(len(stride)):
1420 pad_offset = index * 2
1421 remainders.append(
1422 (
1423 input_shape[index + 1]
1424 - 1
1425 + pad[pad_offset]
1426 + pad[pad_offset + 1]
1427 - (weight_shape[index + kernel_offset] - 1)
1428 * dilation[index]
1429 )
1430 % stride[index]
1431 )
1432
1433 # ensure parameters are valid
1434 params_valid = TosaErrorValidator.checkConvParams(
1435 weight_shape, stride, pad, dilation
1436 )
1437 if params_valid and max(remainders) > 0:
1438 error_result = True
1439
1440 info_dict = {
1441 "error_name": error_name,
1442 "error_result": error_result,
1443 "error_reason": error_reason,
1444 "param_reqs": param_reqs,
1445 }
1446 return info_dict
1447
1448 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001449 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1450 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1451 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1452 error_result = False
1453 error_reason = (
1454 "Mismatch between output shape provided and expected output shape"
1455 )
1456
1457 if check:
1458 output_shape = kwargs["output_shape"]
1459 input_shape = kwargs["input_shape"]
1460 axis = kwargs["axis"]
1461
1462 dimension_match = True
1463 axis_shift = 0
1464
1465 # Check that rank is correct before trying to check dimensions
1466 if (len(input_shape) - 1) == len(output_shape):
1467 for i in range(len(input_shape)):
1468 if i == axis:
1469 axis_shift = 1
1470 continue
1471 if input_shape[i] != output_shape[i - axis_shift]:
1472 dimension_match = False
1473
1474 if not dimension_match:
1475 error_result = True
1476
1477 info_dict = {
1478 "error_name": error_name,
1479 "error_result": error_result,
1480 "error_reason": error_reason,
1481 "param_reqs": param_reqs,
1482 }
1483 return info_dict
1484
1485 @staticmethod
1486 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1487 error_name = ErrorIf.ArgmaxOutputRankMismatch
1488 param_reqs = {"rank": None, "dtype": None, "shape": None}
1489 error_result = False
1490 error_reason = (
1491 "Mismatch between output shape provided and expected output shape"
1492 )
1493
1494 if check:
1495 output_shape = kwargs["output_shape"]
1496 input_shape = kwargs["input_shape"]
1497 axis = kwargs["axis"]
1498 valid_params = axis >= 0 and axis < len(input_shape)
1499
1500 if valid_params and (len(input_shape) - 1) != len(output_shape):
1501 error_result = True
1502
1503 info_dict = {
1504 "error_name": error_name,
1505 "error_result": error_result,
1506 "error_reason": error_reason,
1507 "param_reqs": param_reqs,
1508 }
1509 return info_dict
1510
1511 @staticmethod
1512 def evKernelSmallerOne(check=False, **kwargs):
1513 error_name = ErrorIf.KernelSmallerOne
1514 param_reqs = {"rank": None, "dtype": None, "shape": None}
1515 error_result = False
1516 error_reason = "At least one kernel dimension is smaller than zero"
1517
1518 if check:
1519 kernel = kwargs["kernel"]
1520 if min(kernel) < 1:
1521 error_result = True
1522
1523 info_dict = {
1524 "error_name": error_name,
1525 "error_result": error_result,
1526 "error_reason": error_reason,
1527 "param_reqs": param_reqs,
1528 }
1529 return info_dict
1530
1531 @staticmethod
1532 def evStrideSmallerOne(check=False, **kwargs):
1533 error_name = ErrorIf.StrideSmallerOne
1534 param_reqs = {"rank": None, "dtype": None, "shape": None}
1535 error_result = False
1536 error_reason = "At least one stride dimension is smaller than zero"
1537
1538 if check:
1539 stride = kwargs["stride"]
1540 if min(stride) < 1:
1541 error_result = True
1542
1543 info_dict = {
1544 "error_name": error_name,
1545 "error_result": error_result,
1546 "error_reason": error_reason,
1547 "param_reqs": param_reqs,
1548 }
1549 return info_dict
1550
1551 @staticmethod
1552 def evDilationSmallerOne(check=False, **kwargs):
1553 error_result = check and min(kwargs["dilation"]) < 1
1554 return {
1555 "error_name": ErrorIf.DilationSmallerOne,
1556 "error_reason": "At least one dilation is smaller than one",
1557 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1558 "error_result": error_result,
1559 }
1560
1561 @staticmethod
1562 def evScaleTrue(check=False, **kwargs):
1563 error_name = ErrorIf.ScaleTrue
1564 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1565 error_result = False
1566 error_reason = "Scale set to true but input type is INT48"
1567
1568 if check:
1569 input_dtype = kwargs["input_dtype"]
1570 scale32 = kwargs["scale32"]
1571 if scale32 and input_dtype == DType.INT48:
1572 error_result = True
1573
1574 info_dict = {
1575 "error_name": error_name,
1576 "error_result": error_result,
1577 "error_reason": error_reason,
1578 "param_reqs": param_reqs,
1579 }
1580 return info_dict
1581
1582 @staticmethod
1583 def evScaleNotTrue(check=False, **kwargs):
1584 error_name = ErrorIf.ScaleNotTrue
1585 param_reqs = {"rank": None, "dtype": None, "shape": None}
1586 error_result = False
1587 error_reason = "Scale set to false but double round set to true"
1588
1589 if check:
1590 scale32 = kwargs["scale32"]
1591 double_round = kwargs["double_round"]
1592 if not scale32 and double_round:
1593 error_result = True
1594
1595 info_dict = {
1596 "error_name": error_name,
1597 "error_result": error_result,
1598 "error_reason": error_reason,
1599 "param_reqs": param_reqs,
1600 }
1601 return info_dict
1602
1603 @staticmethod
1604 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1605 error_name = ErrorIf.TensorSizeInputOutputMismatch
1606 param_reqs = {"rank": None, "dtype": None, "shape": None}
1607 error_result = False
1608 error_reason = "Input tensor size does not match output tensor size"
1609
1610 if check:
1611 input_shape = kwargs["input_shape"]
1612 output_shape = kwargs["output_shape"]
1613 input_size = np.prod(input_shape)
1614 output_size = np.prod(output_shape)
1615 if input_size != output_size:
1616 error_result = True
1617
1618 info_dict = {
1619 "error_name": error_name,
1620 "error_result": error_result,
1621 "error_reason": error_reason,
1622 "param_reqs": param_reqs,
1623 }
1624 return info_dict
1625
1626 @staticmethod
1627 def evStartSmallerZero(check=False, **kwargs):
1628 error_name = ErrorIf.StartSmallerZero
1629 param_reqs = {"rank": None, "dtype": None, "shape": None}
1630 error_result = False
1631 error_reason = "Starting point smaller than zero"
1632
1633 if check:
1634 input_shape = kwargs["input_shape"]
1635 start = kwargs["start"]
1636 rank = len(input_shape)
1637 if len(start) == rank:
1638 for index in range(rank):
1639 if start[index] < 0:
1640 error_result = True
1641
1642 info_dict = {
1643 "error_name": error_name,
1644 "error_result": error_result,
1645 "error_reason": error_reason,
1646 "param_reqs": param_reqs,
1647 }
1648 return info_dict
1649
1650 @staticmethod
1651 def evSizeSmallerEqualZero(check=False, **kwargs):
1652 error_name = ErrorIf.SizeSmallerEqualZero
1653 param_reqs = {"rank": None, "dtype": None, "shape": None}
1654 error_result = False
1655 error_reason = "Size smaller than or equal to zero"
1656
1657 if check:
1658 input_shape = kwargs["input_shape"]
1659 size = kwargs["size"]
1660 rank = len(input_shape)
1661 if len(size) == rank:
1662 for index in range(rank):
1663 if size[index] <= 0:
1664 error_result = True
1665
1666 info_dict = {
1667 "error_name": error_name,
1668 "error_result": error_result,
1669 "error_reason": error_reason,
1670 "param_reqs": param_reqs,
1671 }
1672 return info_dict
1673
1674 @staticmethod
1675 def evStartSizeOutsideBounds(check=False, **kwargs):
1676 error_name = ErrorIf.StartSizeOutsideBounds
1677 param_reqs = {"rank": None, "dtype": None, "shape": None}
1678 error_result = False
1679 error_reason = "starting point plus size larger than input dimension"
1680
1681 if check:
1682 input_shape = kwargs["input_shape"]
1683 start = kwargs["start"]
1684 size = kwargs["size"]
1685 rank = len(input_shape)
1686 if len(start) == rank and len(size) == rank:
1687 for index in range(rank):
1688 if start[index] + size[index] > input_shape[index]:
1689 error_result = True
1690
1691 info_dict = {
1692 "error_name": error_name,
1693 "error_result": error_result,
1694 "error_reason": error_reason,
1695 "param_reqs": param_reqs,
1696 }
1697 return info_dict
1698
1699 @staticmethod
1700 def evSizeOutputShapeMismatch(check=False, **kwargs):
1701 error_name = ErrorIf.SizeOutputShapeMismatch
1702 param_reqs = {"rank": None, "dtype": None, "shape": None}
1703 error_result = False
1704 error_reason = "Size does not match output dimension"
1705
1706 if check:
1707 input_shape = kwargs["input_shape"]
1708 output_shape = kwargs["output_shape"]
1709 size = kwargs["size"]
1710 rank = len(input_shape)
1711 if len(size) == rank:
1712 for index in range(rank):
1713 if size[index] != output_shape[index]:
1714 error_result = True
1715
1716 info_dict = {
1717 "error_name": error_name,
1718 "error_result": error_result,
1719 "error_reason": error_reason,
1720 "param_reqs": param_reqs,
1721 }
1722 return info_dict
1723
1724 @staticmethod
1725 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1726 error_name = ErrorIf.InputSizeStartLengthMismatch
1727 param_reqs = {"rank": None, "dtype": None, "shape": None}
1728 error_result = False
1729 error_reason = "rank of input not equal to length of start or size"
1730
1731 if check:
1732 input_shape = kwargs["input_shape"]
1733 start = kwargs["start"]
1734 size = kwargs["size"]
1735 rank = len(input_shape)
1736 if rank != len(start) or rank != len(size):
1737 error_result = True
1738
1739 info_dict = {
1740 "error_name": error_name,
1741 "error_result": error_result,
1742 "error_reason": error_reason,
1743 "param_reqs": param_reqs,
1744 }
1745 return info_dict
1746
1747 @staticmethod
1748 def evIndexOutsideBounds(check=False, **kwargs):
1749 error_name = ErrorIf.IndexOutsideBounds
1750 param_reqs = {"rank": None, "dtype": None, "shape": None}
1751 error_result = False
1752 error_reason = "Index outside of allowed bounds"
1753
1754 if check:
1755 input_shape = kwargs["input_shape"]
1756 perms = kwargs["perms"]
1757 rank = len(input_shape)
1758
1759 for index in perms:
1760 if index < 0 or index > rank:
1761 error_result = True
1762
1763 info_dict = {
1764 "error_name": error_name,
1765 "error_result": error_result,
1766 "error_reason": error_reason,
1767 "param_reqs": param_reqs,
1768 }
1769 return info_dict
1770
1771 @staticmethod
1772 def evIndexUsedTwice(check=False, **kwargs):
1773 error_name = ErrorIf.IndexUsedTwice
1774 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1775 error_result = False
1776 error_reason = "Index used multiple times"
1777
1778 if check:
1779 perms = kwargs["perms"]
1780
1781 unique_indices = []
1782 for index in perms:
1783 if index in unique_indices:
1784 error_result = True
1785 else:
1786 unique_indices.append(index)
1787
1788 info_dict = {
1789 "error_name": error_name,
1790 "error_result": error_result,
1791 "error_reason": error_reason,
1792 "param_reqs": param_reqs,
1793 }
1794 return info_dict
1795
1796 @staticmethod
1797 def evMaxSmallerMin(check=False, **kwargs):
1798 error_name = ErrorIf.MaxSmallerMin
1799 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1800 error_result = False
1801 error_reason = "Max value smaller than min value"
1802
1803 if check:
1804 max_val = kwargs["max_val"]
1805 min_val = kwargs["min_val"]
1806 if max_val < min_val:
1807 error_result = True
1808
1809 info_dict = {
1810 "error_name": error_name,
1811 "error_result": error_result,
1812 "error_reason": error_reason,
1813 "param_reqs": param_reqs,
1814 }
1815 return info_dict
1816
1817 @staticmethod
1818 def evConcatInputRankMismatch(check=False, **kwargs):
1819 error_name = ErrorIf.ConcatInputRankMismatch
1820 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1821 error_result = False
1822 error_reason = "Input ranks are not identical"
1823
1824 if check:
1825 inputs = kwargs["inputs"]
1826 input_shape = kwargs["input_shape"]
1827 for input in inputs:
1828 if len(input.shape) != len(input_shape):
1829 error_result = True
1830
1831 info_dict = {
1832 "error_name": error_name,
1833 "error_result": error_result,
1834 "error_reason": error_reason,
1835 "param_reqs": param_reqs,
1836 }
1837 return info_dict
1838
1839 @staticmethod
1840 def evConcatInputDimMismatch(check=False, **kwargs):
1841 error_name = ErrorIf.ConcatInputDimMismatch
1842 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1843 error_result = False
1844 error_reason = "Input dimensions differ on too many axes"
1845
1846 if check:
1847 inputs = kwargs["inputs"]
1848 input_shape = kwargs["input_shape"]
1849 axis = kwargs["axis"]
1850
1851 # Ensure rank is valid before checking dims.
1852 valid_rank = True
1853 for input in inputs:
1854 if len(input.shape) != len(input_shape):
1855 valid_rank = False
1856
1857 if valid_rank:
1858 for input in inputs:
1859 for i, dim in enumerate(input.shape):
1860 if dim != input_shape[i] and axis != i:
1861 error_result = True
1862
1863 info_dict = {
1864 "error_name": error_name,
1865 "error_result": error_result,
1866 "error_reason": error_reason,
1867 "param_reqs": param_reqs,
1868 }
1869 return info_dict
1870
1871 @staticmethod
1872 def evConcatShapeSumMismatch(check=False, **kwargs):
1873 error_name = ErrorIf.ConcatShapeSumMismatch
1874 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1875 error_result = False
1876 error_reason = "Sum of dimensions on axis not equal to output dimension"
1877
1878 if check:
1879 inputs = kwargs["inputs"]
1880 input_shape = kwargs["input_shape"]
1881 output_shape = kwargs["output_shape"]
1882 axis = kwargs["axis"]
1883
1884 # Ensure rank is valid before checking dims.
1885 valid_params = True
1886 for input in inputs:
1887 if len(input.shape) != len(input_shape):
1888 valid_params = False
1889 if axis < 0 or axis > len(input_shape):
1890 valid_params = False
1891
1892 if valid_params:
1893 axis_dim_sum = 0
1894 for input in inputs:
1895 axis_dim_sum += input.shape[axis]
1896
1897 if axis_dim_sum != output_shape[axis]:
1898 error_result = True
1899
1900 info_dict = {
1901 "error_name": error_name,
1902 "error_result": error_result,
1903 "error_reason": error_reason,
1904 "param_reqs": param_reqs,
1905 }
1906 return info_dict
1907
1908 @staticmethod
1909 def evInputListThenGraphMismatch(check=False, **kwargs):
1910 error_name = ErrorIf.CondIfInputListThenGraphMismatch
1911 param_reqs = {"rank": None, "dtype": None, "shape": None}
1912 error_result = False
1913 error_reason = "Input list shape does not match then-graph shape"
1914
1915 if check:
1916 a = kwargs["a"]
1917 b = kwargs["b"]
1918 basicBlocks = kwargs["basicBlocks"]
1919 then_block = basicBlocks[1]
1920 then_inputs = then_block.inputs
1921 then_tens = then_block.tensors
1922 if (a.shape != then_tens[then_inputs[0]].shape) or (
1923 b.shape != then_tens[then_inputs[1]].shape
1924 ):
1925 error_result = True
1926
1927 info_dict = {
1928 "error_name": error_name,
1929 "error_result": error_result,
1930 "error_reason": error_reason,
1931 "param_reqs": param_reqs,
1932 }
1933 return info_dict
1934
1935 @staticmethod
1936 def evInputListElseGraphMismatch(check=False, **kwargs):
1937 error_name = ErrorIf.CondIfInputListElseGraphMismatch
1938 param_reqs = {"rank": None, "dtype": None, "shape": None}
1939 error_result = False
1940 error_reason = "Input list shape does not match else-graph shape"
1941
1942 if check:
1943 a = kwargs["a"]
1944 b = kwargs["b"]
1945 basicBlocks = kwargs["basicBlocks"]
1946 else_block = basicBlocks[2]
1947 else_inputs = else_block.inputs
1948 else_tens = else_block.tensors
1949 if (a.shape != else_tens[else_inputs[0]].shape) or (
1950 b.shape != else_tens[else_inputs[1]].shape
1951 ):
1952 error_result = True
1953
1954 info_dict = {
1955 "error_name": error_name,
1956 "error_result": error_result,
1957 "error_reason": error_reason,
1958 "param_reqs": param_reqs,
1959 }
1960 return info_dict
1961
1962 @staticmethod
1963 def evOutputListThenGraphMismatch(check=False, **kwargs):
1964 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
1965 param_reqs = {"rank": None, "dtype": None, "shape": None}
1966 error_result = False
1967 error_reason = "Output list shape does not match then-graph shape"
1968
1969 if check:
1970 basicBlocks = kwargs["basicBlocks"]
1971 cond_block = basicBlocks[0]
1972 cond_outputs = cond_block.outputs
1973 cond_tens = cond_block.tensors
1974 then_block = basicBlocks[1]
1975 then_outputs = then_block.outputs
1976 then_tens = then_block.tensors
1977 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
1978 error_result = True
1979
1980 info_dict = {
1981 "error_name": error_name,
1982 "error_result": error_result,
1983 "error_reason": error_reason,
1984 "param_reqs": param_reqs,
1985 }
1986 return info_dict
1987
1988 @staticmethod
1989 def evOutputListElseGraphMismatch(check=False, **kwargs):
1990 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
1991 param_reqs = {"rank": None, "dtype": None, "shape": None}
1992 error_result = False
1993 error_reason = "Output list shape does not match else-graph shape"
1994
1995 if check:
1996 basicBlocks = kwargs["basicBlocks"]
1997 cond_block = basicBlocks[0]
1998 cond_outputs = cond_block.outputs
1999 cond_tens = cond_block.tensors
2000 else_block = basicBlocks[2]
2001 else_outputs = else_block.outputs
2002 else_tens = else_block.tensors
2003 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2004 error_result = True
2005
2006 info_dict = {
2007 "error_name": error_name,
2008 "error_result": error_result,
2009 "error_reason": error_reason,
2010 "param_reqs": param_reqs,
2011 }
2012 return info_dict
2013
2014 @staticmethod
2015 def evInputListOutputListMismatch(check=False, **kwargs):
2016 error_name = ErrorIf.InputListOutputListMismatch
2017 param_reqs = {"rank": None, "dtype": None, "shape": None}
2018 error_result = False
2019 error_reason = "Input list does not match output list"
2020
2021 if check:
2022 basicBlocks = kwargs["basicBlocks"]
2023 while_block = basicBlocks[0]
2024 while_inputs = while_block.inputs
2025 while_outputs = while_block.outputs
2026 while_tens = while_block.tensors
2027 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2028 error_result = True
2029
2030 info_dict = {
2031 "error_name": error_name,
2032 "error_result": error_result,
2033 "error_reason": error_reason,
2034 "param_reqs": param_reqs,
2035 }
2036 return info_dict
2037
2038 @staticmethod
2039 def evInputListCondGraphMismatch(check=False, **kwargs):
2040 error_name = ErrorIf.InputListCondGraphMismatch
2041 param_reqs = {"rank": None, "dtype": None, "shape": None}
2042 error_result = False
2043 error_reason = "Input list does not match cond graph"
2044
2045 if check:
2046 basicBlocks = kwargs["basicBlocks"]
2047 while_block = basicBlocks[0]
2048 while_inputs = while_block.inputs
2049 while_tens = while_block.tensors
2050 cond_block = basicBlocks[1]
2051 cond_inputs = cond_block.inputs
2052 cond_tens = cond_block.tensors
2053 if (
2054 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2055 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2056 error_result = True
2057
2058 info_dict = {
2059 "error_name": error_name,
2060 "error_result": error_result,
2061 "error_reason": error_reason,
2062 "param_reqs": param_reqs,
2063 }
2064 return info_dict
2065
2066 @staticmethod
2067 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2068 error_name = ErrorIf.InputListBodyGraphInputMismatch
2069 param_reqs = {"rank": None, "dtype": None, "shape": None}
2070 error_result = False
2071 error_reason = "Input list does not match body graph input"
2072
2073 if check:
2074 basicBlocks = kwargs["basicBlocks"]
2075 while_block = basicBlocks[0]
2076 while_inputs = while_block.inputs
2077 while_tens = while_block.tensors
2078 body_block = basicBlocks[2]
2079 body_outputs = body_block.inputs
2080 body_tens = body_block.tensors
2081 if (
2082 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2083 ) or (
2084 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2085 ):
2086 error_result = True
2087
2088 info_dict = {
2089 "error_name": error_name,
2090 "error_result": error_result,
2091 "error_reason": error_reason,
2092 "param_reqs": param_reqs,
2093 }
2094 return info_dict
2095
2096 @staticmethod
2097 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2098 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2099 param_reqs = {"rank": None, "dtype": None, "shape": None}
2100 error_result = False
2101 error_reason = "Input list does not match body graph output"
2102
2103 if check:
2104 basicBlocks = kwargs["basicBlocks"]
2105 while_block = basicBlocks[0]
2106 while_inputs = while_block.inputs
2107 while_tens = while_block.tensors
2108 body_block = basicBlocks[2]
2109 body_outputs = body_block.outputs
2110 body_tens = body_block.tensors
2111 if (
2112 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2113 ) or (
2114 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2115 ):
2116 error_result = True
2117 info_dict = {
2118 "error_name": error_name,
2119 "error_result": error_result,
2120 "error_reason": error_reason,
2121 "param_reqs": param_reqs,
2122 }
2123 return info_dict
2124
2125 @staticmethod
2126 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2127 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2128 param_reqs = {"rank": None, "dtype": None, "shape": None}
2129 error_result = False
2130 error_reason = "Cond graph output is not a match list of booleans"
2131
2132 if check:
2133 basicBlocks = kwargs["basicBlocks"]
2134 cond_block = basicBlocks[1]
2135 cond_outputs = cond_block.outputs
2136 cond_tens = cond_block.tensors
2137 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2138 error_result = True
2139
2140 info_dict = {
2141 "error_name": error_name,
2142 "error_result": error_result,
2143 "error_reason": error_reason,
2144 "param_reqs": param_reqs,
2145 }
2146 return info_dict
2147
2148
2149class TosaInvalidValidator:
2150 @staticmethod
2151 def ivWrongDataTypeOrModeResize(**kwargs):
2152 input_dtype = kwargs["input_dtype"]
2153 args = kwargs["args"]
2154 mode = args[0]
2155 output_dtype = args[8]
2156
2157 if mode == ResizeMode.BILINEAR:
2158 # Invalid output data type / Invalid input datatype
2159 return (
2160 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
2161 or not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
2162 or not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
2163 or (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
2164 )
2165 elif mode == ResizeMode.NEAREST:
2166 # Invalid output data type / Invalid input datatype
2167 return (input_dtype != output_dtype) or (
2168 input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]
2169 )
2170 else:
2171 # Invalid resize mode
2172 return True
2173
2174 @staticmethod
2175 def ivBadStride(**kwargs):
2176 input_dtype = kwargs["input_dtype"]
2177 args = kwargs["args"]
2178 stride_x = args[1][0]
2179 stride_y = args[1][1]
2180 stride_fp_x = args[4][0]
2181 stride_fp_y = args[4][1]
2182
2183 if input_dtype == DType.FLOAT:
2184 if stride_fp_x <= 0 or stride_fp_y <= 0:
2185 # Negative or zero stride
2186 return True
2187 else:
2188 if stride_x <= 0 or stride_y <= 0:
2189 # Negative or zero stride
2190 return True
2191 return False
2192
2193 @staticmethod
2194 def ivHeightWidthInvalid(**kwargs):
2195 opName = kwargs["opName"]
2196
2197 inputShapes = kwargs["shapeList"]
2198 input_shape = inputShapes[0]
2199
2200 args = kwargs["args"]
2201 strides = args[0]
2202 padding = args[1]
2203
2204 if opName.endswith("pool2d"):
2205 # avg_pool2d, max_pool2d
2206 kernel_shape = args[2]
2207 h = (
2208 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2209 ) // strides[0]
2210 w = (
2211 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2212 ) // strides[1]
2213 # return True if any dimension is < 1
2214 return h < 1 or w < 1
2215
2216 if opName.startswith("transpose_conv2d"):
2217 # transpose_conv2d
2218 dilations = args[2]
2219 output_shape = args[3]
2220 filter_shape = inputShapes[1]
2221 kernel_shape = filter_shape[1:-1]
2222
2223 def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
2224 """Calculate the transpose_conv2d output size for a dimension.
2225
2226 Based on the keras function deconv_output_length, in
2227 https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
2228
2229 Args:
2230 in_size: the input size - int
2231 stride: the stride - int
2232 kernel_size: the kernel size - int
2233 dilation: the kernel dilation - int
2234 out_pad: the output padding - int
2235 in_pad: the input padding - int
2236
2237 Returns:
2238 the output size
2239 """
2240 dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
2241 return (
2242 (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
2243 )
2244
2245 for pad_h, pad_w in (
2246 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2247 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2248 (0, 0), # VALID padding
2249 ):
2250 h = get_out_size(
2251 input_shape[1],
2252 strides[0],
2253 kernel_shape[0],
2254 dilations[0],
2255 padding[0],
2256 pad_h,
2257 )
2258 w = get_out_size(
2259 input_shape[2],
2260 strides[1],
2261 kernel_shape[1],
2262 dilations[1],
2263 padding[1],
2264 pad_w,
2265 )
2266 if output_shape[1] == h and output_shape[2] == w:
2267 return False
2268
2269 # output shape does not match the expected shape for any padding option
2270 return True
2271
2272 if "conv2d" in opName or "conv3d" in opName:
2273 # conv2d, conv3d, depthwise_conv2d
2274 dilations = args[2]
2275 filter_shape = inputShapes[1]
2276 kernel_shape = (
2277 filter_shape[0:2]
2278 if opName.startswith("depthwise_conv2d")
2279 else filter_shape[1:-1]
2280 )
2281
2282 for i in range(len(kernel_shape)):
2283 dim = (
2284 input_shape[i + 1]
2285 - kernel_shape[i]
2286 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2287 + padding[i * 2 + 0]
2288 + padding[i * 2 + 1]
2289 ) // strides[i] + 1
2290 # return True if any dimension is < 1
2291 if dim < 1:
2292 return True
2293 return False
2294
2295 assert False, f"Unrecognized Op: {opName}"
2296
2297 @staticmethod
2298 def ivNonPositiveOutputShape(**kwargs):
2299 args = kwargs["args"]
2300 output_shape = args[3]
2301 if output_shape[1] <= 0 or output_shape[2] <= 0:
2302 # Negative output shape
2303 return True
2304 return False