blob: 1651d9576308bf4db7b20056d71f7df7da5d02b2 [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005from generator.tosa_utils import product
6from generator.tosa_utils import usableDTypes
7from generator.tosa_utils import valueToName
8from tosa.DType import DType
9from tosa.Op import Op
10from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011
Matthew Haddone86fd342021-09-07 16:12:21 +010012
13class ErrorIf(object):
14 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010015 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
16 ScaleNLargerMax = "ScaleNLargerMax"
17 ScaleDLargerMax = "ScaleDLargerMax"
18 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010019 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010020 BorderSmallerMin = "BorderSmallerMin"
21 BorderLargerEqualMax = "BorderLargerEqualMax"
22 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
23 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010024 WrongInputType = "WrongInputType"
25 WrongOutputType = "WrongOutputType"
26 WrongInputList = "WrongInputList"
27 WrongOutputList = "WrongOutputList"
28 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010029 BatchMismatch = "BatchMismatch"
30 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010031 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000032 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010033 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010034 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010036 AxisSmallerZero = "AxisSmallerZero"
37 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010038 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
39 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010040 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010041 KernelSmallerOne = "KernelSmallerOne"
42 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000043 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010044 PadSmallerZero = "PadSmallerZero"
45 PadLargerEqualKernel = "PadLargerEqualKernel"
46 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010047 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
48 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
49 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010050 ScaleNotTrue = "ScaleNotTrue"
51 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010052 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
53 StartSmallerZero = "StartSmallerZero"
54 SizeSmallerEqualZero = "SizeSmallerEqualZero"
55 StartSizeOutsideBounds = "StartSizeOutsideBounds"
56 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
57 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
58 IndexOutsideBounds = "IndexOutsideBounds"
59 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010060 MaxSmallerMin = "MaxSmallerMin"
61 ConcatInputRankMismatch = "ConcatInputRankMismatch"
62 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010063 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010064 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
65 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
66 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
67 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
68 InputListOutputListMismatch = "InputListOutputListMismatch"
69 InputListCondGraphMismatch = "InputListCondGraphMismatch"
70 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
71 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
72 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010073 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
74 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010075
76
77class TosaErrorIfArgGen:
78 @staticmethod
79 def eiResizeErrorIf(
80 testGen,
81 error_name,
82 mode,
83 dtype,
84 shapeList,
85 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010086 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010088 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010090 if error_name == ErrorIf.ScaleSmallerEqualZero:
91 index = testGen.randInt(low=0, high=4)
92 scale[index] = testGen.rng.choice([-2, -1, 0])
93 elif error_name == ErrorIf.ScaleNLargerMax:
94 index = testGen.rng.choice([0, 2])
95 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
96 elif error_name == ErrorIf.ScaleDLargerMax:
97 index = testGen.rng.choice([1, 3])
98 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010099
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100100 if error_name == ErrorIf.OffsetLargerEqualMax:
101 index = testGen.rng.choice([0, 1])
102 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
103 elif error_name == ErrorIf.OffsetSmallerMin:
104 index = testGen.rng.choice([0, 1])
105 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
106
107 if error_name == ErrorIf.BorderLargerEqualMax:
108 index = testGen.rng.choice([0, 1])
109 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
110 elif error_name == ErrorIf.BorderSmallerMin:
111 index = testGen.rng.choice([0, 1])
112 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100113
114 if error_name == ErrorIf.WrongOutputType:
115 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
116 incorrect_types = (
117 DType.INT4,
118 DType.INT16,
119 DType.INT32,
120 DType.INT48,
121 DType.FLOAT,
122 )
123 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
124 incorrect_types = (
125 DType.INT4,
126 DType.INT8,
127 DType.INT32,
128 DType.INT48,
129 DType.FLOAT,
130 )
131 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
132 incorrect_types = (
133 DType.INT4,
134 DType.INT8,
135 DType.INT16,
136 DType.INT48,
137 DType.FLOAT,
138 )
139 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
140 incorrect_types = (
141 DType.INT4,
142 DType.INT8,
143 DType.INT16,
144 DType.INT32,
145 DType.FLOAT,
146 )
147 elif dtype == DType.FLOAT:
148 incorrect_types = (
149 DType.INT4,
150 DType.INT8,
151 DType.INT16,
152 DType.INT32,
153 DType.INT48,
154 )
155 outputDType = testGen.rng.choice(a=incorrect_types)
156
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100157 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100158
159 @staticmethod
160 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
161 if (
162 error_name == ErrorIf.StrideSmallerOne
163 # padding must not exceed the kernel size
164 and pad[0] < kernel[0]
165 and pad[1] < kernel[0]
166 and pad[2] < kernel[1]
167 and pad[3] < kernel[1]
168 ):
169 wrongStride = (
170 testGen.rng.choice([0, -1, -2, -3]),
171 testGen.rng.choice([0, -1, -2, -3]),
172 )
173 return wrongStride, pad, kernel
174 elif error_name == ErrorIf.PadSmallerZero:
175 wrongPad = (
176 testGen.rng.choice([-1, -2, -3]),
177 testGen.rng.choice([-1, -2, -3]),
178 testGen.rng.choice([-1, -2, -3]),
179 testGen.rng.choice([-1, -2, -3]),
180 )
181 return stride, wrongPad, kernel
182 elif error_name == ErrorIf.KernelSmallerOne:
183 wrongKernel = (
184 testGen.rng.choice([0, -1, -2, -3]),
185 testGen.rng.choice([0, -1, -2, -3]),
186 )
187 return stride, pad, wrongKernel
188 elif error_name == ErrorIf.PadLargerEqualKernel:
189 wrongPad = (
190 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
191 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
192 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
193 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
194 )
195 return stride, wrongPad, kernel
196 else:
197 return None, None, None
198
199 @staticmethod
200 def eiRescaleWrongOutputType(input_dtype, output_dtype):
201 if input_dtype == DType.INT8:
202 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
203 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100204 elif input_dtype == DType.INT16:
205 if output_dtype not in [
206 DType.UINT8,
207 DType.INT8,
208 DType.UINT16,
209 DType.INT16,
210 DType.INT32,
211 ]:
212 return True
213 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100214 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
215 return True
216 elif input_dtype == DType.INT48:
217 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
218 return True
219 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100220 if output_dtype not in [DType.INT8, DType.INT16]:
221 return True
222 elif input_dtype == DType.UINT16:
223 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100224 return True
225 return False
226
227 @staticmethod
228 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
229 # Mess up input/output tensors for ERROR_IF checks
230 if error_name == "WrongInputList":
231 add_input = testGen.rng.choice([True, False])
232 if add_input:
233 input_list.append("eiDummyInput")
234 else:
235 input_list = input_list[:-1]
236 elif error_name == "WrongOutputList":
237 add_output = testGen.rng.choice([True, False])
238 if add_output:
239 output_list.append("eiDummyOutput")
240 else:
241 output_list = []
242 return input_list, output_list
243
244 @staticmethod
245 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
246 """Restrict the dimensions and overall size of a shape to
247 max_dim and max_items.
248 """
249 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
250 while product(new_shape) > max_items:
251 new_shape = [max(d - 1, 1) for d in new_shape]
252 return new_shape
253
254 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
255 if error_name == ErrorIf.StartSmallerZero:
256 newStart = []
257 for i in range(len(input_shape)):
258 newStart.append(testGen.rng.choice([-3, -2, -1]))
259 return newStart, size
260 elif error_name == ErrorIf.SizeSmallerEqualZero:
261 newSize = []
262 for i in range(len(input_shape)):
263 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
264 return start, newSize
265 elif error_name == ErrorIf.StartSizeOutsideBounds:
266 newStart, newSize = [], []
267 for i in range(len(input_shape)):
268 newStart.append(input_shape[i] - 1)
269 newSize.append(testGen.rng.choice([2, 3, 4]))
270 return newStart, newSize
271 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
272 remove = testGen.rng.choice([True, False])
273 if remove:
274 newStart = start[1:]
275 newSize = size[1:]
276 else:
277 newStart = start
278 newStart.append(1)
279 newSize = size
280 newSize.append(1)
281 return newStart, newSize
282 else:
283 return start, size
284
285 @staticmethod
286 def eiCastErrorIf(testGen, input_dtype):
287 if input_dtype in [DType.BOOL, DType.FLOAT]:
288 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
289 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
290 outputDType = [DType.INT48]
291 else:
292 assert True, f"input_dtype ({input_dtype}) not supported"
293 return outputDType
294
295
296class TosaErrorValidator:
297 @staticmethod
298 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
299 """Check ERROR_IF statements are caught and set the expected result.
300
301 Args:
302 serializer: the serializer to set the expected result in
303 validator_fcns: a sequence of validator functions to verify the result
304 error_name: the name of the ERROR_IF condition to check for
305 kwargs: keyword arguments for the validator functions
306 Returns:
307 True if the result matches the expected result; otherwise False
308 """
309 overall_result = True
310 for val_fcn in validator_fcns:
311 val_result = val_fcn(True, **kwargs)
312 validator_name = val_result["error_name"]
313 error_result = val_result["error_result"]
314 error_reason = val_result["error_reason"]
315
316 # expect an error IFF the error_name and validator_name match
317 expected_result = error_result == (error_name == validator_name)
318 overall_result &= expected_result
319
320 if expected_result and error_result:
321 serializer.setExpectedReturnCode(2, True, desc=error_reason)
322 elif error_result: # and not expected_result
323 print(
324 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
325 f" Expected: {error_name}, Got: {validator_name}"
326 )
327 elif not expected_result: # and not error_result
328 print(
329 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
330 f" Expected: {error_name}"
331 )
332
333 if not expected_result:
334 for k, v in sorted(kwargs.items()):
335 if k != "op":
336 if k.endswith("dtype"):
337 v = valueToName(DType, v)
338 print(f" {k} = {v}")
339
340 return overall_result
341
342 @staticmethod
343 def evWrongInputType(check=False, **kwargs):
344 error_result = False
345
346 # Find the unsupported input data types
347 op = kwargs["op"]
348 input_dtypes = op["types"]
349 allowed_input_dtypes = {
350 t[0] if isinstance(t, list) else t for t in input_dtypes
351 }
352 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
353
354 if op["op"] == Op.CLAMP:
355 wrong_input_dtypes.remove(DType.INT48)
356
357 if check:
358 input_dtype = kwargs["input_dtype"]
359 if input_dtype not in allowed_input_dtypes:
360 error_result = True
361
362 info_dict = {
363 "error_name": ErrorIf.WrongInputType,
364 "error_result": error_result,
365 "error_reason": "Input data type not supported for this operator",
366 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
367 }
368 return info_dict
369
370 @staticmethod
371 def evWrongOutputType(check=False, **kwargs):
372 error_result = False
373
374 if check:
375 input_dtype = kwargs["input_dtype"]
376 output_dtype = kwargs["output_dtype"]
377 op = kwargs["op"]
378
379 if op["op"] == Op.RESIZE:
380 mode = kwargs["mode"]
381 if (
382 (
383 mode == ResizeMode.NEAREST
384 and input_dtype == DType.INT8
385 and output_dtype != DType.INT8
386 )
387 or (
388 mode == ResizeMode.NEAREST
389 and input_dtype == DType.INT16
390 and output_dtype != DType.INT16
391 )
392 or (
393 mode == ResizeMode.BILINEAR
394 and input_dtype == DType.INT8
395 and output_dtype != DType.INT32
396 )
397 or (
398 mode == ResizeMode.BILINEAR
399 and input_dtype == DType.INT16
400 and output_dtype != DType.INT48
401 )
402 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
403 ):
404 error_result = True
405
406 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100407 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
408 input_dtype, output_dtype
409 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100410
411 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
412 if (
413 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
414 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
415 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
416 ):
417 error_result = True
418
419 elif op["op"] == Op.ARGMAX:
420 if (
421 input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
422 and output_dtype != DType.INT32
423 ):
424 error_result = True
425
426 elif op["op"] == Op.MUL:
427 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
428 error_result = True
429 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
430 error_result = True
431
432 elif op["op"] == Op.TABLE:
433 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
434 error_result = True
435 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
436 error_result = True
437
438 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
439 if output_dtype != DType.BOOL:
440 error_result = True
441
442 elif op["op"] == Op.CAST:
443 if (
444 (
445 input_dtype == DType.BOOL
446 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
447 )
448 or (
449 input_dtype == DType.INT8
450 and output_dtype
451 not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
452 )
453 or (
454 input_dtype == DType.INT16
455 and output_dtype
456 not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
457 )
458 or (
459 input_dtype == DType.INT32
460 and output_dtype
461 not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
462 )
463 or (
464 input_dtype == DType.FLOAT
465 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
466 )
467 ):
468 error_result = True
469
470 elif op["op"] in {
471 Op.CONV2D,
472 Op.CONV3D,
473 Op.DEPTHWISE_CONV2D,
474 Op.TRANSPOSE_CONV2D,
475 }:
476 if (
477 input_dtype == DType.INT8
478 and output_dtype != DType.INT32
479 or input_dtype == DType.INT16
480 and output_dtype != DType.INT48
481 or input_dtype == DType.FLOAT
482 and output_dtype != DType.FLOAT
483 ):
484 error_result = True
485 # invalid input types are ignored, to avoid reporting multiple errors
486
487 else:
488 if output_dtype != input_dtype:
489 error_result = True
490
491 info_dict = {
492 "error_name": ErrorIf.WrongOutputType,
493 "error_result": error_result,
494 "error_reason": (
495 "Output data type not supported for this configuration of operator"
496 ),
497 "param_reqs": {"rank": None, "dtype": None, "shape": None},
498 }
499 return info_dict
500
501 @staticmethod
502 def evWrongRank(check=False, **kwargs):
503 all_ranks = (1, 2, 3, 4, 5)
504
505 # Make a list of incorrect ranks
506 assert "op" in kwargs
507 op = kwargs["op"]
508 rmin, rmax = op["rank"]
509 rank_range = range(rmin, rmax + 1)
510 incorrect_ranks = list(set(all_ranks) - set(rank_range))
511 # Remove small incorrect ranks to avoid index errors
512 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
513 # Set minimum incorrect rank to 3 to avoid index error
514 if op["op"] in [Op.RESIZE]:
515 incorrect_ranks = [3, 5]
516 elif op["op"] in [Op.TRANSPOSE]:
517 incorrect_ranks = [7, 8]
518 elif op["op"] in [Op.CONV3D]:
519 incorrect_ranks = [6, 7]
520
521 error_name = ErrorIf.WrongRank
522 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
523 error_result = False
524 error_reason = "Rank not supported for this operator"
525
526 if check:
527 input_shape = kwargs["input_shape"]
528
529 if (
530 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
531 and len(input_shape) != 4
532 ):
533 error_result = True
534 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
535 error_result = True
536 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
537 error_result = True
538 else:
539 if len(input_shape) not in rank_range:
540 error_result = True
541
542 info_dict = {
543 "error_name": error_name,
544 "error_result": error_result,
545 "error_reason": error_reason,
546 "param_reqs": param_reqs,
547 }
548 return info_dict
549
550 @staticmethod
551 def evWrongInputList(check=False, **kwargs):
552 error_name = ErrorIf.WrongInputList
553 param_reqs = {"rank": None, "dtype": None, "shape": None}
554 error_result = False
555 error_reason = "Op input list does not match expected input"
556
557 if check:
558 op = kwargs["op"]
559 input_list = kwargs["input_list"]
560 num_operands = kwargs["num_operands"]
561 if op["op"] in [Op.SCATTER, Op.GATHER]:
562 # SCATTER/GATHER add an indices input tensor in their build functions
563 num_operands += 1
564 if len(input_list) != num_operands:
565 error_result = True
566
567 info_dict = {
568 "error_name": error_name,
569 "error_result": error_result,
570 "error_reason": error_reason,
571 "param_reqs": param_reqs,
572 }
573 return info_dict
574
575 @staticmethod
576 def evWrongOutputList(check=False, **kwargs):
577 error_name = ErrorIf.WrongOutputList
578 param_reqs = {"rank": None, "dtype": None, "shape": None}
579 error_result = False
580 error_reason = "Op output list does not match expected output"
581
582 if check:
583 output_list = kwargs["output_list"]
584 # Note this will be incorrect if an operator returns more than one output
585 if len(output_list) != 1:
586 error_result = True
587
588 info_dict = {
589 "error_name": error_name,
590 "error_result": error_result,
591 "error_reason": error_reason,
592 "param_reqs": param_reqs,
593 }
594 return info_dict
595
596 @staticmethod
597 def evMaxDimExceeded(check=False, **kwargs):
598 error_name = ErrorIf.MaxDimExceeded
599 param_reqs = {
600 "rank": [4, 4],
601 "dtype": [DType.INT8],
602 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
603 }
604 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100605 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100606
607 if check:
608 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100609 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100610 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100611 (input_shape[1] >= MAX_RESIZE_DIMENSION)
612 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
613 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
614 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100615 ):
616 error_result = True
617
618 info_dict = {
619 "error_name": error_name,
620 "error_result": error_result,
621 "error_reason": error_reason,
622 "param_reqs": param_reqs,
623 }
624 return info_dict
625
626 @staticmethod
627 def evBatchMismatch(check=False, **kwargs):
628 error_name = ErrorIf.BatchMismatch
629 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
630 error_result = False
631 error_reason = "Input batch size not equal to output batch size"
632
633 assert "op" in kwargs
634 op = kwargs["op"]
635 rmin, rmax = op["rank"]
636 rank_range = range(rmin, rmax + 1)
637
638 if check:
639 input_shape = kwargs["input_shape"]
640 output_shape = kwargs[
641 "result_tensor"
642 ].shape # Note this is just (N, OH, OW, C)
643
644 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
645 error_result = True
646
647 info_dict = {
648 "error_name": error_name,
649 "error_result": error_result,
650 "error_reason": error_reason,
651 "param_reqs": param_reqs,
652 }
653 return info_dict
654
655 @staticmethod
656 def evChannelMismatch(check=False, **kwargs):
657 error_name = ErrorIf.ChannelMismatch
658 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
659 error_result = False
660 error_reason = "Input channel size not equal to output channel size"
661
662 assert "op" in kwargs
663 op = kwargs["op"]
664 rmin, rmax = op["rank"]
665 rank_range = range(rmin, rmax + 1)
666
667 if check:
668 input_shape = kwargs["input_shape"]
669 output_shape = kwargs[
670 "result_tensor"
671 ].shape # Note this is just (N, OH, OW, C)
672 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
673 error_result = True
674
675 info_dict = {
676 "error_name": error_name,
677 "error_result": error_result,
678 "error_reason": error_reason,
679 "param_reqs": param_reqs,
680 }
681 return info_dict
682
683 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100684 def evScaleSmallerEqualZero(check=False, **kwargs):
685 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100686 param_reqs = {"rank": None, "dtype": None, "shape": None}
687 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100688 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100689
690 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100691 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100692
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100693 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100694 error_result = True
695
696 info_dict = {
697 "error_name": error_name,
698 "error_result": error_result,
699 "error_reason": error_reason,
700 "param_reqs": param_reqs,
701 }
702 return info_dict
703
704 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100705 def evScaleNLargerMax(check=False, **kwargs):
706 error_name = ErrorIf.ScaleNLargerMax
707 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100708 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100709 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100710
711 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100712 scale = kwargs["scale"]
713
714 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
715 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100716
717 info_dict = {
718 "error_name": error_name,
719 "error_result": error_result,
720 "error_reason": error_reason,
721 "param_reqs": param_reqs,
722 }
723 return info_dict
724
725 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100726 def evScaleDLargerMax(check=False, **kwargs):
727 error_name = ErrorIf.ScaleDLargerMax
728 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100729 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100730 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100731
732 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100733 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100734
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100735 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
736 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100737 ):
738 error_result = True
739
740 info_dict = {
741 "error_name": error_name,
742 "error_result": error_result,
743 "error_reason": error_reason,
744 "param_reqs": param_reqs,
745 }
746 return info_dict
747
748 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100749 def evOffsetSmallerMin(check=False, **kwargs):
750 error_name = ErrorIf.OffsetSmallerMin
751 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100752 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100753 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100754
755 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100756 scale = kwargs["scale"]
757 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100758
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100759 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100760 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100761 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100762 error_result = True
763
764 info_dict = {
765 "error_name": error_name,
766 "error_result": error_result,
767 "error_reason": error_reason,
768 "param_reqs": param_reqs,
769 }
770 return info_dict
771
772 @staticmethod
773 def evOffsetLargerEqualMax(check=False, **kwargs):
774 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100775 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100776 error_result = False
777 error_reason = "Offset value larger than or equal to maximum value"
778
779 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100780 scale = kwargs["scale"]
781 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100782
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100783 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
784 error_result = True
785 elif (
786 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
787 ):
788 error_result = True
789
790 info_dict = {
791 "error_name": error_name,
792 "error_result": error_result,
793 "error_reason": error_reason,
794 "param_reqs": param_reqs,
795 }
796 return info_dict
797
798 @staticmethod
799 def evBorderSmallerMin(check=False, **kwargs):
800 error_name = ErrorIf.BorderSmallerMin
801 param_reqs = {"rank": None, "dtype": None, "shape": None}
802 error_result = False
803 error_reason = "Border value smaller than minimum value"
804
805 if check:
806 scale = kwargs["scale"]
807 border = kwargs["border"]
808
809 if (
810 scale[0] > 0
811 and scale[0] <= (1 << 11)
812 and (border[0] < (-16 * scale[0]))
813 ):
814 error_result = True
815 elif (
816 scale[2] > 0
817 and scale[2] <= (1 << 11)
818 and (border[1] < (-16 * scale[2]))
819 ):
820 error_result = True
821
822 info_dict = {
823 "error_name": error_name,
824 "error_result": error_result,
825 "error_reason": error_reason,
826 "param_reqs": param_reqs,
827 }
828 return info_dict
829
830 @staticmethod
831 def evBorderLargerEqualMax(check=False, **kwargs):
832 error_name = ErrorIf.BorderLargerEqualMax
833 param_reqs = {"rank": None, "dtype": None, "shape": None}
834 error_result = False
835 error_reason = "Border value larger than or equal to maximum value"
836
837 if check:
838 scale = kwargs["scale"]
839 border = kwargs["border"]
840
841 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
842 error_result = True
843 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
844 error_result = True
845
846 info_dict = {
847 "error_name": error_name,
848 "error_result": error_result,
849 "error_reason": error_reason,
850 "param_reqs": param_reqs,
851 }
852 return info_dict
853
854 @staticmethod
855 def checkResizeParams(scale, offset, border):
856 return (
857 min(scale) > 0
858 and max(scale[0], scale[2]) <= (1 << 11)
859 and scale[1] < 16 * scale[0]
860 and scale[3] < 16 * scale[2]
861 and offset[0] >= -scale[0]
862 and offset[1] >= -scale[2]
863 and offset[0] < 16 * scale[0]
864 and offset[1] < 16 * scale[2]
865 and border[0] >= -16 * scale[0]
866 and border[1] >= -16 * scale[2]
867 and border[0] < scale[0]
868 and border[1] < scale[2]
869 )
870
871 @staticmethod
872 def evResizeOutputShapeMismatch(check=False, **kwargs):
873 error_name = ErrorIf.ResizeOutputShapeMismatch
874 param_reqs = {"rank": None, "dtype": None, "shape": None}
875 error_result = False
876 error_reason = (
877 "Mismatch between output shape provided and expected output shape"
878 )
879
880 if check:
881 input_shape = kwargs["input_shape"]
882 output_shape = kwargs["output_shape"]
883 scale = kwargs["scale"]
884 offset = kwargs["offset"]
885 border = kwargs["border"]
886
887 # Ensure parameters are valid
888 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
889
890 if (
891 params_valid
892 and max(output_shape) < MAX_RESIZE_DIMENSION
893 and max(input_shape) < MAX_RESIZE_DIMENSION
894 ):
895 output_y = (
896 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
897 ) // scale[1] + 1
898 output_x = (
899 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
900 ) // scale[3] + 1
901
902 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100903 error_result = True
904
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100905 info_dict = {
906 "error_name": error_name,
907 "error_result": error_result,
908 "error_reason": error_reason,
909 "param_reqs": param_reqs,
910 }
911 return info_dict
912
913 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100914 def evResizeOutputShapeNonInteger(check=False, **kwargs):
915 error_name = ErrorIf.ResizeOutputShapeNonInteger
916 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100917 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100918 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100919
920 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100921 input_shape = kwargs["input_shape"]
922 scale = kwargs["scale"]
923 offset = kwargs["offset"]
924 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100925
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100926 # Ensure parameters are valid
927 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100928
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100929 if params_valid:
930 remainder_y = (
931 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
932 ) % scale[1]
933 remainder_x = (
934 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
935 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100936
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100937 if max(remainder_y, remainder_x) > 0:
938 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100939
940 info_dict = {
941 "error_name": error_name,
942 "error_result": error_result,
943 "error_reason": error_reason,
944 "param_reqs": param_reqs,
945 }
946 return info_dict
947
948 @staticmethod
949 def evRankMismatch(check=False, **kwargs):
950 error_name = ErrorIf.RankMismatch
951 param_reqs = {"rank": None, "dtype": None, "shape": None}
952 error_result = False
953 error_reason = "Input Rank does not match output rank"
954
955 if check:
956 input1_shape = kwargs["input1"].shape
957 input2_shape = kwargs["input2"].shape
958 # In case of SELECT op
959 input3_shape = (
960 kwargs["input3"].shape if "input3" in kwargs else input2_shape
961 )
962 output_shape = kwargs["result_tensor"].shape
963 if (
964 (len(input1_shape) != len(output_shape))
965 or (len(input2_shape) != len(output_shape))
966 or (len(input3_shape) != len(output_shape))
967 ):
968 error_result = True
969
970 info_dict = {
971 "error_name": error_name,
972 "error_result": error_result,
973 "error_reason": error_reason,
974 "param_reqs": param_reqs,
975 }
976 return info_dict
977
978 @staticmethod
979 def evDimensionMismatch(check=False, **kwargs):
980 error_name = ErrorIf.DimensionMismatch
981 param_reqs = {"rank": None, "dtype": None, "shape": None}
982 error_result = False
983 error_reason = "Input Dimensions do not match output"
984
985 if check:
986 input1_shape = kwargs["input1"].shape
987 input2_shape = kwargs["input2"].shape
988 # In case of SELECT op
989 input3_shape = (
990 kwargs["input3"].shape if "input3" in kwargs else input2_shape
991 )
992 output_shape = kwargs["result_tensor"].shape
993 for i in range(
994 min(len(input1_shape), len(input2_shape), len(input3_shape))
995 ):
996 if (
997 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
998 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
999 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
1000 ):
1001 error_result = True
1002
1003 info_dict = {
1004 "error_name": error_name,
1005 "error_result": error_result,
1006 "error_reason": error_reason,
1007 "param_reqs": param_reqs,
1008 }
1009 return info_dict
1010
1011 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001012 def _getZeroPoint(qinfo, index):
1013 """Return zero point value from quantization info.
1014
1015 Generally input_zp is index 0, output_zp is index 1
1016 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001017 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001018
1019 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001020 def evInputZeroPointNotZero(check=False, **kwargs):
1021 op = kwargs["op"]
1022 error_result = False
1023
1024 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001025 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001026
1027 # This does not apply to quantizable types
1028 inputDtypes = [
1029 dtype
1030 for dtype in op["types"]
1031 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1032 or (not isinstance(dtype, list) and dtype not in qTypes)
1033 ]
1034
1035 if check:
1036 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001037 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001038 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001039 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001040 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001041 (kwargs["input_dtype"], input_zero_point),
1042 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001043 ):
1044 if dtype not in qTypes and zp != 0:
1045 error_result = True
1046 break
1047 else:
1048 error_result = input_dtype not in qTypes and input_zero_point != 0
1049
1050 info_dict = {
1051 "error_name": ErrorIf.InputZeroPointNotZero,
1052 "error_result": error_result,
1053 "error_reason": "Input DType not INT8 and zero point not 0",
1054 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1055 }
1056 return info_dict
1057
1058 @staticmethod
1059 def evWeightZeroPointNotZero(check=False, **kwargs):
1060 op = kwargs["op"]
1061
1062 # exclude inputs with INT8 weights
1063 inputDtypes = [
1064 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1065 ]
1066
1067 error_name = ErrorIf.WeightZeroPointNotZero
1068 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1069 error_result = False
1070 error_reason = "Weight DType not INT8 and zero point not 0"
1071
1072 if check:
1073 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001074 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001075 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1076 error_result = True
1077
1078 info_dict = {
1079 "error_name": error_name,
1080 "error_result": error_result,
1081 "error_reason": error_reason,
1082 "param_reqs": param_reqs,
1083 }
1084 return info_dict
1085
1086 @staticmethod
1087 def evOutputZeroPointNotZero(check=False, **kwargs):
1088 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001089 inputDtypes = [
1090 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1091 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001092
1093 error_name = ErrorIf.OutputZeroPointNotZero
1094 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1095 error_result = False
1096 error_reason = "Output DType not INT8 and zero point not 0"
1097
1098 if check:
1099 input_dtype = kwargs["input_dtype"]
1100 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001101 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001102 if op["op"] == Op.AVG_POOL2D:
1103 if input_dtype != DType.INT8 and output_zero_point != 0:
1104 error_result = True
1105 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001106 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1107 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001108 ):
1109 error_result = True
1110
1111 info_dict = {
1112 "error_name": error_name,
1113 "error_result": error_result,
1114 "error_reason": error_reason,
1115 "param_reqs": param_reqs,
1116 }
1117 return info_dict
1118
1119 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001120 def evU16InputZeroPointNotValid(check=False, **kwargs):
1121 error_name = ErrorIf.U16InputZeroPointNotValid
1122 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1123 error_result = False
1124 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1125
1126 if check:
1127 input_dtype = kwargs["input_dtype"]
1128 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1129 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1130 0,
1131 32768,
1132 ]
1133
1134 info_dict = {
1135 "error_name": error_name,
1136 "error_result": error_result,
1137 "error_reason": error_reason,
1138 "param_reqs": param_reqs,
1139 }
1140 return info_dict
1141
1142 @staticmethod
1143 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1144 error_name = ErrorIf.U16OutputZeroPointNotValid
1145 param_reqs = {"rank": None, "dtype": None, "shape": None}
1146 error_result = False
1147 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1148
1149 if check:
1150 output_dtype = kwargs["output_dtype"]
1151 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1152
1153 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1154 0,
1155 32768,
1156 ]
1157
1158 info_dict = {
1159 "error_name": error_name,
1160 "error_result": error_result,
1161 "error_reason": error_reason,
1162 "param_reqs": param_reqs,
1163 }
1164 return info_dict
1165
1166 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001167 def evAxisSmallerZero(check=False, **kwargs):
1168 error_name = ErrorIf.AxisSmallerZero
1169 param_reqs = {"rank": None, "dtype": None, "shape": None}
1170 error_result = False
1171 error_reason = "Axis smaller than zero"
1172
1173 if check:
1174 axis = kwargs["axis"]
1175 if axis < 0:
1176 error_result = True
1177
1178 info_dict = {
1179 "error_name": error_name,
1180 "error_result": error_result,
1181 "error_reason": error_reason,
1182 "param_reqs": param_reqs,
1183 }
1184 return info_dict
1185
1186 @staticmethod
1187 def evAxisLargerRank(check=False, **kwargs):
1188 error_name = ErrorIf.AxisLargerRank
1189 param_reqs = {"rank": None, "dtype": None, "shape": None}
1190 error_result = False
1191 error_reason = "Axis larger than rank"
1192
1193 if check:
1194 axis = kwargs["axis"]
1195 shape = kwargs["input_shape"]
1196 if axis > len(shape):
1197 error_result = True
1198
1199 info_dict = {
1200 "error_name": error_name,
1201 "error_result": error_result,
1202 "error_reason": error_reason,
1203 "param_reqs": param_reqs,
1204 }
1205 return info_dict
1206
1207 @staticmethod
1208 def evShapeOfAxisNotOne(check=False, **kwargs):
1209 error_name = ErrorIf.ShapeOfAxisNotOne
1210 param_reqs = {"rank": None, "dtype": None, "shape": None}
1211 error_result = False
1212 error_reason = "shape[axis] is not equal to 1"
1213
1214 if check:
1215 axis = kwargs["axis"]
1216 shape = kwargs["output_shape"]
1217 if (0 <= axis < len(shape)) and shape[axis] != 1:
1218 error_result = True
1219
1220 info_dict = {
1221 "error_name": error_name,
1222 "error_result": error_result,
1223 "error_reason": error_reason,
1224 "param_reqs": param_reqs,
1225 }
1226 return info_dict
1227
1228 @staticmethod
1229 def evPadSmallerZero(check=False, **kwargs):
1230 error_name = ErrorIf.PadSmallerZero
1231 param_reqs = {"rank": None, "dtype": None, "shape": None}
1232 error_result = False
1233 error_reason = "At least one pad is smaller than zero"
1234
1235 if check:
1236 op = kwargs["op"]
1237 pad = kwargs["pad"]
1238 if op["op"] == Op.PAD:
1239 for padding in pad:
1240 if min(padding) < 0:
1241 error_result = True
1242 else:
1243 if min(pad) < 0:
1244 error_result = True
1245
1246 info_dict = {
1247 "error_name": error_name,
1248 "error_result": error_result,
1249 "error_reason": error_reason,
1250 "param_reqs": param_reqs,
1251 }
1252 return info_dict
1253
1254 @staticmethod
1255 def evPadLargerEqualKernel(check=False, **kwargs):
1256 error_name = ErrorIf.PadLargerEqualKernel
1257 param_reqs = {"rank": None, "dtype": None, "shape": None}
1258 error_result = False
1259 error_reason = "At least one pad is larger than kernel dimension"
1260
1261 if check:
1262 pad = kwargs["pad"]
1263 kernel = kwargs["kernel"]
1264 if min(pad) > 0 and min(kernel) > 1:
1265 if (
1266 pad[0] >= kernel[0]
1267 or pad[1] >= kernel[0]
1268 or pad[2] >= kernel[1]
1269 or pad[3] >= kernel[1]
1270 ):
1271 error_result = True
1272
1273 info_dict = {
1274 "error_name": error_name,
1275 "error_result": error_result,
1276 "error_reason": error_reason,
1277 "param_reqs": param_reqs,
1278 }
1279 return info_dict
1280
1281 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001282 def checkPoolingParams(kernel, stride, pad):
1283 return (
1284 min(kernel) >= 1
1285 and min(stride) >= 1
1286 and min(pad) >= 0
1287 and not (
1288 pad[0] >= kernel[0]
1289 or pad[1] >= kernel[0]
1290 or pad[2] >= kernel[1]
1291 or pad[3] >= kernel[1]
1292 )
1293 )
1294
1295 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001296 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1297 error_name = ErrorIf.PoolingOutputShapeMismatch
1298 param_reqs = {"rank": None, "dtype": None, "shape": None}
1299 error_result = False
1300 error_reason = (
1301 "Mismatch between output shape provided and expected output shape"
1302 )
1303
1304 if check:
1305 pad = kwargs["pad"]
1306 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1307
1308 kernel = kwargs["kernel"]
1309 kernel_y, kernel_x = kernel[0], kernel[1]
1310
1311 input_shape = kwargs["input_shape"]
1312 IH, IW = input_shape[1], input_shape[2]
1313
1314 output_shape = kwargs["output_shape"]
1315 OH, OW = output_shape[1], output_shape[2]
1316
1317 stride = kwargs["stride"]
1318 stride_y, stride_x = stride[0], stride[1]
1319
1320 # calculate correct height, width dimensions
1321 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001322 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1323 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001324
1325 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001326 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001327
1328 if params_valid and (OH != y_correct or OW != x_correct):
1329 error_result = True
1330
1331 info_dict = {
1332 "error_name": error_name,
1333 "error_result": error_result,
1334 "error_reason": error_reason,
1335 "param_reqs": param_reqs,
1336 }
1337 return info_dict
1338
1339 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001340 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1341 error_name = ErrorIf.PoolingOutputShapeNonInteger
1342 param_reqs = {"rank": None, "dtype": None, "shape": None}
1343 error_result = False
1344 error_reason = "Parameters do not yield exact integer output dimensions"
1345
1346 if check:
1347 pad = kwargs["pad"]
1348 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1349
1350 kernel = kwargs["kernel"]
1351 kernel_y, kernel_x = kernel[0], kernel[1]
1352
1353 input_shape = kwargs["input_shape"]
1354 IH, IW = input_shape[1], input_shape[2]
1355
1356 stride = kwargs["stride"]
1357 stride_y, stride_x = stride[0], stride[1]
1358
1359 # calculate remainder of height, width dimensions
1360 if stride_x != 0 and stride_y != 0:
1361 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1362 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1363
1364 # ensure parameters are valid
1365 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1366 if params_valid and (y_remainder != 0 or x_remainder != 0):
1367 error_result = True
1368
1369 info_dict = {
1370 "error_name": error_name,
1371 "error_result": error_result,
1372 "error_reason": error_reason,
1373 "param_reqs": param_reqs,
1374 }
1375 return info_dict
1376
1377 @staticmethod
1378 def checkConvParams(weight_shape, stride, pad, dilation):
1379 return (
1380 # Check kernel sizes
1381 min(weight_shape[1:-1]) >= 1
1382 and min(stride) >= 1
1383 and min(pad) >= 0
1384 and (dilation is None or min(dilation) >= 1)
1385 )
1386
1387 @staticmethod
1388 def evConvOutputShapeMismatch(check=False, **kwargs):
1389 error_name = ErrorIf.ConvOutputShapeMismatch
1390 param_reqs = {"rank": None, "dtype": None, "shape": None}
1391 error_result = False
1392 error_reason = (
1393 "Mismatch between output shape provided and expected output shape"
1394 )
1395
1396 if check:
1397 op = kwargs["op"]
1398 pad = kwargs["pad"]
1399 weight_shape = kwargs["weight_shape"]
1400 input_shape = kwargs["input_shape"]
1401 output_shape = kwargs["output_shape"]
1402 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1403 stride = kwargs["stride"]
1404
1405 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1406
1407 # calculate correct dimensions
1408 dims_correct = []
1409 if min(stride) > 0:
1410 for index in range(len(stride)):
1411 pad_offset = index * 2
1412 if op["op"] == Op.TRANSPOSE_CONV2D:
1413 dims_correct.append(
1414 (input_shape[index + 1] - 1) * stride[index]
1415 - pad[pad_offset]
1416 - pad[pad_offset + 1]
1417 + weight_shape[index + kernel_offset]
1418 )
1419 else:
1420 dims_correct.append(
1421 (
1422 input_shape[index + 1]
1423 - 1
1424 + pad[pad_offset]
1425 + pad[pad_offset + 1]
1426 - (weight_shape[index + kernel_offset] - 1)
1427 * dilation[index]
1428 )
1429 // stride[index]
1430 + 1
1431 )
1432
1433 # ensure parameters are valid
1434 params_valid = TosaErrorValidator.checkConvParams(
1435 weight_shape, stride, pad, dilation
1436 )
1437
1438 if params_valid and output_shape[1:-1] != dims_correct:
1439 error_result = True
1440
1441 info_dict = {
1442 "error_name": error_name,
1443 "error_result": error_result,
1444 "error_reason": error_reason,
1445 "param_reqs": param_reqs,
1446 }
1447 return info_dict
1448
1449 @staticmethod
1450 def evConvOutputShapeNonInteger(check=False, **kwargs):
1451 error_name = ErrorIf.ConvOutputShapeNonInteger
1452 param_reqs = {"rank": None, "dtype": None, "shape": None}
1453 error_result = False
1454 error_reason = "Parameters do not yield exact integer output dimensions"
1455
1456 if check:
1457 op = kwargs["op"]
1458 pad = kwargs["pad"]
1459 weight_shape = kwargs["weight_shape"]
1460 input_shape = kwargs["input_shape"]
1461 dilation = kwargs["dilation"]
1462 stride = kwargs["stride"]
1463
1464 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1465
1466 # calculate correct height, width dimensions
1467 remainders = []
1468 if min(stride) > 0:
1469 for index in range(len(stride)):
1470 pad_offset = index * 2
1471 remainders.append(
1472 (
1473 input_shape[index + 1]
1474 - 1
1475 + pad[pad_offset]
1476 + pad[pad_offset + 1]
1477 - (weight_shape[index + kernel_offset] - 1)
1478 * dilation[index]
1479 )
1480 % stride[index]
1481 )
1482
1483 # ensure parameters are valid
1484 params_valid = TosaErrorValidator.checkConvParams(
1485 weight_shape, stride, pad, dilation
1486 )
1487 if params_valid and max(remainders) > 0:
1488 error_result = True
1489
1490 info_dict = {
1491 "error_name": error_name,
1492 "error_result": error_result,
1493 "error_reason": error_reason,
1494 "param_reqs": param_reqs,
1495 }
1496 return info_dict
1497
1498 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001499 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1500 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1501 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1502 error_result = False
1503 error_reason = (
1504 "Mismatch between output shape provided and expected output shape"
1505 )
1506
1507 if check:
1508 output_shape = kwargs["output_shape"]
1509 input_shape = kwargs["input_shape"]
1510 axis = kwargs["axis"]
1511
1512 dimension_match = True
1513 axis_shift = 0
1514
1515 # Check that rank is correct before trying to check dimensions
1516 if (len(input_shape) - 1) == len(output_shape):
1517 for i in range(len(input_shape)):
1518 if i == axis:
1519 axis_shift = 1
1520 continue
1521 if input_shape[i] != output_shape[i - axis_shift]:
1522 dimension_match = False
1523
1524 if not dimension_match:
1525 error_result = True
1526
1527 info_dict = {
1528 "error_name": error_name,
1529 "error_result": error_result,
1530 "error_reason": error_reason,
1531 "param_reqs": param_reqs,
1532 }
1533 return info_dict
1534
1535 @staticmethod
1536 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1537 error_name = ErrorIf.ArgmaxOutputRankMismatch
1538 param_reqs = {"rank": None, "dtype": None, "shape": None}
1539 error_result = False
1540 error_reason = (
1541 "Mismatch between output shape provided and expected output shape"
1542 )
1543
1544 if check:
1545 output_shape = kwargs["output_shape"]
1546 input_shape = kwargs["input_shape"]
1547 axis = kwargs["axis"]
1548 valid_params = axis >= 0 and axis < len(input_shape)
1549
1550 if valid_params and (len(input_shape) - 1) != len(output_shape):
1551 error_result = True
1552
1553 info_dict = {
1554 "error_name": error_name,
1555 "error_result": error_result,
1556 "error_reason": error_reason,
1557 "param_reqs": param_reqs,
1558 }
1559 return info_dict
1560
1561 @staticmethod
1562 def evKernelSmallerOne(check=False, **kwargs):
1563 error_name = ErrorIf.KernelSmallerOne
1564 param_reqs = {"rank": None, "dtype": None, "shape": None}
1565 error_result = False
1566 error_reason = "At least one kernel dimension is smaller than zero"
1567
1568 if check:
1569 kernel = kwargs["kernel"]
1570 if min(kernel) < 1:
1571 error_result = True
1572
1573 info_dict = {
1574 "error_name": error_name,
1575 "error_result": error_result,
1576 "error_reason": error_reason,
1577 "param_reqs": param_reqs,
1578 }
1579 return info_dict
1580
1581 @staticmethod
1582 def evStrideSmallerOne(check=False, **kwargs):
1583 error_name = ErrorIf.StrideSmallerOne
1584 param_reqs = {"rank": None, "dtype": None, "shape": None}
1585 error_result = False
1586 error_reason = "At least one stride dimension is smaller than zero"
1587
1588 if check:
1589 stride = kwargs["stride"]
1590 if min(stride) < 1:
1591 error_result = True
1592
1593 info_dict = {
1594 "error_name": error_name,
1595 "error_result": error_result,
1596 "error_reason": error_reason,
1597 "param_reqs": param_reqs,
1598 }
1599 return info_dict
1600
1601 @staticmethod
1602 def evDilationSmallerOne(check=False, **kwargs):
1603 error_result = check and min(kwargs["dilation"]) < 1
1604 return {
1605 "error_name": ErrorIf.DilationSmallerOne,
1606 "error_reason": "At least one dilation is smaller than one",
1607 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1608 "error_result": error_result,
1609 }
1610
1611 @staticmethod
1612 def evScaleTrue(check=False, **kwargs):
1613 error_name = ErrorIf.ScaleTrue
1614 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1615 error_result = False
1616 error_reason = "Scale set to true but input type is INT48"
1617
1618 if check:
1619 input_dtype = kwargs["input_dtype"]
1620 scale32 = kwargs["scale32"]
1621 if scale32 and input_dtype == DType.INT48:
1622 error_result = True
1623
1624 info_dict = {
1625 "error_name": error_name,
1626 "error_result": error_result,
1627 "error_reason": error_reason,
1628 "param_reqs": param_reqs,
1629 }
1630 return info_dict
1631
1632 @staticmethod
1633 def evScaleNotTrue(check=False, **kwargs):
1634 error_name = ErrorIf.ScaleNotTrue
1635 param_reqs = {"rank": None, "dtype": None, "shape": None}
1636 error_result = False
1637 error_reason = "Scale set to false but double round set to true"
1638
1639 if check:
1640 scale32 = kwargs["scale32"]
1641 double_round = kwargs["double_round"]
1642 if not scale32 and double_round:
1643 error_result = True
1644
1645 info_dict = {
1646 "error_name": error_name,
1647 "error_result": error_result,
1648 "error_reason": error_reason,
1649 "param_reqs": param_reqs,
1650 }
1651 return info_dict
1652
1653 @staticmethod
1654 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1655 error_name = ErrorIf.TensorSizeInputOutputMismatch
1656 param_reqs = {"rank": None, "dtype": None, "shape": None}
1657 error_result = False
1658 error_reason = "Input tensor size does not match output tensor size"
1659
1660 if check:
1661 input_shape = kwargs["input_shape"]
1662 output_shape = kwargs["output_shape"]
1663 input_size = np.prod(input_shape)
1664 output_size = np.prod(output_shape)
1665 if input_size != output_size:
1666 error_result = True
1667
1668 info_dict = {
1669 "error_name": error_name,
1670 "error_result": error_result,
1671 "error_reason": error_reason,
1672 "param_reqs": param_reqs,
1673 }
1674 return info_dict
1675
1676 @staticmethod
1677 def evStartSmallerZero(check=False, **kwargs):
1678 error_name = ErrorIf.StartSmallerZero
1679 param_reqs = {"rank": None, "dtype": None, "shape": None}
1680 error_result = False
1681 error_reason = "Starting point smaller than zero"
1682
1683 if check:
1684 input_shape = kwargs["input_shape"]
1685 start = kwargs["start"]
1686 rank = len(input_shape)
1687 if len(start) == rank:
1688 for index in range(rank):
1689 if start[index] < 0:
1690 error_result = True
1691
1692 info_dict = {
1693 "error_name": error_name,
1694 "error_result": error_result,
1695 "error_reason": error_reason,
1696 "param_reqs": param_reqs,
1697 }
1698 return info_dict
1699
1700 @staticmethod
1701 def evSizeSmallerEqualZero(check=False, **kwargs):
1702 error_name = ErrorIf.SizeSmallerEqualZero
1703 param_reqs = {"rank": None, "dtype": None, "shape": None}
1704 error_result = False
1705 error_reason = "Size smaller than or equal to zero"
1706
1707 if check:
1708 input_shape = kwargs["input_shape"]
1709 size = kwargs["size"]
1710 rank = len(input_shape)
1711 if len(size) == rank:
1712 for index in range(rank):
1713 if size[index] <= 0:
1714 error_result = True
1715
1716 info_dict = {
1717 "error_name": error_name,
1718 "error_result": error_result,
1719 "error_reason": error_reason,
1720 "param_reqs": param_reqs,
1721 }
1722 return info_dict
1723
1724 @staticmethod
1725 def evStartSizeOutsideBounds(check=False, **kwargs):
1726 error_name = ErrorIf.StartSizeOutsideBounds
1727 param_reqs = {"rank": None, "dtype": None, "shape": None}
1728 error_result = False
1729 error_reason = "starting point plus size larger than input dimension"
1730
1731 if check:
1732 input_shape = kwargs["input_shape"]
1733 start = kwargs["start"]
1734 size = kwargs["size"]
1735 rank = len(input_shape)
1736 if len(start) == rank and len(size) == rank:
1737 for index in range(rank):
1738 if start[index] + size[index] > input_shape[index]:
1739 error_result = True
1740
1741 info_dict = {
1742 "error_name": error_name,
1743 "error_result": error_result,
1744 "error_reason": error_reason,
1745 "param_reqs": param_reqs,
1746 }
1747 return info_dict
1748
1749 @staticmethod
1750 def evSizeOutputShapeMismatch(check=False, **kwargs):
1751 error_name = ErrorIf.SizeOutputShapeMismatch
1752 param_reqs = {"rank": None, "dtype": None, "shape": None}
1753 error_result = False
1754 error_reason = "Size does not match output dimension"
1755
1756 if check:
1757 input_shape = kwargs["input_shape"]
1758 output_shape = kwargs["output_shape"]
1759 size = kwargs["size"]
1760 rank = len(input_shape)
1761 if len(size) == rank:
1762 for index in range(rank):
1763 if size[index] != output_shape[index]:
1764 error_result = True
1765
1766 info_dict = {
1767 "error_name": error_name,
1768 "error_result": error_result,
1769 "error_reason": error_reason,
1770 "param_reqs": param_reqs,
1771 }
1772 return info_dict
1773
1774 @staticmethod
1775 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1776 error_name = ErrorIf.InputSizeStartLengthMismatch
1777 param_reqs = {"rank": None, "dtype": None, "shape": None}
1778 error_result = False
1779 error_reason = "rank of input not equal to length of start or size"
1780
1781 if check:
1782 input_shape = kwargs["input_shape"]
1783 start = kwargs["start"]
1784 size = kwargs["size"]
1785 rank = len(input_shape)
1786 if rank != len(start) or rank != len(size):
1787 error_result = True
1788
1789 info_dict = {
1790 "error_name": error_name,
1791 "error_result": error_result,
1792 "error_reason": error_reason,
1793 "param_reqs": param_reqs,
1794 }
1795 return info_dict
1796
1797 @staticmethod
1798 def evIndexOutsideBounds(check=False, **kwargs):
1799 error_name = ErrorIf.IndexOutsideBounds
1800 param_reqs = {"rank": None, "dtype": None, "shape": None}
1801 error_result = False
1802 error_reason = "Index outside of allowed bounds"
1803
1804 if check:
1805 input_shape = kwargs["input_shape"]
1806 perms = kwargs["perms"]
1807 rank = len(input_shape)
1808
1809 for index in perms:
1810 if index < 0 or index > rank:
1811 error_result = True
1812
1813 info_dict = {
1814 "error_name": error_name,
1815 "error_result": error_result,
1816 "error_reason": error_reason,
1817 "param_reqs": param_reqs,
1818 }
1819 return info_dict
1820
1821 @staticmethod
1822 def evIndexUsedTwice(check=False, **kwargs):
1823 error_name = ErrorIf.IndexUsedTwice
1824 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1825 error_result = False
1826 error_reason = "Index used multiple times"
1827
1828 if check:
1829 perms = kwargs["perms"]
1830
1831 unique_indices = []
1832 for index in perms:
1833 if index in unique_indices:
1834 error_result = True
1835 else:
1836 unique_indices.append(index)
1837
1838 info_dict = {
1839 "error_name": error_name,
1840 "error_result": error_result,
1841 "error_reason": error_reason,
1842 "param_reqs": param_reqs,
1843 }
1844 return info_dict
1845
1846 @staticmethod
1847 def evMaxSmallerMin(check=False, **kwargs):
1848 error_name = ErrorIf.MaxSmallerMin
1849 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1850 error_result = False
1851 error_reason = "Max value smaller than min value"
1852
1853 if check:
1854 max_val = kwargs["max_val"]
1855 min_val = kwargs["min_val"]
1856 if max_val < min_val:
1857 error_result = True
1858
1859 info_dict = {
1860 "error_name": error_name,
1861 "error_result": error_result,
1862 "error_reason": error_reason,
1863 "param_reqs": param_reqs,
1864 }
1865 return info_dict
1866
1867 @staticmethod
1868 def evConcatInputRankMismatch(check=False, **kwargs):
1869 error_name = ErrorIf.ConcatInputRankMismatch
1870 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1871 error_result = False
1872 error_reason = "Input ranks are not identical"
1873
1874 if check:
1875 inputs = kwargs["inputs"]
1876 input_shape = kwargs["input_shape"]
1877 for input in inputs:
1878 if len(input.shape) != len(input_shape):
1879 error_result = True
1880
1881 info_dict = {
1882 "error_name": error_name,
1883 "error_result": error_result,
1884 "error_reason": error_reason,
1885 "param_reqs": param_reqs,
1886 }
1887 return info_dict
1888
1889 @staticmethod
1890 def evConcatInputDimMismatch(check=False, **kwargs):
1891 error_name = ErrorIf.ConcatInputDimMismatch
1892 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1893 error_result = False
1894 error_reason = "Input dimensions differ on too many axes"
1895
1896 if check:
1897 inputs = kwargs["inputs"]
1898 input_shape = kwargs["input_shape"]
1899 axis = kwargs["axis"]
1900
1901 # Ensure rank is valid before checking dims.
1902 valid_rank = True
1903 for input in inputs:
1904 if len(input.shape) != len(input_shape):
1905 valid_rank = False
1906
1907 if valid_rank:
1908 for input in inputs:
1909 for i, dim in enumerate(input.shape):
1910 if dim != input_shape[i] and axis != i:
1911 error_result = True
1912
1913 info_dict = {
1914 "error_name": error_name,
1915 "error_result": error_result,
1916 "error_reason": error_reason,
1917 "param_reqs": param_reqs,
1918 }
1919 return info_dict
1920
1921 @staticmethod
1922 def evConcatShapeSumMismatch(check=False, **kwargs):
1923 error_name = ErrorIf.ConcatShapeSumMismatch
1924 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1925 error_result = False
1926 error_reason = "Sum of dimensions on axis not equal to output dimension"
1927
1928 if check:
1929 inputs = kwargs["inputs"]
1930 input_shape = kwargs["input_shape"]
1931 output_shape = kwargs["output_shape"]
1932 axis = kwargs["axis"]
1933
1934 # Ensure rank is valid before checking dims.
1935 valid_params = True
1936 for input in inputs:
1937 if len(input.shape) != len(input_shape):
1938 valid_params = False
1939 if axis < 0 or axis > len(input_shape):
1940 valid_params = False
1941
1942 if valid_params:
1943 axis_dim_sum = 0
1944 for input in inputs:
1945 axis_dim_sum += input.shape[axis]
1946
1947 if axis_dim_sum != output_shape[axis]:
1948 error_result = True
1949
1950 info_dict = {
1951 "error_name": error_name,
1952 "error_result": error_result,
1953 "error_reason": error_reason,
1954 "param_reqs": param_reqs,
1955 }
1956 return info_dict
1957
1958 @staticmethod
1959 def evInputListThenGraphMismatch(check=False, **kwargs):
1960 error_name = ErrorIf.CondIfInputListThenGraphMismatch
1961 param_reqs = {"rank": None, "dtype": None, "shape": None}
1962 error_result = False
1963 error_reason = "Input list shape does not match then-graph shape"
1964
1965 if check:
1966 a = kwargs["a"]
1967 b = kwargs["b"]
1968 basicBlocks = kwargs["basicBlocks"]
1969 then_block = basicBlocks[1]
1970 then_inputs = then_block.inputs
1971 then_tens = then_block.tensors
1972 if (a.shape != then_tens[then_inputs[0]].shape) or (
1973 b.shape != then_tens[then_inputs[1]].shape
1974 ):
1975 error_result = True
1976
1977 info_dict = {
1978 "error_name": error_name,
1979 "error_result": error_result,
1980 "error_reason": error_reason,
1981 "param_reqs": param_reqs,
1982 }
1983 return info_dict
1984
1985 @staticmethod
1986 def evInputListElseGraphMismatch(check=False, **kwargs):
1987 error_name = ErrorIf.CondIfInputListElseGraphMismatch
1988 param_reqs = {"rank": None, "dtype": None, "shape": None}
1989 error_result = False
1990 error_reason = "Input list shape does not match else-graph shape"
1991
1992 if check:
1993 a = kwargs["a"]
1994 b = kwargs["b"]
1995 basicBlocks = kwargs["basicBlocks"]
1996 else_block = basicBlocks[2]
1997 else_inputs = else_block.inputs
1998 else_tens = else_block.tensors
1999 if (a.shape != else_tens[else_inputs[0]].shape) or (
2000 b.shape != else_tens[else_inputs[1]].shape
2001 ):
2002 error_result = True
2003
2004 info_dict = {
2005 "error_name": error_name,
2006 "error_result": error_result,
2007 "error_reason": error_reason,
2008 "param_reqs": param_reqs,
2009 }
2010 return info_dict
2011
2012 @staticmethod
2013 def evOutputListThenGraphMismatch(check=False, **kwargs):
2014 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2015 param_reqs = {"rank": None, "dtype": None, "shape": None}
2016 error_result = False
2017 error_reason = "Output list shape does not match then-graph shape"
2018
2019 if check:
2020 basicBlocks = kwargs["basicBlocks"]
2021 cond_block = basicBlocks[0]
2022 cond_outputs = cond_block.outputs
2023 cond_tens = cond_block.tensors
2024 then_block = basicBlocks[1]
2025 then_outputs = then_block.outputs
2026 then_tens = then_block.tensors
2027 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2028 error_result = True
2029
2030 info_dict = {
2031 "error_name": error_name,
2032 "error_result": error_result,
2033 "error_reason": error_reason,
2034 "param_reqs": param_reqs,
2035 }
2036 return info_dict
2037
2038 @staticmethod
2039 def evOutputListElseGraphMismatch(check=False, **kwargs):
2040 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2041 param_reqs = {"rank": None, "dtype": None, "shape": None}
2042 error_result = False
2043 error_reason = "Output list shape does not match else-graph shape"
2044
2045 if check:
2046 basicBlocks = kwargs["basicBlocks"]
2047 cond_block = basicBlocks[0]
2048 cond_outputs = cond_block.outputs
2049 cond_tens = cond_block.tensors
2050 else_block = basicBlocks[2]
2051 else_outputs = else_block.outputs
2052 else_tens = else_block.tensors
2053 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2054 error_result = True
2055
2056 info_dict = {
2057 "error_name": error_name,
2058 "error_result": error_result,
2059 "error_reason": error_reason,
2060 "param_reqs": param_reqs,
2061 }
2062 return info_dict
2063
2064 @staticmethod
2065 def evInputListOutputListMismatch(check=False, **kwargs):
2066 error_name = ErrorIf.InputListOutputListMismatch
2067 param_reqs = {"rank": None, "dtype": None, "shape": None}
2068 error_result = False
2069 error_reason = "Input list does not match output list"
2070
2071 if check:
2072 basicBlocks = kwargs["basicBlocks"]
2073 while_block = basicBlocks[0]
2074 while_inputs = while_block.inputs
2075 while_outputs = while_block.outputs
2076 while_tens = while_block.tensors
2077 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2078 error_result = True
2079
2080 info_dict = {
2081 "error_name": error_name,
2082 "error_result": error_result,
2083 "error_reason": error_reason,
2084 "param_reqs": param_reqs,
2085 }
2086 return info_dict
2087
2088 @staticmethod
2089 def evInputListCondGraphMismatch(check=False, **kwargs):
2090 error_name = ErrorIf.InputListCondGraphMismatch
2091 param_reqs = {"rank": None, "dtype": None, "shape": None}
2092 error_result = False
2093 error_reason = "Input list does not match cond graph"
2094
2095 if check:
2096 basicBlocks = kwargs["basicBlocks"]
2097 while_block = basicBlocks[0]
2098 while_inputs = while_block.inputs
2099 while_tens = while_block.tensors
2100 cond_block = basicBlocks[1]
2101 cond_inputs = cond_block.inputs
2102 cond_tens = cond_block.tensors
2103 if (
2104 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2105 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2106 error_result = True
2107
2108 info_dict = {
2109 "error_name": error_name,
2110 "error_result": error_result,
2111 "error_reason": error_reason,
2112 "param_reqs": param_reqs,
2113 }
2114 return info_dict
2115
2116 @staticmethod
2117 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2118 error_name = ErrorIf.InputListBodyGraphInputMismatch
2119 param_reqs = {"rank": None, "dtype": None, "shape": None}
2120 error_result = False
2121 error_reason = "Input list does not match body graph input"
2122
2123 if check:
2124 basicBlocks = kwargs["basicBlocks"]
2125 while_block = basicBlocks[0]
2126 while_inputs = while_block.inputs
2127 while_tens = while_block.tensors
2128 body_block = basicBlocks[2]
2129 body_outputs = body_block.inputs
2130 body_tens = body_block.tensors
2131 if (
2132 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2133 ) or (
2134 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2135 ):
2136 error_result = True
2137
2138 info_dict = {
2139 "error_name": error_name,
2140 "error_result": error_result,
2141 "error_reason": error_reason,
2142 "param_reqs": param_reqs,
2143 }
2144 return info_dict
2145
2146 @staticmethod
2147 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2148 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2149 param_reqs = {"rank": None, "dtype": None, "shape": None}
2150 error_result = False
2151 error_reason = "Input list does not match body graph output"
2152
2153 if check:
2154 basicBlocks = kwargs["basicBlocks"]
2155 while_block = basicBlocks[0]
2156 while_inputs = while_block.inputs
2157 while_tens = while_block.tensors
2158 body_block = basicBlocks[2]
2159 body_outputs = body_block.outputs
2160 body_tens = body_block.tensors
2161 if (
2162 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2163 ) or (
2164 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2165 ):
2166 error_result = True
2167 info_dict = {
2168 "error_name": error_name,
2169 "error_result": error_result,
2170 "error_reason": error_reason,
2171 "param_reqs": param_reqs,
2172 }
2173 return info_dict
2174
2175 @staticmethod
2176 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2177 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2178 param_reqs = {"rank": None, "dtype": None, "shape": None}
2179 error_result = False
2180 error_reason = "Cond graph output is not a match list of booleans"
2181
2182 if check:
2183 basicBlocks = kwargs["basicBlocks"]
2184 cond_block = basicBlocks[1]
2185 cond_outputs = cond_block.outputs
2186 cond_tens = cond_block.tensors
2187 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2188 error_result = True
2189
2190 info_dict = {
2191 "error_name": error_name,
2192 "error_result": error_result,
2193 "error_reason": error_reason,
2194 "param_reqs": param_reqs,
2195 }
2196 return info_dict
2197
2198
2199class TosaInvalidValidator:
2200 @staticmethod
2201 def ivWrongDataTypeOrModeResize(**kwargs):
2202 input_dtype = kwargs["input_dtype"]
2203 args = kwargs["args"]
2204 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002205 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002206
2207 if mode == ResizeMode.BILINEAR:
2208 # Invalid output data type / Invalid input datatype
2209 return (
2210 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002211 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
2212 and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002213 )
2214 elif mode == ResizeMode.NEAREST:
2215 # Invalid output data type / Invalid input datatype
2216 return (input_dtype != output_dtype) or (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002217 input_dtype not in [DType.INT8, DType.INT16, DType.FLOAT]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002218 )
2219 else:
2220 # Invalid resize mode
2221 return True
2222
2223 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002224 def ivHeightWidthInvalid(**kwargs):
2225 opName = kwargs["opName"]
2226
2227 inputShapes = kwargs["shapeList"]
2228 input_shape = inputShapes[0]
2229
2230 args = kwargs["args"]
2231 strides = args[0]
2232 padding = args[1]
2233
2234 if opName.endswith("pool2d"):
2235 # avg_pool2d, max_pool2d
2236 kernel_shape = args[2]
2237 h = (
2238 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2239 ) // strides[0]
2240 w = (
2241 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2242 ) // strides[1]
2243 # return True if any dimension is < 1
2244 return h < 1 or w < 1
2245
2246 if opName.startswith("transpose_conv2d"):
2247 # transpose_conv2d
TatWai Chong24594f52022-06-08 00:48:04 -07002248 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002249 filter_shape = inputShapes[1]
2250 kernel_shape = filter_shape[1:-1]
2251
TatWai Chong24594f52022-06-08 00:48:04 -07002252 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002253 """Calculate the transpose_conv2d output size for a dimension.
2254
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002255 Args:
2256 in_size: the input size - int
2257 stride: the stride - int
2258 kernel_size: the kernel size - int
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002259 out_pad: the output padding - int
2260 in_pad: the input padding - int
2261
2262 Returns:
2263 the output size
2264 """
TatWai Chong24594f52022-06-08 00:48:04 -07002265 return (in_size - 1) * stride + kernel_size - in_pad - out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002266
2267 for pad_h, pad_w in (
2268 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2269 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2270 (0, 0), # VALID padding
2271 ):
2272 h = get_out_size(
2273 input_shape[1],
2274 strides[0],
2275 kernel_shape[0],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002276 padding[0],
2277 pad_h,
2278 )
2279 w = get_out_size(
2280 input_shape[2],
2281 strides[1],
2282 kernel_shape[1],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002283 padding[1],
2284 pad_w,
2285 )
2286 if output_shape[1] == h and output_shape[2] == w:
2287 return False
2288
2289 # output shape does not match the expected shape for any padding option
2290 return True
2291
2292 if "conv2d" in opName or "conv3d" in opName:
2293 # conv2d, conv3d, depthwise_conv2d
2294 dilations = args[2]
2295 filter_shape = inputShapes[1]
2296 kernel_shape = (
2297 filter_shape[0:2]
2298 if opName.startswith("depthwise_conv2d")
2299 else filter_shape[1:-1]
2300 )
2301
2302 for i in range(len(kernel_shape)):
2303 dim = (
2304 input_shape[i + 1]
2305 - kernel_shape[i]
2306 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2307 + padding[i * 2 + 0]
2308 + padding[i * 2 + 1]
2309 ) // strides[i] + 1
2310 # return True if any dimension is < 1
2311 if dim < 1:
2312 return True
2313 return False
2314
2315 assert False, f"Unrecognized Op: {opName}"
2316
2317 @staticmethod
2318 def ivNonPositiveOutputShape(**kwargs):
2319 args = kwargs["args"]
TatWai Chong24594f52022-06-08 00:48:04 -07002320 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002321 if output_shape[1] <= 0 or output_shape[2] <= 0:
2322 # Negative output shape
2323 return True
2324 return False