blob: f9a00f90818a426a864000d02786068a4552a773 [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005from generator.tosa_utils import product
6from generator.tosa_utils import usableDTypes
7from generator.tosa_utils import valueToName
8from tosa.DType import DType
9from tosa.Op import Op
10from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011
Matthew Haddone86fd342021-09-07 16:12:21 +010012
13class ErrorIf(object):
14 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010015 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
16 ScaleNLargerMax = "ScaleNLargerMax"
17 ScaleDLargerMax = "ScaleDLargerMax"
18 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010019 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010020 BorderSmallerMin = "BorderSmallerMin"
21 BorderLargerEqualMax = "BorderLargerEqualMax"
22 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
23 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010024 WrongInputType = "WrongInputType"
25 WrongOutputType = "WrongOutputType"
26 WrongInputList = "WrongInputList"
27 WrongOutputList = "WrongOutputList"
28 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010029 BatchMismatch = "BatchMismatch"
30 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010031 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000032 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010033 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010034 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010036 AxisSmallerZero = "AxisSmallerZero"
37 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010038 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
39 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010040 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010041 KernelSmallerOne = "KernelSmallerOne"
42 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000043 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010044 PadSmallerZero = "PadSmallerZero"
45 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010046 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010047 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010048 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
49 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
50 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010051 ScaleNotTrue = "ScaleNotTrue"
52 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010053 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
54 StartSmallerZero = "StartSmallerZero"
55 SizeSmallerEqualZero = "SizeSmallerEqualZero"
56 StartSizeOutsideBounds = "StartSizeOutsideBounds"
57 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
58 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
59 IndexOutsideBounds = "IndexOutsideBounds"
60 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010061 MaxSmallerMin = "MaxSmallerMin"
62 ConcatInputRankMismatch = "ConcatInputRankMismatch"
63 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010064 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010065 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
66 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
67 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
68 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
69 InputListOutputListMismatch = "InputListOutputListMismatch"
70 InputListCondGraphMismatch = "InputListCondGraphMismatch"
71 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
72 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
73 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010074 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
75 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010076
77
78class TosaErrorIfArgGen:
79 @staticmethod
80 def eiResizeErrorIf(
81 testGen,
82 error_name,
83 mode,
84 dtype,
85 shapeList,
86 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010087 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010088 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010089 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010090 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010091 if error_name == ErrorIf.ScaleSmallerEqualZero:
92 index = testGen.randInt(low=0, high=4)
93 scale[index] = testGen.rng.choice([-2, -1, 0])
94 elif error_name == ErrorIf.ScaleNLargerMax:
95 index = testGen.rng.choice([0, 2])
96 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
97 elif error_name == ErrorIf.ScaleDLargerMax:
98 index = testGen.rng.choice([1, 3])
99 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100101 if error_name == ErrorIf.OffsetLargerEqualMax:
102 index = testGen.rng.choice([0, 1])
103 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
104 elif error_name == ErrorIf.OffsetSmallerMin:
105 index = testGen.rng.choice([0, 1])
106 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
107
108 if error_name == ErrorIf.BorderLargerEqualMax:
109 index = testGen.rng.choice([0, 1])
110 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
111 elif error_name == ErrorIf.BorderSmallerMin:
112 index = testGen.rng.choice([0, 1])
113 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100114
115 if error_name == ErrorIf.WrongOutputType:
116 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
117 incorrect_types = (
118 DType.INT4,
119 DType.INT16,
120 DType.INT32,
121 DType.INT48,
122 DType.FLOAT,
123 )
124 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
125 incorrect_types = (
126 DType.INT4,
127 DType.INT8,
128 DType.INT32,
129 DType.INT48,
130 DType.FLOAT,
131 )
132 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
133 incorrect_types = (
134 DType.INT4,
135 DType.INT8,
136 DType.INT16,
137 DType.INT48,
138 DType.FLOAT,
139 )
140 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
141 incorrect_types = (
142 DType.INT4,
143 DType.INT8,
144 DType.INT16,
145 DType.INT32,
146 DType.FLOAT,
147 )
148 elif dtype == DType.FLOAT:
149 incorrect_types = (
150 DType.INT4,
151 DType.INT8,
152 DType.INT16,
153 DType.INT32,
154 DType.INT48,
155 )
156 outputDType = testGen.rng.choice(a=incorrect_types)
157
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100158 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100159
160 @staticmethod
161 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
162 if (
163 error_name == ErrorIf.StrideSmallerOne
164 # padding must not exceed the kernel size
165 and pad[0] < kernel[0]
166 and pad[1] < kernel[0]
167 and pad[2] < kernel[1]
168 and pad[3] < kernel[1]
169 ):
170 wrongStride = (
171 testGen.rng.choice([0, -1, -2, -3]),
172 testGen.rng.choice([0, -1, -2, -3]),
173 )
174 return wrongStride, pad, kernel
175 elif error_name == ErrorIf.PadSmallerZero:
176 wrongPad = (
177 testGen.rng.choice([-1, -2, -3]),
178 testGen.rng.choice([-1, -2, -3]),
179 testGen.rng.choice([-1, -2, -3]),
180 testGen.rng.choice([-1, -2, -3]),
181 )
182 return stride, wrongPad, kernel
183 elif error_name == ErrorIf.KernelSmallerOne:
184 wrongKernel = (
185 testGen.rng.choice([0, -1, -2, -3]),
186 testGen.rng.choice([0, -1, -2, -3]),
187 )
188 return stride, pad, wrongKernel
189 elif error_name == ErrorIf.PadLargerEqualKernel:
190 wrongPad = (
191 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
192 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
193 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
194 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
195 )
196 return stride, wrongPad, kernel
197 else:
198 return None, None, None
199
200 @staticmethod
201 def eiRescaleWrongOutputType(input_dtype, output_dtype):
202 if input_dtype == DType.INT8:
203 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
204 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100205 elif input_dtype == DType.INT16:
206 if output_dtype not in [
207 DType.UINT8,
208 DType.INT8,
209 DType.UINT16,
210 DType.INT16,
211 DType.INT32,
212 ]:
213 return True
214 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100215 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
216 return True
217 elif input_dtype == DType.INT48:
218 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
219 return True
220 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100221 if output_dtype not in [DType.INT8, DType.INT16]:
222 return True
223 elif input_dtype == DType.UINT16:
224 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100225 return True
226 return False
227
228 @staticmethod
229 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
230 # Mess up input/output tensors for ERROR_IF checks
231 if error_name == "WrongInputList":
232 add_input = testGen.rng.choice([True, False])
233 if add_input:
234 input_list.append("eiDummyInput")
235 else:
236 input_list = input_list[:-1]
237 elif error_name == "WrongOutputList":
238 add_output = testGen.rng.choice([True, False])
239 if add_output:
240 output_list.append("eiDummyOutput")
241 else:
242 output_list = []
243 return input_list, output_list
244
245 @staticmethod
246 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
247 """Restrict the dimensions and overall size of a shape to
248 max_dim and max_items.
249 """
250 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
251 while product(new_shape) > max_items:
252 new_shape = [max(d - 1, 1) for d in new_shape]
253 return new_shape
254
255 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
256 if error_name == ErrorIf.StartSmallerZero:
257 newStart = []
258 for i in range(len(input_shape)):
259 newStart.append(testGen.rng.choice([-3, -2, -1]))
260 return newStart, size
261 elif error_name == ErrorIf.SizeSmallerEqualZero:
262 newSize = []
263 for i in range(len(input_shape)):
264 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
265 return start, newSize
266 elif error_name == ErrorIf.StartSizeOutsideBounds:
267 newStart, newSize = [], []
268 for i in range(len(input_shape)):
269 newStart.append(input_shape[i] - 1)
270 newSize.append(testGen.rng.choice([2, 3, 4]))
271 return newStart, newSize
272 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
273 remove = testGen.rng.choice([True, False])
274 if remove:
275 newStart = start[1:]
276 newSize = size[1:]
277 else:
278 newStart = start
279 newStart.append(1)
280 newSize = size
281 newSize.append(1)
282 return newStart, newSize
283 else:
284 return start, size
285
286 @staticmethod
287 def eiCastErrorIf(testGen, input_dtype):
288 if input_dtype in [DType.BOOL, DType.FLOAT]:
289 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
290 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
291 outputDType = [DType.INT48]
292 else:
293 assert True, f"input_dtype ({input_dtype}) not supported"
294 return outputDType
295
296
297class TosaErrorValidator:
298 @staticmethod
299 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
300 """Check ERROR_IF statements are caught and set the expected result.
301
302 Args:
303 serializer: the serializer to set the expected result in
304 validator_fcns: a sequence of validator functions to verify the result
305 error_name: the name of the ERROR_IF condition to check for
306 kwargs: keyword arguments for the validator functions
307 Returns:
308 True if the result matches the expected result; otherwise False
309 """
310 overall_result = True
311 for val_fcn in validator_fcns:
312 val_result = val_fcn(True, **kwargs)
313 validator_name = val_result["error_name"]
314 error_result = val_result["error_result"]
315 error_reason = val_result["error_reason"]
316
317 # expect an error IFF the error_name and validator_name match
318 expected_result = error_result == (error_name == validator_name)
319 overall_result &= expected_result
320
321 if expected_result and error_result:
322 serializer.setExpectedReturnCode(2, True, desc=error_reason)
323 elif error_result: # and not expected_result
324 print(
325 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
326 f" Expected: {error_name}, Got: {validator_name}"
327 )
328 elif not expected_result: # and not error_result
329 print(
330 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
331 f" Expected: {error_name}"
332 )
333
334 if not expected_result:
335 for k, v in sorted(kwargs.items()):
336 if k != "op":
337 if k.endswith("dtype"):
338 v = valueToName(DType, v)
339 print(f" {k} = {v}")
340
341 return overall_result
342
343 @staticmethod
344 def evWrongInputType(check=False, **kwargs):
345 error_result = False
346
347 # Find the unsupported input data types
348 op = kwargs["op"]
349 input_dtypes = op["types"]
350 allowed_input_dtypes = {
351 t[0] if isinstance(t, list) else t for t in input_dtypes
352 }
353 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
354
355 if op["op"] == Op.CLAMP:
356 wrong_input_dtypes.remove(DType.INT48)
357
358 if check:
359 input_dtype = kwargs["input_dtype"]
360 if input_dtype not in allowed_input_dtypes:
361 error_result = True
362
363 info_dict = {
364 "error_name": ErrorIf.WrongInputType,
365 "error_result": error_result,
366 "error_reason": "Input data type not supported for this operator",
367 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
368 }
369 return info_dict
370
371 @staticmethod
372 def evWrongOutputType(check=False, **kwargs):
373 error_result = False
374
375 if check:
376 input_dtype = kwargs["input_dtype"]
377 output_dtype = kwargs["output_dtype"]
378 op = kwargs["op"]
379
380 if op["op"] == Op.RESIZE:
381 mode = kwargs["mode"]
382 if (
383 (
384 mode == ResizeMode.NEAREST
385 and input_dtype == DType.INT8
386 and output_dtype != DType.INT8
387 )
388 or (
389 mode == ResizeMode.NEAREST
390 and input_dtype == DType.INT16
391 and output_dtype != DType.INT16
392 )
393 or (
394 mode == ResizeMode.BILINEAR
395 and input_dtype == DType.INT8
396 and output_dtype != DType.INT32
397 )
398 or (
399 mode == ResizeMode.BILINEAR
400 and input_dtype == DType.INT16
401 and output_dtype != DType.INT48
402 )
403 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
404 ):
405 error_result = True
406
407 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100408 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
409 input_dtype, output_dtype
410 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100411
412 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
413 if (
414 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
415 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
416 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
417 ):
418 error_result = True
419
420 elif op["op"] == Op.ARGMAX:
421 if (
422 input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
423 and output_dtype != DType.INT32
424 ):
425 error_result = True
426
427 elif op["op"] == Op.MUL:
428 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
429 error_result = True
430 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
431 error_result = True
432
433 elif op["op"] == Op.TABLE:
434 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
435 error_result = True
436 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
437 error_result = True
438
439 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
440 if output_dtype != DType.BOOL:
441 error_result = True
442
443 elif op["op"] == Op.CAST:
444 if (
445 (
446 input_dtype == DType.BOOL
447 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
448 )
449 or (
450 input_dtype == DType.INT8
451 and output_dtype
452 not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
453 )
454 or (
455 input_dtype == DType.INT16
456 and output_dtype
457 not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
458 )
459 or (
460 input_dtype == DType.INT32
461 and output_dtype
462 not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
463 )
464 or (
465 input_dtype == DType.FLOAT
466 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
467 )
468 ):
469 error_result = True
470
471 elif op["op"] in {
472 Op.CONV2D,
473 Op.CONV3D,
474 Op.DEPTHWISE_CONV2D,
475 Op.TRANSPOSE_CONV2D,
476 }:
477 if (
478 input_dtype == DType.INT8
479 and output_dtype != DType.INT32
480 or input_dtype == DType.INT16
481 and output_dtype != DType.INT48
482 or input_dtype == DType.FLOAT
483 and output_dtype != DType.FLOAT
484 ):
485 error_result = True
486 # invalid input types are ignored, to avoid reporting multiple errors
487
488 else:
489 if output_dtype != input_dtype:
490 error_result = True
491
492 info_dict = {
493 "error_name": ErrorIf.WrongOutputType,
494 "error_result": error_result,
495 "error_reason": (
496 "Output data type not supported for this configuration of operator"
497 ),
498 "param_reqs": {"rank": None, "dtype": None, "shape": None},
499 }
500 return info_dict
501
502 @staticmethod
503 def evWrongRank(check=False, **kwargs):
504 all_ranks = (1, 2, 3, 4, 5)
505
506 # Make a list of incorrect ranks
507 assert "op" in kwargs
508 op = kwargs["op"]
509 rmin, rmax = op["rank"]
510 rank_range = range(rmin, rmax + 1)
511 incorrect_ranks = list(set(all_ranks) - set(rank_range))
512 # Remove small incorrect ranks to avoid index errors
513 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
514 # Set minimum incorrect rank to 3 to avoid index error
515 if op["op"] in [Op.RESIZE]:
516 incorrect_ranks = [3, 5]
517 elif op["op"] in [Op.TRANSPOSE]:
518 incorrect_ranks = [7, 8]
519 elif op["op"] in [Op.CONV3D]:
520 incorrect_ranks = [6, 7]
521
522 error_name = ErrorIf.WrongRank
523 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
524 error_result = False
525 error_reason = "Rank not supported for this operator"
526
527 if check:
528 input_shape = kwargs["input_shape"]
529
530 if (
531 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
532 and len(input_shape) != 4
533 ):
534 error_result = True
535 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
536 error_result = True
537 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
538 error_result = True
539 else:
540 if len(input_shape) not in rank_range:
541 error_result = True
542
543 info_dict = {
544 "error_name": error_name,
545 "error_result": error_result,
546 "error_reason": error_reason,
547 "param_reqs": param_reqs,
548 }
549 return info_dict
550
551 @staticmethod
552 def evWrongInputList(check=False, **kwargs):
553 error_name = ErrorIf.WrongInputList
554 param_reqs = {"rank": None, "dtype": None, "shape": None}
555 error_result = False
556 error_reason = "Op input list does not match expected input"
557
558 if check:
559 op = kwargs["op"]
560 input_list = kwargs["input_list"]
561 num_operands = kwargs["num_operands"]
562 if op["op"] in [Op.SCATTER, Op.GATHER]:
563 # SCATTER/GATHER add an indices input tensor in their build functions
564 num_operands += 1
565 if len(input_list) != num_operands:
566 error_result = True
567
568 info_dict = {
569 "error_name": error_name,
570 "error_result": error_result,
571 "error_reason": error_reason,
572 "param_reqs": param_reqs,
573 }
574 return info_dict
575
576 @staticmethod
577 def evWrongOutputList(check=False, **kwargs):
578 error_name = ErrorIf.WrongOutputList
579 param_reqs = {"rank": None, "dtype": None, "shape": None}
580 error_result = False
581 error_reason = "Op output list does not match expected output"
582
583 if check:
584 output_list = kwargs["output_list"]
585 # Note this will be incorrect if an operator returns more than one output
586 if len(output_list) != 1:
587 error_result = True
588
589 info_dict = {
590 "error_name": error_name,
591 "error_result": error_result,
592 "error_reason": error_reason,
593 "param_reqs": param_reqs,
594 }
595 return info_dict
596
597 @staticmethod
598 def evMaxDimExceeded(check=False, **kwargs):
599 error_name = ErrorIf.MaxDimExceeded
600 param_reqs = {
601 "rank": [4, 4],
602 "dtype": [DType.INT8],
603 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
604 }
605 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100606 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100607
608 if check:
609 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100610 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100611 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100612 (input_shape[1] >= MAX_RESIZE_DIMENSION)
613 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
614 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
615 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100616 ):
617 error_result = True
618
619 info_dict = {
620 "error_name": error_name,
621 "error_result": error_result,
622 "error_reason": error_reason,
623 "param_reqs": param_reqs,
624 }
625 return info_dict
626
627 @staticmethod
628 def evBatchMismatch(check=False, **kwargs):
629 error_name = ErrorIf.BatchMismatch
630 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
631 error_result = False
632 error_reason = "Input batch size not equal to output batch size"
633
634 assert "op" in kwargs
635 op = kwargs["op"]
636 rmin, rmax = op["rank"]
637 rank_range = range(rmin, rmax + 1)
638
639 if check:
640 input_shape = kwargs["input_shape"]
641 output_shape = kwargs[
642 "result_tensor"
643 ].shape # Note this is just (N, OH, OW, C)
644
645 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
646 error_result = True
647
648 info_dict = {
649 "error_name": error_name,
650 "error_result": error_result,
651 "error_reason": error_reason,
652 "param_reqs": param_reqs,
653 }
654 return info_dict
655
656 @staticmethod
657 def evChannelMismatch(check=False, **kwargs):
658 error_name = ErrorIf.ChannelMismatch
659 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
660 error_result = False
661 error_reason = "Input channel size not equal to output channel size"
662
663 assert "op" in kwargs
664 op = kwargs["op"]
665 rmin, rmax = op["rank"]
666 rank_range = range(rmin, rmax + 1)
667
668 if check:
669 input_shape = kwargs["input_shape"]
670 output_shape = kwargs[
671 "result_tensor"
672 ].shape # Note this is just (N, OH, OW, C)
673 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
674 error_result = True
675
676 info_dict = {
677 "error_name": error_name,
678 "error_result": error_result,
679 "error_reason": error_reason,
680 "param_reqs": param_reqs,
681 }
682 return info_dict
683
684 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100685 def evScaleSmallerEqualZero(check=False, **kwargs):
686 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100687 param_reqs = {"rank": None, "dtype": None, "shape": None}
688 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100689 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100690
691 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100692 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100693
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100694 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100695 error_result = True
696
697 info_dict = {
698 "error_name": error_name,
699 "error_result": error_result,
700 "error_reason": error_reason,
701 "param_reqs": param_reqs,
702 }
703 return info_dict
704
705 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100706 def evScaleNLargerMax(check=False, **kwargs):
707 error_name = ErrorIf.ScaleNLargerMax
708 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100709 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100710 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100711
712 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100713 scale = kwargs["scale"]
714
715 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
716 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100717
718 info_dict = {
719 "error_name": error_name,
720 "error_result": error_result,
721 "error_reason": error_reason,
722 "param_reqs": param_reqs,
723 }
724 return info_dict
725
726 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100727 def evScaleDLargerMax(check=False, **kwargs):
728 error_name = ErrorIf.ScaleDLargerMax
729 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100730 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100731 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100732
733 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100734 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100735
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100736 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
737 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100738 ):
739 error_result = True
740
741 info_dict = {
742 "error_name": error_name,
743 "error_result": error_result,
744 "error_reason": error_reason,
745 "param_reqs": param_reqs,
746 }
747 return info_dict
748
749 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100750 def evOffsetSmallerMin(check=False, **kwargs):
751 error_name = ErrorIf.OffsetSmallerMin
752 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100753 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100754 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100755
756 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100757 scale = kwargs["scale"]
758 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100759
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100760 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100761 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100762 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100763 error_result = True
764
765 info_dict = {
766 "error_name": error_name,
767 "error_result": error_result,
768 "error_reason": error_reason,
769 "param_reqs": param_reqs,
770 }
771 return info_dict
772
773 @staticmethod
774 def evOffsetLargerEqualMax(check=False, **kwargs):
775 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100776 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100777 error_result = False
778 error_reason = "Offset value larger than or equal to maximum value"
779
780 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100781 scale = kwargs["scale"]
782 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100783
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100784 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
785 error_result = True
786 elif (
787 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
788 ):
789 error_result = True
790
791 info_dict = {
792 "error_name": error_name,
793 "error_result": error_result,
794 "error_reason": error_reason,
795 "param_reqs": param_reqs,
796 }
797 return info_dict
798
799 @staticmethod
800 def evBorderSmallerMin(check=False, **kwargs):
801 error_name = ErrorIf.BorderSmallerMin
802 param_reqs = {"rank": None, "dtype": None, "shape": None}
803 error_result = False
804 error_reason = "Border value smaller than minimum value"
805
806 if check:
807 scale = kwargs["scale"]
808 border = kwargs["border"]
809
810 if (
811 scale[0] > 0
812 and scale[0] <= (1 << 11)
813 and (border[0] < (-16 * scale[0]))
814 ):
815 error_result = True
816 elif (
817 scale[2] > 0
818 and scale[2] <= (1 << 11)
819 and (border[1] < (-16 * scale[2]))
820 ):
821 error_result = True
822
823 info_dict = {
824 "error_name": error_name,
825 "error_result": error_result,
826 "error_reason": error_reason,
827 "param_reqs": param_reqs,
828 }
829 return info_dict
830
831 @staticmethod
832 def evBorderLargerEqualMax(check=False, **kwargs):
833 error_name = ErrorIf.BorderLargerEqualMax
834 param_reqs = {"rank": None, "dtype": None, "shape": None}
835 error_result = False
836 error_reason = "Border value larger than or equal to maximum value"
837
838 if check:
839 scale = kwargs["scale"]
840 border = kwargs["border"]
841
842 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
843 error_result = True
844 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
845 error_result = True
846
847 info_dict = {
848 "error_name": error_name,
849 "error_result": error_result,
850 "error_reason": error_reason,
851 "param_reqs": param_reqs,
852 }
853 return info_dict
854
855 @staticmethod
856 def checkResizeParams(scale, offset, border):
857 return (
858 min(scale) > 0
859 and max(scale[0], scale[2]) <= (1 << 11)
860 and scale[1] < 16 * scale[0]
861 and scale[3] < 16 * scale[2]
862 and offset[0] >= -scale[0]
863 and offset[1] >= -scale[2]
864 and offset[0] < 16 * scale[0]
865 and offset[1] < 16 * scale[2]
866 and border[0] >= -16 * scale[0]
867 and border[1] >= -16 * scale[2]
868 and border[0] < scale[0]
869 and border[1] < scale[2]
870 )
871
872 @staticmethod
873 def evResizeOutputShapeMismatch(check=False, **kwargs):
874 error_name = ErrorIf.ResizeOutputShapeMismatch
875 param_reqs = {"rank": None, "dtype": None, "shape": None}
876 error_result = False
877 error_reason = (
878 "Mismatch between output shape provided and expected output shape"
879 )
880
881 if check:
882 input_shape = kwargs["input_shape"]
883 output_shape = kwargs["output_shape"]
884 scale = kwargs["scale"]
885 offset = kwargs["offset"]
886 border = kwargs["border"]
887
888 # Ensure parameters are valid
889 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
890
891 if (
892 params_valid
893 and max(output_shape) < MAX_RESIZE_DIMENSION
894 and max(input_shape) < MAX_RESIZE_DIMENSION
895 ):
896 output_y = (
897 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
898 ) // scale[1] + 1
899 output_x = (
900 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
901 ) // scale[3] + 1
902
903 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 error_result = True
905
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100906 info_dict = {
907 "error_name": error_name,
908 "error_result": error_result,
909 "error_reason": error_reason,
910 "param_reqs": param_reqs,
911 }
912 return info_dict
913
914 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100915 def evResizeOutputShapeNonInteger(check=False, **kwargs):
916 error_name = ErrorIf.ResizeOutputShapeNonInteger
917 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100918 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100919 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920
921 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100922 input_shape = kwargs["input_shape"]
923 scale = kwargs["scale"]
924 offset = kwargs["offset"]
925 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100926
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100927 # Ensure parameters are valid
928 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100929
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100930 if params_valid:
931 remainder_y = (
932 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
933 ) % scale[1]
934 remainder_x = (
935 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
936 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100937
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100938 if max(remainder_y, remainder_x) > 0:
939 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100940
941 info_dict = {
942 "error_name": error_name,
943 "error_result": error_result,
944 "error_reason": error_reason,
945 "param_reqs": param_reqs,
946 }
947 return info_dict
948
949 @staticmethod
950 def evRankMismatch(check=False, **kwargs):
951 error_name = ErrorIf.RankMismatch
952 param_reqs = {"rank": None, "dtype": None, "shape": None}
953 error_result = False
954 error_reason = "Input Rank does not match output rank"
955
956 if check:
957 input1_shape = kwargs["input1"].shape
958 input2_shape = kwargs["input2"].shape
959 # In case of SELECT op
960 input3_shape = (
961 kwargs["input3"].shape if "input3" in kwargs else input2_shape
962 )
963 output_shape = kwargs["result_tensor"].shape
964 if (
965 (len(input1_shape) != len(output_shape))
966 or (len(input2_shape) != len(output_shape))
967 or (len(input3_shape) != len(output_shape))
968 ):
969 error_result = True
970
971 info_dict = {
972 "error_name": error_name,
973 "error_result": error_result,
974 "error_reason": error_reason,
975 "param_reqs": param_reqs,
976 }
977 return info_dict
978
979 @staticmethod
980 def evDimensionMismatch(check=False, **kwargs):
981 error_name = ErrorIf.DimensionMismatch
982 param_reqs = {"rank": None, "dtype": None, "shape": None}
983 error_result = False
984 error_reason = "Input Dimensions do not match output"
985
986 if check:
987 input1_shape = kwargs["input1"].shape
988 input2_shape = kwargs["input2"].shape
989 # In case of SELECT op
990 input3_shape = (
991 kwargs["input3"].shape if "input3" in kwargs else input2_shape
992 )
993 output_shape = kwargs["result_tensor"].shape
994 for i in range(
995 min(len(input1_shape), len(input2_shape), len(input3_shape))
996 ):
997 if (
998 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
999 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
1000 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
1001 ):
1002 error_result = True
1003
1004 info_dict = {
1005 "error_name": error_name,
1006 "error_result": error_result,
1007 "error_reason": error_reason,
1008 "param_reqs": param_reqs,
1009 }
1010 return info_dict
1011
1012 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001013 def _getZeroPoint(qinfo, index):
1014 """Return zero point value from quantization info.
1015
1016 Generally input_zp is index 0, output_zp is index 1
1017 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001018 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001019
1020 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 def evInputZeroPointNotZero(check=False, **kwargs):
1022 op = kwargs["op"]
1023 error_result = False
1024
1025 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001026 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001027
1028 # This does not apply to quantizable types
1029 inputDtypes = [
1030 dtype
1031 for dtype in op["types"]
1032 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1033 or (not isinstance(dtype, list) and dtype not in qTypes)
1034 ]
1035
1036 if check:
1037 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001038 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001039 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001040 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001041 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001042 (kwargs["input_dtype"], input_zero_point),
1043 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044 ):
1045 if dtype not in qTypes and zp != 0:
1046 error_result = True
1047 break
1048 else:
1049 error_result = input_dtype not in qTypes and input_zero_point != 0
1050
1051 info_dict = {
1052 "error_name": ErrorIf.InputZeroPointNotZero,
1053 "error_result": error_result,
1054 "error_reason": "Input DType not INT8 and zero point not 0",
1055 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1056 }
1057 return info_dict
1058
1059 @staticmethod
1060 def evWeightZeroPointNotZero(check=False, **kwargs):
1061 op = kwargs["op"]
1062
1063 # exclude inputs with INT8 weights
1064 inputDtypes = [
1065 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1066 ]
1067
1068 error_name = ErrorIf.WeightZeroPointNotZero
1069 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1070 error_result = False
1071 error_reason = "Weight DType not INT8 and zero point not 0"
1072
1073 if check:
1074 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001075 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001076 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1077 error_result = True
1078
1079 info_dict = {
1080 "error_name": error_name,
1081 "error_result": error_result,
1082 "error_reason": error_reason,
1083 "param_reqs": param_reqs,
1084 }
1085 return info_dict
1086
1087 @staticmethod
1088 def evOutputZeroPointNotZero(check=False, **kwargs):
1089 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001090 inputDtypes = [
1091 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1092 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001093
1094 error_name = ErrorIf.OutputZeroPointNotZero
1095 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1096 error_result = False
1097 error_reason = "Output DType not INT8 and zero point not 0"
1098
1099 if check:
1100 input_dtype = kwargs["input_dtype"]
1101 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001102 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001103 if op["op"] == Op.AVG_POOL2D:
1104 if input_dtype != DType.INT8 and output_zero_point != 0:
1105 error_result = True
1106 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001107 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1108 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001109 ):
1110 error_result = True
1111
1112 info_dict = {
1113 "error_name": error_name,
1114 "error_result": error_result,
1115 "error_reason": error_reason,
1116 "param_reqs": param_reqs,
1117 }
1118 return info_dict
1119
1120 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001121 def evU16InputZeroPointNotValid(check=False, **kwargs):
1122 error_name = ErrorIf.U16InputZeroPointNotValid
1123 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1124 error_result = False
1125 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1126
1127 if check:
1128 input_dtype = kwargs["input_dtype"]
1129 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1130 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1131 0,
1132 32768,
1133 ]
1134
1135 info_dict = {
1136 "error_name": error_name,
1137 "error_result": error_result,
1138 "error_reason": error_reason,
1139 "param_reqs": param_reqs,
1140 }
1141 return info_dict
1142
1143 @staticmethod
1144 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1145 error_name = ErrorIf.U16OutputZeroPointNotValid
1146 param_reqs = {"rank": None, "dtype": None, "shape": None}
1147 error_result = False
1148 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1149
1150 if check:
1151 output_dtype = kwargs["output_dtype"]
1152 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1153
1154 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1155 0,
1156 32768,
1157 ]
1158
1159 info_dict = {
1160 "error_name": error_name,
1161 "error_result": error_result,
1162 "error_reason": error_reason,
1163 "param_reqs": param_reqs,
1164 }
1165 return info_dict
1166
1167 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001168 def evAxisSmallerZero(check=False, **kwargs):
1169 error_name = ErrorIf.AxisSmallerZero
1170 param_reqs = {"rank": None, "dtype": None, "shape": None}
1171 error_result = False
1172 error_reason = "Axis smaller than zero"
1173
1174 if check:
1175 axis = kwargs["axis"]
1176 if axis < 0:
1177 error_result = True
1178
1179 info_dict = {
1180 "error_name": error_name,
1181 "error_result": error_result,
1182 "error_reason": error_reason,
1183 "param_reqs": param_reqs,
1184 }
1185 return info_dict
1186
1187 @staticmethod
1188 def evAxisLargerRank(check=False, **kwargs):
1189 error_name = ErrorIf.AxisLargerRank
1190 param_reqs = {"rank": None, "dtype": None, "shape": None}
1191 error_result = False
1192 error_reason = "Axis larger than rank"
1193
1194 if check:
1195 axis = kwargs["axis"]
1196 shape = kwargs["input_shape"]
1197 if axis > len(shape):
1198 error_result = True
1199
1200 info_dict = {
1201 "error_name": error_name,
1202 "error_result": error_result,
1203 "error_reason": error_reason,
1204 "param_reqs": param_reqs,
1205 }
1206 return info_dict
1207
1208 @staticmethod
1209 def evShapeOfAxisNotOne(check=False, **kwargs):
1210 error_name = ErrorIf.ShapeOfAxisNotOne
1211 param_reqs = {"rank": None, "dtype": None, "shape": None}
1212 error_result = False
1213 error_reason = "shape[axis] is not equal to 1"
1214
1215 if check:
1216 axis = kwargs["axis"]
1217 shape = kwargs["output_shape"]
1218 if (0 <= axis < len(shape)) and shape[axis] != 1:
1219 error_result = True
1220
1221 info_dict = {
1222 "error_name": error_name,
1223 "error_result": error_result,
1224 "error_reason": error_reason,
1225 "param_reqs": param_reqs,
1226 }
1227 return info_dict
1228
1229 @staticmethod
1230 def evPadSmallerZero(check=False, **kwargs):
1231 error_name = ErrorIf.PadSmallerZero
1232 param_reqs = {"rank": None, "dtype": None, "shape": None}
1233 error_result = False
1234 error_reason = "At least one pad is smaller than zero"
1235
1236 if check:
1237 op = kwargs["op"]
1238 pad = kwargs["pad"]
1239 if op["op"] == Op.PAD:
1240 for padding in pad:
1241 if min(padding) < 0:
1242 error_result = True
1243 else:
1244 if min(pad) < 0:
1245 error_result = True
1246
1247 info_dict = {
1248 "error_name": error_name,
1249 "error_result": error_result,
1250 "error_reason": error_reason,
1251 "param_reqs": param_reqs,
1252 }
1253 return info_dict
1254
1255 @staticmethod
1256 def evPadLargerEqualKernel(check=False, **kwargs):
1257 error_name = ErrorIf.PadLargerEqualKernel
1258 param_reqs = {"rank": None, "dtype": None, "shape": None}
1259 error_result = False
1260 error_reason = "At least one pad is larger than kernel dimension"
1261
1262 if check:
1263 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001264 op = kwargs["op"]
1265 if op["op"] == Op.TRANSPOSE_CONV2D:
1266 # transpose_conv2d
1267 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001268 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001269 pad[0] <= -kernel[0]
1270 or pad[1] <= -kernel[0]
1271 or pad[2] <= -kernel[1]
1272 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001273 ):
1274 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001275 else:
1276 # pooling op
1277 kernel = kwargs["kernel"]
1278 if min(pad) > 0 and min(kernel) > 1:
1279 if (
1280 pad[0] >= kernel[0]
1281 or pad[1] >= kernel[0]
1282 or pad[2] >= kernel[1]
1283 or pad[3] >= kernel[1]
1284 ):
1285 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001286
1287 info_dict = {
1288 "error_name": error_name,
1289 "error_result": error_result,
1290 "error_reason": error_reason,
1291 "param_reqs": param_reqs,
1292 }
1293 return info_dict
1294
1295 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001296 def evPadOutputShapeMismatch(check=False, **kwargs):
1297 error_name = ErrorIf.PadOutputShapeMismatch
1298 param_reqs = {"rank": None, "dtype": None, "shape": None}
1299 error_result = False
1300 error_reason = "Pad output shape mismatch for requested padding"
1301
1302 if check:
1303 pad = kwargs["pad"]
1304 input_shape = kwargs["input_shape"]
1305 output_shape = kwargs["output_shape"]
1306 for dim, padding in enumerate(pad):
1307 expected_size = input_shape[dim] + padding[0] + padding[1]
1308 if expected_size != output_shape[dim]:
1309 error_result = True
1310
1311 info_dict = {
1312 "error_name": error_name,
1313 "error_result": error_result,
1314 "error_reason": error_reason,
1315 "param_reqs": param_reqs,
1316 }
1317 return info_dict
1318
1319 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001320 def checkPoolingParams(kernel, stride, pad):
1321 return (
1322 min(kernel) >= 1
1323 and min(stride) >= 1
1324 and min(pad) >= 0
1325 and not (
1326 pad[0] >= kernel[0]
1327 or pad[1] >= kernel[0]
1328 or pad[2] >= kernel[1]
1329 or pad[3] >= kernel[1]
1330 )
1331 )
1332
1333 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001334 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1335 error_name = ErrorIf.PoolingOutputShapeMismatch
1336 param_reqs = {"rank": None, "dtype": None, "shape": None}
1337 error_result = False
1338 error_reason = (
1339 "Mismatch between output shape provided and expected output shape"
1340 )
1341
1342 if check:
1343 pad = kwargs["pad"]
1344 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1345
1346 kernel = kwargs["kernel"]
1347 kernel_y, kernel_x = kernel[0], kernel[1]
1348
1349 input_shape = kwargs["input_shape"]
1350 IH, IW = input_shape[1], input_shape[2]
1351
1352 output_shape = kwargs["output_shape"]
1353 OH, OW = output_shape[1], output_shape[2]
1354
1355 stride = kwargs["stride"]
1356 stride_y, stride_x = stride[0], stride[1]
1357
1358 # calculate correct height, width dimensions
1359 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001360 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1361 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001362
1363 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001364 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001365
1366 if params_valid and (OH != y_correct or OW != x_correct):
1367 error_result = True
1368
1369 info_dict = {
1370 "error_name": error_name,
1371 "error_result": error_result,
1372 "error_reason": error_reason,
1373 "param_reqs": param_reqs,
1374 }
1375 return info_dict
1376
1377 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001378 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1379 error_name = ErrorIf.PoolingOutputShapeNonInteger
1380 param_reqs = {"rank": None, "dtype": None, "shape": None}
1381 error_result = False
1382 error_reason = "Parameters do not yield exact integer output dimensions"
1383
1384 if check:
1385 pad = kwargs["pad"]
1386 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1387
1388 kernel = kwargs["kernel"]
1389 kernel_y, kernel_x = kernel[0], kernel[1]
1390
1391 input_shape = kwargs["input_shape"]
1392 IH, IW = input_shape[1], input_shape[2]
1393
1394 stride = kwargs["stride"]
1395 stride_y, stride_x = stride[0], stride[1]
1396
1397 # calculate remainder of height, width dimensions
1398 if stride_x != 0 and stride_y != 0:
1399 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1400 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1401
1402 # ensure parameters are valid
1403 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1404 if params_valid and (y_remainder != 0 or x_remainder != 0):
1405 error_result = True
1406
1407 info_dict = {
1408 "error_name": error_name,
1409 "error_result": error_result,
1410 "error_reason": error_reason,
1411 "param_reqs": param_reqs,
1412 }
1413 return info_dict
1414
1415 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001416 def checkConvParams(op, weight_shape, stride, pad, dilation):
1417 if op == Op.TRANSPOSE_CONV2D:
1418 pad_ok = (
1419 pad[0] > -weight_shape[1]
1420 and pad[1] > -weight_shape[1]
1421 and pad[2] > -weight_shape[2]
1422 and pad[3] > -weight_shape[2]
1423 )
1424 else:
1425 pad_ok = min(pad) >= 0
1426
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001427 return (
1428 # Check kernel sizes
1429 min(weight_shape[1:-1]) >= 1
1430 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001431 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001432 and (dilation is None or min(dilation) >= 1)
1433 )
1434
1435 @staticmethod
1436 def evConvOutputShapeMismatch(check=False, **kwargs):
1437 error_name = ErrorIf.ConvOutputShapeMismatch
1438 param_reqs = {"rank": None, "dtype": None, "shape": None}
1439 error_result = False
1440 error_reason = (
1441 "Mismatch between output shape provided and expected output shape"
1442 )
1443
1444 if check:
1445 op = kwargs["op"]
1446 pad = kwargs["pad"]
1447 weight_shape = kwargs["weight_shape"]
1448 input_shape = kwargs["input_shape"]
1449 output_shape = kwargs["output_shape"]
1450 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1451 stride = kwargs["stride"]
1452
1453 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1454
1455 # calculate correct dimensions
1456 dims_correct = []
1457 if min(stride) > 0:
1458 for index in range(len(stride)):
1459 pad_offset = index * 2
1460 if op["op"] == Op.TRANSPOSE_CONV2D:
1461 dims_correct.append(
1462 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001463 + pad[pad_offset]
1464 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001465 + weight_shape[index + kernel_offset]
1466 )
1467 else:
1468 dims_correct.append(
1469 (
1470 input_shape[index + 1]
1471 - 1
1472 + pad[pad_offset]
1473 + pad[pad_offset + 1]
1474 - (weight_shape[index + kernel_offset] - 1)
1475 * dilation[index]
1476 )
1477 // stride[index]
1478 + 1
1479 )
1480
1481 # ensure parameters are valid
1482 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001483 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001484 )
1485
1486 if params_valid and output_shape[1:-1] != dims_correct:
1487 error_result = True
1488
1489 info_dict = {
1490 "error_name": error_name,
1491 "error_result": error_result,
1492 "error_reason": error_reason,
1493 "param_reqs": param_reqs,
1494 }
1495 return info_dict
1496
1497 @staticmethod
1498 def evConvOutputShapeNonInteger(check=False, **kwargs):
1499 error_name = ErrorIf.ConvOutputShapeNonInteger
1500 param_reqs = {"rank": None, "dtype": None, "shape": None}
1501 error_result = False
1502 error_reason = "Parameters do not yield exact integer output dimensions"
1503
1504 if check:
1505 op = kwargs["op"]
1506 pad = kwargs["pad"]
1507 weight_shape = kwargs["weight_shape"]
1508 input_shape = kwargs["input_shape"]
1509 dilation = kwargs["dilation"]
1510 stride = kwargs["stride"]
1511
1512 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1513
1514 # calculate correct height, width dimensions
1515 remainders = []
1516 if min(stride) > 0:
1517 for index in range(len(stride)):
1518 pad_offset = index * 2
1519 remainders.append(
1520 (
1521 input_shape[index + 1]
1522 - 1
1523 + pad[pad_offset]
1524 + pad[pad_offset + 1]
1525 - (weight_shape[index + kernel_offset] - 1)
1526 * dilation[index]
1527 )
1528 % stride[index]
1529 )
1530
1531 # ensure parameters are valid
1532 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001533 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001534 )
1535 if params_valid and max(remainders) > 0:
1536 error_result = True
1537
1538 info_dict = {
1539 "error_name": error_name,
1540 "error_result": error_result,
1541 "error_reason": error_reason,
1542 "param_reqs": param_reqs,
1543 }
1544 return info_dict
1545
1546 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001547 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1548 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1549 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1550 error_result = False
1551 error_reason = (
1552 "Mismatch between output shape provided and expected output shape"
1553 )
1554
1555 if check:
1556 output_shape = kwargs["output_shape"]
1557 input_shape = kwargs["input_shape"]
1558 axis = kwargs["axis"]
1559
1560 dimension_match = True
1561 axis_shift = 0
1562
1563 # Check that rank is correct before trying to check dimensions
1564 if (len(input_shape) - 1) == len(output_shape):
1565 for i in range(len(input_shape)):
1566 if i == axis:
1567 axis_shift = 1
1568 continue
1569 if input_shape[i] != output_shape[i - axis_shift]:
1570 dimension_match = False
1571
1572 if not dimension_match:
1573 error_result = True
1574
1575 info_dict = {
1576 "error_name": error_name,
1577 "error_result": error_result,
1578 "error_reason": error_reason,
1579 "param_reqs": param_reqs,
1580 }
1581 return info_dict
1582
1583 @staticmethod
1584 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1585 error_name = ErrorIf.ArgmaxOutputRankMismatch
1586 param_reqs = {"rank": None, "dtype": None, "shape": None}
1587 error_result = False
1588 error_reason = (
1589 "Mismatch between output shape provided and expected output shape"
1590 )
1591
1592 if check:
1593 output_shape = kwargs["output_shape"]
1594 input_shape = kwargs["input_shape"]
1595 axis = kwargs["axis"]
1596 valid_params = axis >= 0 and axis < len(input_shape)
1597
1598 if valid_params and (len(input_shape) - 1) != len(output_shape):
1599 error_result = True
1600
1601 info_dict = {
1602 "error_name": error_name,
1603 "error_result": error_result,
1604 "error_reason": error_reason,
1605 "param_reqs": param_reqs,
1606 }
1607 return info_dict
1608
1609 @staticmethod
1610 def evKernelSmallerOne(check=False, **kwargs):
1611 error_name = ErrorIf.KernelSmallerOne
1612 param_reqs = {"rank": None, "dtype": None, "shape": None}
1613 error_result = False
1614 error_reason = "At least one kernel dimension is smaller than zero"
1615
1616 if check:
1617 kernel = kwargs["kernel"]
1618 if min(kernel) < 1:
1619 error_result = True
1620
1621 info_dict = {
1622 "error_name": error_name,
1623 "error_result": error_result,
1624 "error_reason": error_reason,
1625 "param_reqs": param_reqs,
1626 }
1627 return info_dict
1628
1629 @staticmethod
1630 def evStrideSmallerOne(check=False, **kwargs):
1631 error_name = ErrorIf.StrideSmallerOne
1632 param_reqs = {"rank": None, "dtype": None, "shape": None}
1633 error_result = False
1634 error_reason = "At least one stride dimension is smaller than zero"
1635
1636 if check:
1637 stride = kwargs["stride"]
1638 if min(stride) < 1:
1639 error_result = True
1640
1641 info_dict = {
1642 "error_name": error_name,
1643 "error_result": error_result,
1644 "error_reason": error_reason,
1645 "param_reqs": param_reqs,
1646 }
1647 return info_dict
1648
1649 @staticmethod
1650 def evDilationSmallerOne(check=False, **kwargs):
1651 error_result = check and min(kwargs["dilation"]) < 1
1652 return {
1653 "error_name": ErrorIf.DilationSmallerOne,
1654 "error_reason": "At least one dilation is smaller than one",
1655 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1656 "error_result": error_result,
1657 }
1658
1659 @staticmethod
1660 def evScaleTrue(check=False, **kwargs):
1661 error_name = ErrorIf.ScaleTrue
1662 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1663 error_result = False
1664 error_reason = "Scale set to true but input type is INT48"
1665
1666 if check:
1667 input_dtype = kwargs["input_dtype"]
1668 scale32 = kwargs["scale32"]
1669 if scale32 and input_dtype == DType.INT48:
1670 error_result = True
1671
1672 info_dict = {
1673 "error_name": error_name,
1674 "error_result": error_result,
1675 "error_reason": error_reason,
1676 "param_reqs": param_reqs,
1677 }
1678 return info_dict
1679
1680 @staticmethod
1681 def evScaleNotTrue(check=False, **kwargs):
1682 error_name = ErrorIf.ScaleNotTrue
1683 param_reqs = {"rank": None, "dtype": None, "shape": None}
1684 error_result = False
1685 error_reason = "Scale set to false but double round set to true"
1686
1687 if check:
1688 scale32 = kwargs["scale32"]
1689 double_round = kwargs["double_round"]
1690 if not scale32 and double_round:
1691 error_result = True
1692
1693 info_dict = {
1694 "error_name": error_name,
1695 "error_result": error_result,
1696 "error_reason": error_reason,
1697 "param_reqs": param_reqs,
1698 }
1699 return info_dict
1700
1701 @staticmethod
1702 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1703 error_name = ErrorIf.TensorSizeInputOutputMismatch
1704 param_reqs = {"rank": None, "dtype": None, "shape": None}
1705 error_result = False
1706 error_reason = "Input tensor size does not match output tensor size"
1707
1708 if check:
1709 input_shape = kwargs["input_shape"]
1710 output_shape = kwargs["output_shape"]
1711 input_size = np.prod(input_shape)
1712 output_size = np.prod(output_shape)
1713 if input_size != output_size:
1714 error_result = True
1715
1716 info_dict = {
1717 "error_name": error_name,
1718 "error_result": error_result,
1719 "error_reason": error_reason,
1720 "param_reqs": param_reqs,
1721 }
1722 return info_dict
1723
1724 @staticmethod
1725 def evStartSmallerZero(check=False, **kwargs):
1726 error_name = ErrorIf.StartSmallerZero
1727 param_reqs = {"rank": None, "dtype": None, "shape": None}
1728 error_result = False
1729 error_reason = "Starting point smaller than zero"
1730
1731 if check:
1732 input_shape = kwargs["input_shape"]
1733 start = kwargs["start"]
1734 rank = len(input_shape)
1735 if len(start) == rank:
1736 for index in range(rank):
1737 if start[index] < 0:
1738 error_result = True
1739
1740 info_dict = {
1741 "error_name": error_name,
1742 "error_result": error_result,
1743 "error_reason": error_reason,
1744 "param_reqs": param_reqs,
1745 }
1746 return info_dict
1747
1748 @staticmethod
1749 def evSizeSmallerEqualZero(check=False, **kwargs):
1750 error_name = ErrorIf.SizeSmallerEqualZero
1751 param_reqs = {"rank": None, "dtype": None, "shape": None}
1752 error_result = False
1753 error_reason = "Size smaller than or equal to zero"
1754
1755 if check:
1756 input_shape = kwargs["input_shape"]
1757 size = kwargs["size"]
1758 rank = len(input_shape)
1759 if len(size) == rank:
1760 for index in range(rank):
1761 if size[index] <= 0:
1762 error_result = True
1763
1764 info_dict = {
1765 "error_name": error_name,
1766 "error_result": error_result,
1767 "error_reason": error_reason,
1768 "param_reqs": param_reqs,
1769 }
1770 return info_dict
1771
1772 @staticmethod
1773 def evStartSizeOutsideBounds(check=False, **kwargs):
1774 error_name = ErrorIf.StartSizeOutsideBounds
1775 param_reqs = {"rank": None, "dtype": None, "shape": None}
1776 error_result = False
1777 error_reason = "starting point plus size larger than input dimension"
1778
1779 if check:
1780 input_shape = kwargs["input_shape"]
1781 start = kwargs["start"]
1782 size = kwargs["size"]
1783 rank = len(input_shape)
1784 if len(start) == rank and len(size) == rank:
1785 for index in range(rank):
1786 if start[index] + size[index] > input_shape[index]:
1787 error_result = True
1788
1789 info_dict = {
1790 "error_name": error_name,
1791 "error_result": error_result,
1792 "error_reason": error_reason,
1793 "param_reqs": param_reqs,
1794 }
1795 return info_dict
1796
1797 @staticmethod
1798 def evSizeOutputShapeMismatch(check=False, **kwargs):
1799 error_name = ErrorIf.SizeOutputShapeMismatch
1800 param_reqs = {"rank": None, "dtype": None, "shape": None}
1801 error_result = False
1802 error_reason = "Size does not match output dimension"
1803
1804 if check:
1805 input_shape = kwargs["input_shape"]
1806 output_shape = kwargs["output_shape"]
1807 size = kwargs["size"]
1808 rank = len(input_shape)
1809 if len(size) == rank:
1810 for index in range(rank):
1811 if size[index] != output_shape[index]:
1812 error_result = True
1813
1814 info_dict = {
1815 "error_name": error_name,
1816 "error_result": error_result,
1817 "error_reason": error_reason,
1818 "param_reqs": param_reqs,
1819 }
1820 return info_dict
1821
1822 @staticmethod
1823 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1824 error_name = ErrorIf.InputSizeStartLengthMismatch
1825 param_reqs = {"rank": None, "dtype": None, "shape": None}
1826 error_result = False
1827 error_reason = "rank of input not equal to length of start or size"
1828
1829 if check:
1830 input_shape = kwargs["input_shape"]
1831 start = kwargs["start"]
1832 size = kwargs["size"]
1833 rank = len(input_shape)
1834 if rank != len(start) or rank != len(size):
1835 error_result = True
1836
1837 info_dict = {
1838 "error_name": error_name,
1839 "error_result": error_result,
1840 "error_reason": error_reason,
1841 "param_reqs": param_reqs,
1842 }
1843 return info_dict
1844
1845 @staticmethod
1846 def evIndexOutsideBounds(check=False, **kwargs):
1847 error_name = ErrorIf.IndexOutsideBounds
1848 param_reqs = {"rank": None, "dtype": None, "shape": None}
1849 error_result = False
1850 error_reason = "Index outside of allowed bounds"
1851
1852 if check:
1853 input_shape = kwargs["input_shape"]
1854 perms = kwargs["perms"]
1855 rank = len(input_shape)
1856
1857 for index in perms:
1858 if index < 0 or index > rank:
1859 error_result = True
1860
1861 info_dict = {
1862 "error_name": error_name,
1863 "error_result": error_result,
1864 "error_reason": error_reason,
1865 "param_reqs": param_reqs,
1866 }
1867 return info_dict
1868
1869 @staticmethod
1870 def evIndexUsedTwice(check=False, **kwargs):
1871 error_name = ErrorIf.IndexUsedTwice
1872 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1873 error_result = False
1874 error_reason = "Index used multiple times"
1875
1876 if check:
1877 perms = kwargs["perms"]
1878
1879 unique_indices = []
1880 for index in perms:
1881 if index in unique_indices:
1882 error_result = True
1883 else:
1884 unique_indices.append(index)
1885
1886 info_dict = {
1887 "error_name": error_name,
1888 "error_result": error_result,
1889 "error_reason": error_reason,
1890 "param_reqs": param_reqs,
1891 }
1892 return info_dict
1893
1894 @staticmethod
1895 def evMaxSmallerMin(check=False, **kwargs):
1896 error_name = ErrorIf.MaxSmallerMin
1897 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1898 error_result = False
1899 error_reason = "Max value smaller than min value"
1900
1901 if check:
1902 max_val = kwargs["max_val"]
1903 min_val = kwargs["min_val"]
1904 if max_val < min_val:
1905 error_result = True
1906
1907 info_dict = {
1908 "error_name": error_name,
1909 "error_result": error_result,
1910 "error_reason": error_reason,
1911 "param_reqs": param_reqs,
1912 }
1913 return info_dict
1914
1915 @staticmethod
1916 def evConcatInputRankMismatch(check=False, **kwargs):
1917 error_name = ErrorIf.ConcatInputRankMismatch
1918 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1919 error_result = False
1920 error_reason = "Input ranks are not identical"
1921
1922 if check:
1923 inputs = kwargs["inputs"]
1924 input_shape = kwargs["input_shape"]
1925 for input in inputs:
1926 if len(input.shape) != len(input_shape):
1927 error_result = True
1928
1929 info_dict = {
1930 "error_name": error_name,
1931 "error_result": error_result,
1932 "error_reason": error_reason,
1933 "param_reqs": param_reqs,
1934 }
1935 return info_dict
1936
1937 @staticmethod
1938 def evConcatInputDimMismatch(check=False, **kwargs):
1939 error_name = ErrorIf.ConcatInputDimMismatch
1940 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1941 error_result = False
1942 error_reason = "Input dimensions differ on too many axes"
1943
1944 if check:
1945 inputs = kwargs["inputs"]
1946 input_shape = kwargs["input_shape"]
1947 axis = kwargs["axis"]
1948
1949 # Ensure rank is valid before checking dims.
1950 valid_rank = True
1951 for input in inputs:
1952 if len(input.shape) != len(input_shape):
1953 valid_rank = False
1954
1955 if valid_rank:
1956 for input in inputs:
1957 for i, dim in enumerate(input.shape):
1958 if dim != input_shape[i] and axis != i:
1959 error_result = True
1960
1961 info_dict = {
1962 "error_name": error_name,
1963 "error_result": error_result,
1964 "error_reason": error_reason,
1965 "param_reqs": param_reqs,
1966 }
1967 return info_dict
1968
1969 @staticmethod
1970 def evConcatShapeSumMismatch(check=False, **kwargs):
1971 error_name = ErrorIf.ConcatShapeSumMismatch
1972 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1973 error_result = False
1974 error_reason = "Sum of dimensions on axis not equal to output dimension"
1975
1976 if check:
1977 inputs = kwargs["inputs"]
1978 input_shape = kwargs["input_shape"]
1979 output_shape = kwargs["output_shape"]
1980 axis = kwargs["axis"]
1981
1982 # Ensure rank is valid before checking dims.
1983 valid_params = True
1984 for input in inputs:
1985 if len(input.shape) != len(input_shape):
1986 valid_params = False
1987 if axis < 0 or axis > len(input_shape):
1988 valid_params = False
1989
1990 if valid_params:
1991 axis_dim_sum = 0
1992 for input in inputs:
1993 axis_dim_sum += input.shape[axis]
1994
1995 if axis_dim_sum != output_shape[axis]:
1996 error_result = True
1997
1998 info_dict = {
1999 "error_name": error_name,
2000 "error_result": error_result,
2001 "error_reason": error_reason,
2002 "param_reqs": param_reqs,
2003 }
2004 return info_dict
2005
2006 @staticmethod
2007 def evInputListThenGraphMismatch(check=False, **kwargs):
2008 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2009 param_reqs = {"rank": None, "dtype": None, "shape": None}
2010 error_result = False
2011 error_reason = "Input list shape does not match then-graph shape"
2012
2013 if check:
2014 a = kwargs["a"]
2015 b = kwargs["b"]
2016 basicBlocks = kwargs["basicBlocks"]
2017 then_block = basicBlocks[1]
2018 then_inputs = then_block.inputs
2019 then_tens = then_block.tensors
2020 if (a.shape != then_tens[then_inputs[0]].shape) or (
2021 b.shape != then_tens[then_inputs[1]].shape
2022 ):
2023 error_result = True
2024
2025 info_dict = {
2026 "error_name": error_name,
2027 "error_result": error_result,
2028 "error_reason": error_reason,
2029 "param_reqs": param_reqs,
2030 }
2031 return info_dict
2032
2033 @staticmethod
2034 def evInputListElseGraphMismatch(check=False, **kwargs):
2035 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2036 param_reqs = {"rank": None, "dtype": None, "shape": None}
2037 error_result = False
2038 error_reason = "Input list shape does not match else-graph shape"
2039
2040 if check:
2041 a = kwargs["a"]
2042 b = kwargs["b"]
2043 basicBlocks = kwargs["basicBlocks"]
2044 else_block = basicBlocks[2]
2045 else_inputs = else_block.inputs
2046 else_tens = else_block.tensors
2047 if (a.shape != else_tens[else_inputs[0]].shape) or (
2048 b.shape != else_tens[else_inputs[1]].shape
2049 ):
2050 error_result = True
2051
2052 info_dict = {
2053 "error_name": error_name,
2054 "error_result": error_result,
2055 "error_reason": error_reason,
2056 "param_reqs": param_reqs,
2057 }
2058 return info_dict
2059
2060 @staticmethod
2061 def evOutputListThenGraphMismatch(check=False, **kwargs):
2062 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2063 param_reqs = {"rank": None, "dtype": None, "shape": None}
2064 error_result = False
2065 error_reason = "Output list shape does not match then-graph shape"
2066
2067 if check:
2068 basicBlocks = kwargs["basicBlocks"]
2069 cond_block = basicBlocks[0]
2070 cond_outputs = cond_block.outputs
2071 cond_tens = cond_block.tensors
2072 then_block = basicBlocks[1]
2073 then_outputs = then_block.outputs
2074 then_tens = then_block.tensors
2075 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2076 error_result = True
2077
2078 info_dict = {
2079 "error_name": error_name,
2080 "error_result": error_result,
2081 "error_reason": error_reason,
2082 "param_reqs": param_reqs,
2083 }
2084 return info_dict
2085
2086 @staticmethod
2087 def evOutputListElseGraphMismatch(check=False, **kwargs):
2088 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2089 param_reqs = {"rank": None, "dtype": None, "shape": None}
2090 error_result = False
2091 error_reason = "Output list shape does not match else-graph shape"
2092
2093 if check:
2094 basicBlocks = kwargs["basicBlocks"]
2095 cond_block = basicBlocks[0]
2096 cond_outputs = cond_block.outputs
2097 cond_tens = cond_block.tensors
2098 else_block = basicBlocks[2]
2099 else_outputs = else_block.outputs
2100 else_tens = else_block.tensors
2101 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2102 error_result = True
2103
2104 info_dict = {
2105 "error_name": error_name,
2106 "error_result": error_result,
2107 "error_reason": error_reason,
2108 "param_reqs": param_reqs,
2109 }
2110 return info_dict
2111
2112 @staticmethod
2113 def evInputListOutputListMismatch(check=False, **kwargs):
2114 error_name = ErrorIf.InputListOutputListMismatch
2115 param_reqs = {"rank": None, "dtype": None, "shape": None}
2116 error_result = False
2117 error_reason = "Input list does not match output list"
2118
2119 if check:
2120 basicBlocks = kwargs["basicBlocks"]
2121 while_block = basicBlocks[0]
2122 while_inputs = while_block.inputs
2123 while_outputs = while_block.outputs
2124 while_tens = while_block.tensors
2125 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2126 error_result = True
2127
2128 info_dict = {
2129 "error_name": error_name,
2130 "error_result": error_result,
2131 "error_reason": error_reason,
2132 "param_reqs": param_reqs,
2133 }
2134 return info_dict
2135
2136 @staticmethod
2137 def evInputListCondGraphMismatch(check=False, **kwargs):
2138 error_name = ErrorIf.InputListCondGraphMismatch
2139 param_reqs = {"rank": None, "dtype": None, "shape": None}
2140 error_result = False
2141 error_reason = "Input list does not match cond graph"
2142
2143 if check:
2144 basicBlocks = kwargs["basicBlocks"]
2145 while_block = basicBlocks[0]
2146 while_inputs = while_block.inputs
2147 while_tens = while_block.tensors
2148 cond_block = basicBlocks[1]
2149 cond_inputs = cond_block.inputs
2150 cond_tens = cond_block.tensors
2151 if (
2152 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2153 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2154 error_result = True
2155
2156 info_dict = {
2157 "error_name": error_name,
2158 "error_result": error_result,
2159 "error_reason": error_reason,
2160 "param_reqs": param_reqs,
2161 }
2162 return info_dict
2163
2164 @staticmethod
2165 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2166 error_name = ErrorIf.InputListBodyGraphInputMismatch
2167 param_reqs = {"rank": None, "dtype": None, "shape": None}
2168 error_result = False
2169 error_reason = "Input list does not match body graph input"
2170
2171 if check:
2172 basicBlocks = kwargs["basicBlocks"]
2173 while_block = basicBlocks[0]
2174 while_inputs = while_block.inputs
2175 while_tens = while_block.tensors
2176 body_block = basicBlocks[2]
2177 body_outputs = body_block.inputs
2178 body_tens = body_block.tensors
2179 if (
2180 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2181 ) or (
2182 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2183 ):
2184 error_result = True
2185
2186 info_dict = {
2187 "error_name": error_name,
2188 "error_result": error_result,
2189 "error_reason": error_reason,
2190 "param_reqs": param_reqs,
2191 }
2192 return info_dict
2193
2194 @staticmethod
2195 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2196 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2197 param_reqs = {"rank": None, "dtype": None, "shape": None}
2198 error_result = False
2199 error_reason = "Input list does not match body graph output"
2200
2201 if check:
2202 basicBlocks = kwargs["basicBlocks"]
2203 while_block = basicBlocks[0]
2204 while_inputs = while_block.inputs
2205 while_tens = while_block.tensors
2206 body_block = basicBlocks[2]
2207 body_outputs = body_block.outputs
2208 body_tens = body_block.tensors
2209 if (
2210 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2211 ) or (
2212 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2213 ):
2214 error_result = True
2215 info_dict = {
2216 "error_name": error_name,
2217 "error_result": error_result,
2218 "error_reason": error_reason,
2219 "param_reqs": param_reqs,
2220 }
2221 return info_dict
2222
2223 @staticmethod
2224 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2225 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2226 param_reqs = {"rank": None, "dtype": None, "shape": None}
2227 error_result = False
2228 error_reason = "Cond graph output is not a match list of booleans"
2229
2230 if check:
2231 basicBlocks = kwargs["basicBlocks"]
2232 cond_block = basicBlocks[1]
2233 cond_outputs = cond_block.outputs
2234 cond_tens = cond_block.tensors
2235 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2236 error_result = True
2237
2238 info_dict = {
2239 "error_name": error_name,
2240 "error_result": error_result,
2241 "error_reason": error_reason,
2242 "param_reqs": param_reqs,
2243 }
2244 return info_dict
2245
2246
2247class TosaInvalidValidator:
2248 @staticmethod
2249 def ivWrongDataTypeOrModeResize(**kwargs):
2250 input_dtype = kwargs["input_dtype"]
2251 args = kwargs["args"]
2252 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002253 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002254
2255 if mode == ResizeMode.BILINEAR:
2256 # Invalid output data type / Invalid input datatype
2257 return (
2258 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002259 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
2260 and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002261 )
2262 elif mode == ResizeMode.NEAREST:
2263 # Invalid output data type / Invalid input datatype
2264 return (input_dtype != output_dtype) or (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002265 input_dtype not in [DType.INT8, DType.INT16, DType.FLOAT]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002266 )
2267 else:
2268 # Invalid resize mode
2269 return True
2270
2271 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002272 def ivHeightWidthInvalid(**kwargs):
2273 opName = kwargs["opName"]
2274
2275 inputShapes = kwargs["shapeList"]
2276 input_shape = inputShapes[0]
2277
2278 args = kwargs["args"]
2279 strides = args[0]
2280 padding = args[1]
2281
2282 if opName.endswith("pool2d"):
2283 # avg_pool2d, max_pool2d
2284 kernel_shape = args[2]
2285 h = (
2286 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2287 ) // strides[0]
2288 w = (
2289 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2290 ) // strides[1]
2291 # return True if any dimension is < 1
2292 return h < 1 or w < 1
2293
2294 if opName.startswith("transpose_conv2d"):
2295 # transpose_conv2d
TatWai Chong24594f52022-06-08 00:48:04 -07002296 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002297 filter_shape = inputShapes[1]
2298 kernel_shape = filter_shape[1:-1]
2299
TatWai Chong24594f52022-06-08 00:48:04 -07002300 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002301 """Calculate the transpose_conv2d output size for a dimension.
2302
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002303 Args:
2304 in_size: the input size - int
2305 stride: the stride - int
2306 kernel_size: the kernel size - int
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002307 out_pad: the output padding - int
2308 in_pad: the input padding - int
2309
2310 Returns:
2311 the output size
2312 """
TatWai Chong24594f52022-06-08 00:48:04 -07002313 return (in_size - 1) * stride + kernel_size - in_pad - out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002314
2315 for pad_h, pad_w in (
2316 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2317 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2318 (0, 0), # VALID padding
2319 ):
2320 h = get_out_size(
2321 input_shape[1],
2322 strides[0],
2323 kernel_shape[0],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002324 padding[0],
2325 pad_h,
2326 )
2327 w = get_out_size(
2328 input_shape[2],
2329 strides[1],
2330 kernel_shape[1],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002331 padding[1],
2332 pad_w,
2333 )
2334 if output_shape[1] == h and output_shape[2] == w:
2335 return False
2336
2337 # output shape does not match the expected shape for any padding option
2338 return True
2339
2340 if "conv2d" in opName or "conv3d" in opName:
2341 # conv2d, conv3d, depthwise_conv2d
2342 dilations = args[2]
2343 filter_shape = inputShapes[1]
2344 kernel_shape = (
2345 filter_shape[0:2]
2346 if opName.startswith("depthwise_conv2d")
2347 else filter_shape[1:-1]
2348 )
2349
2350 for i in range(len(kernel_shape)):
2351 dim = (
2352 input_shape[i + 1]
2353 - kernel_shape[i]
2354 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2355 + padding[i * 2 + 0]
2356 + padding[i * 2 + 1]
2357 ) // strides[i] + 1
2358 # return True if any dimension is < 1
2359 if dim < 1:
2360 return True
2361 return False
2362
2363 assert False, f"Unrecognized Op: {opName}"
2364
2365 @staticmethod
2366 def ivNonPositiveOutputShape(**kwargs):
2367 args = kwargs["args"]
TatWai Chong24594f52022-06-08 00:48:04 -07002368 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002369 if output_shape[1] <= 0 or output_shape[2] <= 0:
2370 # Negative output shape
2371 return True
2372 return False