blob: e4e60b77ff5627e16d9b558cb84025520a04f7fc [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005from generator.tosa_utils import product
6from generator.tosa_utils import usableDTypes
7from generator.tosa_utils import valueToName
8from tosa.DType import DType
9from tosa.Op import Op
10from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011
Matthew Haddone86fd342021-09-07 16:12:21 +010012
13class ErrorIf(object):
14 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010015 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
16 ScaleNLargerMax = "ScaleNLargerMax"
17 ScaleDLargerMax = "ScaleDLargerMax"
18 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010019 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010020 BorderSmallerMin = "BorderSmallerMin"
21 BorderLargerEqualMax = "BorderLargerEqualMax"
22 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
23 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010024 WrongInputType = "WrongInputType"
25 WrongOutputType = "WrongOutputType"
26 WrongInputList = "WrongInputList"
27 WrongOutputList = "WrongOutputList"
28 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010029 BatchMismatch = "BatchMismatch"
30 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010031 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000032 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010033 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010034 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010036 AxisSmallerZero = "AxisSmallerZero"
37 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010038 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
39 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010040 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010041 KernelSmallerOne = "KernelSmallerOne"
42 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000043 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010044 PadSmallerZero = "PadSmallerZero"
45 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010046 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010047 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010048 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
49 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
50 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010051 ScaleNotTrue = "ScaleNotTrue"
52 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010053 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
54 StartSmallerZero = "StartSmallerZero"
55 SizeSmallerEqualZero = "SizeSmallerEqualZero"
56 StartSizeOutsideBounds = "StartSizeOutsideBounds"
57 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
58 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
59 IndexOutsideBounds = "IndexOutsideBounds"
60 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010061 MaxSmallerMin = "MaxSmallerMin"
62 ConcatInputRankMismatch = "ConcatInputRankMismatch"
63 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010064 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010065 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
66 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
67 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
68 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
69 InputListOutputListMismatch = "InputListOutputListMismatch"
70 InputListCondGraphMismatch = "InputListCondGraphMismatch"
71 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
72 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
73 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010074 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
75 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010076
77
78class TosaErrorIfArgGen:
79 @staticmethod
80 def eiResizeErrorIf(
81 testGen,
82 error_name,
83 mode,
84 dtype,
85 shapeList,
86 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010087 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010088 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010089 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010090 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010091 if error_name == ErrorIf.ScaleSmallerEqualZero:
92 index = testGen.randInt(low=0, high=4)
93 scale[index] = testGen.rng.choice([-2, -1, 0])
94 elif error_name == ErrorIf.ScaleNLargerMax:
95 index = testGen.rng.choice([0, 2])
96 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
97 elif error_name == ErrorIf.ScaleDLargerMax:
98 index = testGen.rng.choice([1, 3])
99 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100101 if error_name == ErrorIf.OffsetLargerEqualMax:
102 index = testGen.rng.choice([0, 1])
103 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
104 elif error_name == ErrorIf.OffsetSmallerMin:
105 index = testGen.rng.choice([0, 1])
106 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
107
108 if error_name == ErrorIf.BorderLargerEqualMax:
109 index = testGen.rng.choice([0, 1])
110 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
111 elif error_name == ErrorIf.BorderSmallerMin:
112 index = testGen.rng.choice([0, 1])
113 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100114
115 if error_name == ErrorIf.WrongOutputType:
116 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
117 incorrect_types = (
118 DType.INT4,
119 DType.INT16,
120 DType.INT32,
121 DType.INT48,
122 DType.FLOAT,
123 )
124 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
125 incorrect_types = (
126 DType.INT4,
127 DType.INT8,
128 DType.INT32,
129 DType.INT48,
130 DType.FLOAT,
131 )
132 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
133 incorrect_types = (
134 DType.INT4,
135 DType.INT8,
136 DType.INT16,
137 DType.INT48,
138 DType.FLOAT,
139 )
140 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
141 incorrect_types = (
142 DType.INT4,
143 DType.INT8,
144 DType.INT16,
145 DType.INT32,
146 DType.FLOAT,
147 )
148 elif dtype == DType.FLOAT:
149 incorrect_types = (
150 DType.INT4,
151 DType.INT8,
152 DType.INT16,
153 DType.INT32,
154 DType.INT48,
155 )
156 outputDType = testGen.rng.choice(a=incorrect_types)
157
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100158 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100159
160 @staticmethod
161 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
162 if (
163 error_name == ErrorIf.StrideSmallerOne
164 # padding must not exceed the kernel size
165 and pad[0] < kernel[0]
166 and pad[1] < kernel[0]
167 and pad[2] < kernel[1]
168 and pad[3] < kernel[1]
169 ):
170 wrongStride = (
171 testGen.rng.choice([0, -1, -2, -3]),
172 testGen.rng.choice([0, -1, -2, -3]),
173 )
174 return wrongStride, pad, kernel
175 elif error_name == ErrorIf.PadSmallerZero:
176 wrongPad = (
177 testGen.rng.choice([-1, -2, -3]),
178 testGen.rng.choice([-1, -2, -3]),
179 testGen.rng.choice([-1, -2, -3]),
180 testGen.rng.choice([-1, -2, -3]),
181 )
182 return stride, wrongPad, kernel
183 elif error_name == ErrorIf.KernelSmallerOne:
184 wrongKernel = (
185 testGen.rng.choice([0, -1, -2, -3]),
186 testGen.rng.choice([0, -1, -2, -3]),
187 )
188 return stride, pad, wrongKernel
189 elif error_name == ErrorIf.PadLargerEqualKernel:
190 wrongPad = (
191 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
192 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
193 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
194 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
195 )
196 return stride, wrongPad, kernel
197 else:
198 return None, None, None
199
200 @staticmethod
201 def eiRescaleWrongOutputType(input_dtype, output_dtype):
202 if input_dtype == DType.INT8:
203 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
204 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100205 elif input_dtype == DType.INT16:
206 if output_dtype not in [
207 DType.UINT8,
208 DType.INT8,
209 DType.UINT16,
210 DType.INT16,
211 DType.INT32,
212 ]:
213 return True
214 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100215 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
216 return True
217 elif input_dtype == DType.INT48:
218 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
219 return True
220 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100221 if output_dtype not in [DType.INT8, DType.INT16]:
222 return True
223 elif input_dtype == DType.UINT16:
224 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100225 return True
226 return False
227
228 @staticmethod
229 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
230 # Mess up input/output tensors for ERROR_IF checks
231 if error_name == "WrongInputList":
232 add_input = testGen.rng.choice([True, False])
233 if add_input:
234 input_list.append("eiDummyInput")
235 else:
236 input_list = input_list[:-1]
237 elif error_name == "WrongOutputList":
238 add_output = testGen.rng.choice([True, False])
239 if add_output:
240 output_list.append("eiDummyOutput")
241 else:
242 output_list = []
243 return input_list, output_list
244
245 @staticmethod
246 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
247 """Restrict the dimensions and overall size of a shape to
248 max_dim and max_items.
249 """
250 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
251 while product(new_shape) > max_items:
252 new_shape = [max(d - 1, 1) for d in new_shape]
253 return new_shape
254
255 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
256 if error_name == ErrorIf.StartSmallerZero:
257 newStart = []
258 for i in range(len(input_shape)):
259 newStart.append(testGen.rng.choice([-3, -2, -1]))
260 return newStart, size
261 elif error_name == ErrorIf.SizeSmallerEqualZero:
262 newSize = []
263 for i in range(len(input_shape)):
264 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
265 return start, newSize
266 elif error_name == ErrorIf.StartSizeOutsideBounds:
267 newStart, newSize = [], []
268 for i in range(len(input_shape)):
269 newStart.append(input_shape[i] - 1)
270 newSize.append(testGen.rng.choice([2, 3, 4]))
271 return newStart, newSize
272 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
273 remove = testGen.rng.choice([True, False])
274 if remove:
275 newStart = start[1:]
276 newSize = size[1:]
277 else:
278 newStart = start
279 newStart.append(1)
280 newSize = size
281 newSize.append(1)
282 return newStart, newSize
283 else:
284 return start, size
285
286 @staticmethod
287 def eiCastErrorIf(testGen, input_dtype):
288 if input_dtype in [DType.BOOL, DType.FLOAT]:
289 outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
290 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
291 outputDType = [DType.INT48]
292 else:
293 assert True, f"input_dtype ({input_dtype}) not supported"
294 return outputDType
295
296
297class TosaErrorValidator:
298 @staticmethod
299 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
300 """Check ERROR_IF statements are caught and set the expected result.
301
302 Args:
303 serializer: the serializer to set the expected result in
304 validator_fcns: a sequence of validator functions to verify the result
305 error_name: the name of the ERROR_IF condition to check for
306 kwargs: keyword arguments for the validator functions
307 Returns:
308 True if the result matches the expected result; otherwise False
309 """
310 overall_result = True
311 for val_fcn in validator_fcns:
312 val_result = val_fcn(True, **kwargs)
313 validator_name = val_result["error_name"]
314 error_result = val_result["error_result"]
315 error_reason = val_result["error_reason"]
316
317 # expect an error IFF the error_name and validator_name match
318 expected_result = error_result == (error_name == validator_name)
319 overall_result &= expected_result
320
321 if expected_result and error_result:
322 serializer.setExpectedReturnCode(2, True, desc=error_reason)
323 elif error_result: # and not expected_result
324 print(
325 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
326 f" Expected: {error_name}, Got: {validator_name}"
327 )
328 elif not expected_result: # and not error_result
329 print(
330 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
331 f" Expected: {error_name}"
332 )
333
334 if not expected_result:
335 for k, v in sorted(kwargs.items()):
336 if k != "op":
337 if k.endswith("dtype"):
338 v = valueToName(DType, v)
339 print(f" {k} = {v}")
340
341 return overall_result
342
343 @staticmethod
344 def evWrongInputType(check=False, **kwargs):
345 error_result = False
346
347 # Find the unsupported input data types
348 op = kwargs["op"]
349 input_dtypes = op["types"]
350 allowed_input_dtypes = {
351 t[0] if isinstance(t, list) else t for t in input_dtypes
352 }
353 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
354
355 if op["op"] == Op.CLAMP:
356 wrong_input_dtypes.remove(DType.INT48)
357
358 if check:
359 input_dtype = kwargs["input_dtype"]
360 if input_dtype not in allowed_input_dtypes:
361 error_result = True
362
363 info_dict = {
364 "error_name": ErrorIf.WrongInputType,
365 "error_result": error_result,
366 "error_reason": "Input data type not supported for this operator",
367 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
368 }
369 return info_dict
370
371 @staticmethod
372 def evWrongOutputType(check=False, **kwargs):
373 error_result = False
374
375 if check:
376 input_dtype = kwargs["input_dtype"]
377 output_dtype = kwargs["output_dtype"]
378 op = kwargs["op"]
379
380 if op["op"] == Op.RESIZE:
381 mode = kwargs["mode"]
382 if (
383 (
384 mode == ResizeMode.NEAREST
385 and input_dtype == DType.INT8
386 and output_dtype != DType.INT8
387 )
388 or (
389 mode == ResizeMode.NEAREST
390 and input_dtype == DType.INT16
391 and output_dtype != DType.INT16
392 )
393 or (
394 mode == ResizeMode.BILINEAR
395 and input_dtype == DType.INT8
396 and output_dtype != DType.INT32
397 )
398 or (
399 mode == ResizeMode.BILINEAR
400 and input_dtype == DType.INT16
401 and output_dtype != DType.INT48
402 )
403 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
404 ):
405 error_result = True
406
407 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100408 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
409 input_dtype, output_dtype
410 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100411
412 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
413 if (
414 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
415 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
416 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
417 ):
418 error_result = True
419
420 elif op["op"] == Op.ARGMAX:
421 if (
422 input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
423 and output_dtype != DType.INT32
424 ):
425 error_result = True
426
427 elif op["op"] == Op.MUL:
428 if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
429 error_result = True
430 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
431 error_result = True
432
433 elif op["op"] == Op.TABLE:
434 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
435 error_result = True
436 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
437 error_result = True
438
439 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
440 if output_dtype != DType.BOOL:
441 error_result = True
442
443 elif op["op"] == Op.CAST:
444 if (
445 (
446 input_dtype == DType.BOOL
447 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
448 )
449 or (
450 input_dtype == DType.INT8
451 and output_dtype
452 not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
453 )
454 or (
455 input_dtype == DType.INT16
456 and output_dtype
457 not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
458 )
459 or (
460 input_dtype == DType.INT32
461 and output_dtype
462 not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
463 )
464 or (
465 input_dtype == DType.FLOAT
466 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
467 )
468 ):
469 error_result = True
470
471 elif op["op"] in {
472 Op.CONV2D,
473 Op.CONV3D,
474 Op.DEPTHWISE_CONV2D,
475 Op.TRANSPOSE_CONV2D,
476 }:
477 if (
478 input_dtype == DType.INT8
479 and output_dtype != DType.INT32
480 or input_dtype == DType.INT16
481 and output_dtype != DType.INT48
482 or input_dtype == DType.FLOAT
483 and output_dtype != DType.FLOAT
484 ):
485 error_result = True
486 # invalid input types are ignored, to avoid reporting multiple errors
487
488 else:
489 if output_dtype != input_dtype:
490 error_result = True
491
492 info_dict = {
493 "error_name": ErrorIf.WrongOutputType,
494 "error_result": error_result,
495 "error_reason": (
496 "Output data type not supported for this configuration of operator"
497 ),
498 "param_reqs": {"rank": None, "dtype": None, "shape": None},
499 }
500 return info_dict
501
502 @staticmethod
503 def evWrongRank(check=False, **kwargs):
504 all_ranks = (1, 2, 3, 4, 5)
505
506 # Make a list of incorrect ranks
507 assert "op" in kwargs
508 op = kwargs["op"]
509 rmin, rmax = op["rank"]
510 rank_range = range(rmin, rmax + 1)
511 incorrect_ranks = list(set(all_ranks) - set(rank_range))
512 # Remove small incorrect ranks to avoid index errors
513 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
514 # Set minimum incorrect rank to 3 to avoid index error
515 if op["op"] in [Op.RESIZE]:
516 incorrect_ranks = [3, 5]
517 elif op["op"] in [Op.TRANSPOSE]:
518 incorrect_ranks = [7, 8]
519 elif op["op"] in [Op.CONV3D]:
520 incorrect_ranks = [6, 7]
521
522 error_name = ErrorIf.WrongRank
523 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
524 error_result = False
525 error_reason = "Rank not supported for this operator"
526
527 if check:
528 input_shape = kwargs["input_shape"]
529
530 if (
531 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
532 and len(input_shape) != 4
533 ):
534 error_result = True
535 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
536 error_result = True
537 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
538 error_result = True
539 else:
540 if len(input_shape) not in rank_range:
541 error_result = True
542
543 info_dict = {
544 "error_name": error_name,
545 "error_result": error_result,
546 "error_reason": error_reason,
547 "param_reqs": param_reqs,
548 }
549 return info_dict
550
551 @staticmethod
552 def evWrongInputList(check=False, **kwargs):
553 error_name = ErrorIf.WrongInputList
554 param_reqs = {"rank": None, "dtype": None, "shape": None}
555 error_result = False
556 error_reason = "Op input list does not match expected input"
557
558 if check:
559 op = kwargs["op"]
560 input_list = kwargs["input_list"]
561 num_operands = kwargs["num_operands"]
562 if op["op"] in [Op.SCATTER, Op.GATHER]:
563 # SCATTER/GATHER add an indices input tensor in their build functions
564 num_operands += 1
565 if len(input_list) != num_operands:
566 error_result = True
567
568 info_dict = {
569 "error_name": error_name,
570 "error_result": error_result,
571 "error_reason": error_reason,
572 "param_reqs": param_reqs,
573 }
574 return info_dict
575
576 @staticmethod
577 def evWrongOutputList(check=False, **kwargs):
578 error_name = ErrorIf.WrongOutputList
579 param_reqs = {"rank": None, "dtype": None, "shape": None}
580 error_result = False
581 error_reason = "Op output list does not match expected output"
582
583 if check:
584 output_list = kwargs["output_list"]
585 # Note this will be incorrect if an operator returns more than one output
586 if len(output_list) != 1:
587 error_result = True
588
589 info_dict = {
590 "error_name": error_name,
591 "error_result": error_result,
592 "error_reason": error_reason,
593 "param_reqs": param_reqs,
594 }
595 return info_dict
596
597 @staticmethod
598 def evMaxDimExceeded(check=False, **kwargs):
599 error_name = ErrorIf.MaxDimExceeded
600 param_reqs = {
601 "rank": [4, 4],
602 "dtype": [DType.INT8],
603 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
604 }
605 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100606 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100607
608 if check:
609 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100610 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100611 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100612 (input_shape[1] >= MAX_RESIZE_DIMENSION)
613 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
614 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
615 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100616 ):
617 error_result = True
618
619 info_dict = {
620 "error_name": error_name,
621 "error_result": error_result,
622 "error_reason": error_reason,
623 "param_reqs": param_reqs,
624 }
625 return info_dict
626
627 @staticmethod
628 def evBatchMismatch(check=False, **kwargs):
629 error_name = ErrorIf.BatchMismatch
630 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
631 error_result = False
632 error_reason = "Input batch size not equal to output batch size"
633
634 assert "op" in kwargs
635 op = kwargs["op"]
636 rmin, rmax = op["rank"]
637 rank_range = range(rmin, rmax + 1)
638
639 if check:
640 input_shape = kwargs["input_shape"]
641 output_shape = kwargs[
642 "result_tensor"
643 ].shape # Note this is just (N, OH, OW, C)
644
645 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
646 error_result = True
647
648 info_dict = {
649 "error_name": error_name,
650 "error_result": error_result,
651 "error_reason": error_reason,
652 "param_reqs": param_reqs,
653 }
654 return info_dict
655
656 @staticmethod
657 def evChannelMismatch(check=False, **kwargs):
658 error_name = ErrorIf.ChannelMismatch
659 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
660 error_result = False
661 error_reason = "Input channel size not equal to output channel size"
662
663 assert "op" in kwargs
664 op = kwargs["op"]
665 rmin, rmax = op["rank"]
666 rank_range = range(rmin, rmax + 1)
667
668 if check:
669 input_shape = kwargs["input_shape"]
670 output_shape = kwargs[
671 "result_tensor"
672 ].shape # Note this is just (N, OH, OW, C)
673 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
674 error_result = True
675
676 info_dict = {
677 "error_name": error_name,
678 "error_result": error_result,
679 "error_reason": error_reason,
680 "param_reqs": param_reqs,
681 }
682 return info_dict
683
684 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100685 def evScaleSmallerEqualZero(check=False, **kwargs):
686 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100687 param_reqs = {"rank": None, "dtype": None, "shape": None}
688 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100689 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100690
691 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100692 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100693
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100694 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100695 error_result = True
696
697 info_dict = {
698 "error_name": error_name,
699 "error_result": error_result,
700 "error_reason": error_reason,
701 "param_reqs": param_reqs,
702 }
703 return info_dict
704
705 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100706 def evScaleNLargerMax(check=False, **kwargs):
707 error_name = ErrorIf.ScaleNLargerMax
708 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100709 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100710 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100711
712 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100713 scale = kwargs["scale"]
714
715 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
716 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100717
718 info_dict = {
719 "error_name": error_name,
720 "error_result": error_result,
721 "error_reason": error_reason,
722 "param_reqs": param_reqs,
723 }
724 return info_dict
725
726 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100727 def evScaleDLargerMax(check=False, **kwargs):
728 error_name = ErrorIf.ScaleDLargerMax
729 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100730 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100731 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100732
733 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100734 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100735
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100736 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
737 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100738 ):
739 error_result = True
740
741 info_dict = {
742 "error_name": error_name,
743 "error_result": error_result,
744 "error_reason": error_reason,
745 "param_reqs": param_reqs,
746 }
747 return info_dict
748
749 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100750 def evOffsetSmallerMin(check=False, **kwargs):
751 error_name = ErrorIf.OffsetSmallerMin
752 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100753 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100754 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100755
756 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100757 scale = kwargs["scale"]
758 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100759
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100760 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100761 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100762 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100763 error_result = True
764
765 info_dict = {
766 "error_name": error_name,
767 "error_result": error_result,
768 "error_reason": error_reason,
769 "param_reqs": param_reqs,
770 }
771 return info_dict
772
773 @staticmethod
774 def evOffsetLargerEqualMax(check=False, **kwargs):
775 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100776 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100777 error_result = False
778 error_reason = "Offset value larger than or equal to maximum value"
779
780 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100781 scale = kwargs["scale"]
782 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100783
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100784 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
785 error_result = True
786 elif (
787 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
788 ):
789 error_result = True
790
791 info_dict = {
792 "error_name": error_name,
793 "error_result": error_result,
794 "error_reason": error_reason,
795 "param_reqs": param_reqs,
796 }
797 return info_dict
798
799 @staticmethod
800 def evBorderSmallerMin(check=False, **kwargs):
801 error_name = ErrorIf.BorderSmallerMin
802 param_reqs = {"rank": None, "dtype": None, "shape": None}
803 error_result = False
804 error_reason = "Border value smaller than minimum value"
805
806 if check:
807 scale = kwargs["scale"]
808 border = kwargs["border"]
809
810 if (
811 scale[0] > 0
812 and scale[0] <= (1 << 11)
813 and (border[0] < (-16 * scale[0]))
814 ):
815 error_result = True
816 elif (
817 scale[2] > 0
818 and scale[2] <= (1 << 11)
819 and (border[1] < (-16 * scale[2]))
820 ):
821 error_result = True
822
823 info_dict = {
824 "error_name": error_name,
825 "error_result": error_result,
826 "error_reason": error_reason,
827 "param_reqs": param_reqs,
828 }
829 return info_dict
830
831 @staticmethod
832 def evBorderLargerEqualMax(check=False, **kwargs):
833 error_name = ErrorIf.BorderLargerEqualMax
834 param_reqs = {"rank": None, "dtype": None, "shape": None}
835 error_result = False
836 error_reason = "Border value larger than or equal to maximum value"
837
838 if check:
839 scale = kwargs["scale"]
840 border = kwargs["border"]
841
842 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
843 error_result = True
844 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
845 error_result = True
846
847 info_dict = {
848 "error_name": error_name,
849 "error_result": error_result,
850 "error_reason": error_reason,
851 "param_reqs": param_reqs,
852 }
853 return info_dict
854
855 @staticmethod
856 def checkResizeParams(scale, offset, border):
857 return (
858 min(scale) > 0
859 and max(scale[0], scale[2]) <= (1 << 11)
860 and scale[1] < 16 * scale[0]
861 and scale[3] < 16 * scale[2]
862 and offset[0] >= -scale[0]
863 and offset[1] >= -scale[2]
864 and offset[0] < 16 * scale[0]
865 and offset[1] < 16 * scale[2]
866 and border[0] >= -16 * scale[0]
867 and border[1] >= -16 * scale[2]
868 and border[0] < scale[0]
869 and border[1] < scale[2]
870 )
871
872 @staticmethod
873 def evResizeOutputShapeMismatch(check=False, **kwargs):
874 error_name = ErrorIf.ResizeOutputShapeMismatch
875 param_reqs = {"rank": None, "dtype": None, "shape": None}
876 error_result = False
877 error_reason = (
878 "Mismatch between output shape provided and expected output shape"
879 )
880
881 if check:
882 input_shape = kwargs["input_shape"]
883 output_shape = kwargs["output_shape"]
884 scale = kwargs["scale"]
885 offset = kwargs["offset"]
886 border = kwargs["border"]
887
888 # Ensure parameters are valid
889 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
890
891 if (
892 params_valid
893 and max(output_shape) < MAX_RESIZE_DIMENSION
894 and max(input_shape) < MAX_RESIZE_DIMENSION
895 ):
896 output_y = (
897 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
898 ) // scale[1] + 1
899 output_x = (
900 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
901 ) // scale[3] + 1
902
903 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 error_result = True
905
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100906 info_dict = {
907 "error_name": error_name,
908 "error_result": error_result,
909 "error_reason": error_reason,
910 "param_reqs": param_reqs,
911 }
912 return info_dict
913
914 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100915 def evResizeOutputShapeNonInteger(check=False, **kwargs):
916 error_name = ErrorIf.ResizeOutputShapeNonInteger
917 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100918 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100919 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920
921 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100922 input_shape = kwargs["input_shape"]
923 scale = kwargs["scale"]
924 offset = kwargs["offset"]
925 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100926
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100927 # Ensure parameters are valid
928 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100929
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100930 if params_valid:
931 remainder_y = (
932 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
933 ) % scale[1]
934 remainder_x = (
935 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
936 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100937
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100938 if max(remainder_y, remainder_x) > 0:
939 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100940
941 info_dict = {
942 "error_name": error_name,
943 "error_result": error_result,
944 "error_reason": error_reason,
945 "param_reqs": param_reqs,
946 }
947 return info_dict
948
949 @staticmethod
950 def evRankMismatch(check=False, **kwargs):
951 error_name = ErrorIf.RankMismatch
952 param_reqs = {"rank": None, "dtype": None, "shape": None}
953 error_result = False
954 error_reason = "Input Rank does not match output rank"
955
956 if check:
957 input1_shape = kwargs["input1"].shape
958 input2_shape = kwargs["input2"].shape
959 # In case of SELECT op
960 input3_shape = (
961 kwargs["input3"].shape if "input3" in kwargs else input2_shape
962 )
963 output_shape = kwargs["result_tensor"].shape
964 if (
965 (len(input1_shape) != len(output_shape))
966 or (len(input2_shape) != len(output_shape))
967 or (len(input3_shape) != len(output_shape))
968 ):
969 error_result = True
970
971 info_dict = {
972 "error_name": error_name,
973 "error_result": error_result,
974 "error_reason": error_reason,
975 "param_reqs": param_reqs,
976 }
977 return info_dict
978
979 @staticmethod
980 def evDimensionMismatch(check=False, **kwargs):
981 error_name = ErrorIf.DimensionMismatch
982 param_reqs = {"rank": None, "dtype": None, "shape": None}
983 error_result = False
984 error_reason = "Input Dimensions do not match output"
985
986 if check:
987 input1_shape = kwargs["input1"].shape
988 input2_shape = kwargs["input2"].shape
989 # In case of SELECT op
990 input3_shape = (
991 kwargs["input3"].shape if "input3" in kwargs else input2_shape
992 )
993 output_shape = kwargs["result_tensor"].shape
994 for i in range(
995 min(len(input1_shape), len(input2_shape), len(input3_shape))
996 ):
997 if (
998 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
999 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
1000 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
1001 ):
1002 error_result = True
1003
1004 info_dict = {
1005 "error_name": error_name,
1006 "error_result": error_result,
1007 "error_reason": error_reason,
1008 "param_reqs": param_reqs,
1009 }
1010 return info_dict
1011
1012 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001013 def _getZeroPoint(qinfo, index):
1014 """Return zero point value from quantization info.
1015
1016 Generally input_zp is index 0, output_zp is index 1
1017 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001018 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001019
1020 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 def evInputZeroPointNotZero(check=False, **kwargs):
1022 op = kwargs["op"]
1023 error_result = False
1024
1025 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001026 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001027
1028 # This does not apply to quantizable types
1029 inputDtypes = [
1030 dtype
1031 for dtype in op["types"]
1032 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1033 or (not isinstance(dtype, list) and dtype not in qTypes)
1034 ]
1035
1036 if check:
1037 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001038 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001039 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001040 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001041 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001042 (kwargs["input_dtype"], input_zero_point),
1043 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044 ):
1045 if dtype not in qTypes and zp != 0:
1046 error_result = True
1047 break
1048 else:
1049 error_result = input_dtype not in qTypes and input_zero_point != 0
1050
1051 info_dict = {
1052 "error_name": ErrorIf.InputZeroPointNotZero,
1053 "error_result": error_result,
1054 "error_reason": "Input DType not INT8 and zero point not 0",
1055 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1056 }
1057 return info_dict
1058
1059 @staticmethod
1060 def evWeightZeroPointNotZero(check=False, **kwargs):
1061 op = kwargs["op"]
1062
1063 # exclude inputs with INT8 weights
1064 inputDtypes = [
1065 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1066 ]
1067
1068 error_name = ErrorIf.WeightZeroPointNotZero
1069 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1070 error_result = False
1071 error_reason = "Weight DType not INT8 and zero point not 0"
1072
1073 if check:
1074 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001075 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001076 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1077 error_result = True
1078
1079 info_dict = {
1080 "error_name": error_name,
1081 "error_result": error_result,
1082 "error_reason": error_reason,
1083 "param_reqs": param_reqs,
1084 }
1085 return info_dict
1086
1087 @staticmethod
1088 def evOutputZeroPointNotZero(check=False, **kwargs):
1089 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001090 inputDtypes = [
1091 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1092 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001093
1094 error_name = ErrorIf.OutputZeroPointNotZero
1095 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1096 error_result = False
1097 error_reason = "Output DType not INT8 and zero point not 0"
1098
1099 if check:
1100 input_dtype = kwargs["input_dtype"]
1101 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001102 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001103 if op["op"] == Op.AVG_POOL2D:
1104 if input_dtype != DType.INT8 and output_zero_point != 0:
1105 error_result = True
1106 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001107 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1108 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001109 ):
1110 error_result = True
1111
1112 info_dict = {
1113 "error_name": error_name,
1114 "error_result": error_result,
1115 "error_reason": error_reason,
1116 "param_reqs": param_reqs,
1117 }
1118 return info_dict
1119
1120 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001121 def evU16InputZeroPointNotValid(check=False, **kwargs):
1122 error_name = ErrorIf.U16InputZeroPointNotValid
1123 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1124 error_result = False
1125 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1126
1127 if check:
1128 input_dtype = kwargs["input_dtype"]
1129 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1130 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1131 0,
1132 32768,
1133 ]
1134
1135 info_dict = {
1136 "error_name": error_name,
1137 "error_result": error_result,
1138 "error_reason": error_reason,
1139 "param_reqs": param_reqs,
1140 }
1141 return info_dict
1142
1143 @staticmethod
1144 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1145 error_name = ErrorIf.U16OutputZeroPointNotValid
1146 param_reqs = {"rank": None, "dtype": None, "shape": None}
1147 error_result = False
1148 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1149
1150 if check:
1151 output_dtype = kwargs["output_dtype"]
1152 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1153
1154 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1155 0,
1156 32768,
1157 ]
1158
1159 info_dict = {
1160 "error_name": error_name,
1161 "error_result": error_result,
1162 "error_reason": error_reason,
1163 "param_reqs": param_reqs,
1164 }
1165 return info_dict
1166
1167 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001168 def evAxisSmallerZero(check=False, **kwargs):
1169 error_name = ErrorIf.AxisSmallerZero
1170 param_reqs = {"rank": None, "dtype": None, "shape": None}
1171 error_result = False
1172 error_reason = "Axis smaller than zero"
1173
1174 if check:
1175 axis = kwargs["axis"]
1176 if axis < 0:
1177 error_result = True
1178
1179 info_dict = {
1180 "error_name": error_name,
1181 "error_result": error_result,
1182 "error_reason": error_reason,
1183 "param_reqs": param_reqs,
1184 }
1185 return info_dict
1186
1187 @staticmethod
1188 def evAxisLargerRank(check=False, **kwargs):
1189 error_name = ErrorIf.AxisLargerRank
1190 param_reqs = {"rank": None, "dtype": None, "shape": None}
1191 error_result = False
1192 error_reason = "Axis larger than rank"
1193
1194 if check:
1195 axis = kwargs["axis"]
1196 shape = kwargs["input_shape"]
1197 if axis > len(shape):
1198 error_result = True
1199
1200 info_dict = {
1201 "error_name": error_name,
1202 "error_result": error_result,
1203 "error_reason": error_reason,
1204 "param_reqs": param_reqs,
1205 }
1206 return info_dict
1207
1208 @staticmethod
1209 def evShapeOfAxisNotOne(check=False, **kwargs):
1210 error_name = ErrorIf.ShapeOfAxisNotOne
1211 param_reqs = {"rank": None, "dtype": None, "shape": None}
1212 error_result = False
1213 error_reason = "shape[axis] is not equal to 1"
1214
1215 if check:
1216 axis = kwargs["axis"]
1217 shape = kwargs["output_shape"]
1218 if (0 <= axis < len(shape)) and shape[axis] != 1:
1219 error_result = True
1220
1221 info_dict = {
1222 "error_name": error_name,
1223 "error_result": error_result,
1224 "error_reason": error_reason,
1225 "param_reqs": param_reqs,
1226 }
1227 return info_dict
1228
1229 @staticmethod
1230 def evPadSmallerZero(check=False, **kwargs):
1231 error_name = ErrorIf.PadSmallerZero
1232 param_reqs = {"rank": None, "dtype": None, "shape": None}
1233 error_result = False
1234 error_reason = "At least one pad is smaller than zero"
1235
1236 if check:
1237 op = kwargs["op"]
1238 pad = kwargs["pad"]
1239 if op["op"] == Op.PAD:
1240 for padding in pad:
1241 if min(padding) < 0:
1242 error_result = True
1243 else:
1244 if min(pad) < 0:
1245 error_result = True
1246
1247 info_dict = {
1248 "error_name": error_name,
1249 "error_result": error_result,
1250 "error_reason": error_reason,
1251 "param_reqs": param_reqs,
1252 }
1253 return info_dict
1254
1255 @staticmethod
1256 def evPadLargerEqualKernel(check=False, **kwargs):
1257 error_name = ErrorIf.PadLargerEqualKernel
1258 param_reqs = {"rank": None, "dtype": None, "shape": None}
1259 error_result = False
1260 error_reason = "At least one pad is larger than kernel dimension"
1261
1262 if check:
1263 pad = kwargs["pad"]
1264 kernel = kwargs["kernel"]
1265 if min(pad) > 0 and min(kernel) > 1:
1266 if (
1267 pad[0] >= kernel[0]
1268 or pad[1] >= kernel[0]
1269 or pad[2] >= kernel[1]
1270 or pad[3] >= kernel[1]
1271 ):
1272 error_result = True
1273
1274 info_dict = {
1275 "error_name": error_name,
1276 "error_result": error_result,
1277 "error_reason": error_reason,
1278 "param_reqs": param_reqs,
1279 }
1280 return info_dict
1281
1282 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001283 def evPadOutputShapeMismatch(check=False, **kwargs):
1284 error_name = ErrorIf.PadOutputShapeMismatch
1285 param_reqs = {"rank": None, "dtype": None, "shape": None}
1286 error_result = False
1287 error_reason = "Pad output shape mismatch for requested padding"
1288
1289 if check:
1290 pad = kwargs["pad"]
1291 input_shape = kwargs["input_shape"]
1292 output_shape = kwargs["output_shape"]
1293 for dim, padding in enumerate(pad):
1294 expected_size = input_shape[dim] + padding[0] + padding[1]
1295 if expected_size != output_shape[dim]:
1296 error_result = True
1297
1298 info_dict = {
1299 "error_name": error_name,
1300 "error_result": error_result,
1301 "error_reason": error_reason,
1302 "param_reqs": param_reqs,
1303 }
1304 return info_dict
1305
1306 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001307 def checkPoolingParams(kernel, stride, pad):
1308 return (
1309 min(kernel) >= 1
1310 and min(stride) >= 1
1311 and min(pad) >= 0
1312 and not (
1313 pad[0] >= kernel[0]
1314 or pad[1] >= kernel[0]
1315 or pad[2] >= kernel[1]
1316 or pad[3] >= kernel[1]
1317 )
1318 )
1319
1320 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001321 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1322 error_name = ErrorIf.PoolingOutputShapeMismatch
1323 param_reqs = {"rank": None, "dtype": None, "shape": None}
1324 error_result = False
1325 error_reason = (
1326 "Mismatch between output shape provided and expected output shape"
1327 )
1328
1329 if check:
1330 pad = kwargs["pad"]
1331 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1332
1333 kernel = kwargs["kernel"]
1334 kernel_y, kernel_x = kernel[0], kernel[1]
1335
1336 input_shape = kwargs["input_shape"]
1337 IH, IW = input_shape[1], input_shape[2]
1338
1339 output_shape = kwargs["output_shape"]
1340 OH, OW = output_shape[1], output_shape[2]
1341
1342 stride = kwargs["stride"]
1343 stride_y, stride_x = stride[0], stride[1]
1344
1345 # calculate correct height, width dimensions
1346 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001347 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1348 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001349
1350 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001351 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001352
1353 if params_valid and (OH != y_correct or OW != x_correct):
1354 error_result = True
1355
1356 info_dict = {
1357 "error_name": error_name,
1358 "error_result": error_result,
1359 "error_reason": error_reason,
1360 "param_reqs": param_reqs,
1361 }
1362 return info_dict
1363
1364 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001365 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1366 error_name = ErrorIf.PoolingOutputShapeNonInteger
1367 param_reqs = {"rank": None, "dtype": None, "shape": None}
1368 error_result = False
1369 error_reason = "Parameters do not yield exact integer output dimensions"
1370
1371 if check:
1372 pad = kwargs["pad"]
1373 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1374
1375 kernel = kwargs["kernel"]
1376 kernel_y, kernel_x = kernel[0], kernel[1]
1377
1378 input_shape = kwargs["input_shape"]
1379 IH, IW = input_shape[1], input_shape[2]
1380
1381 stride = kwargs["stride"]
1382 stride_y, stride_x = stride[0], stride[1]
1383
1384 # calculate remainder of height, width dimensions
1385 if stride_x != 0 and stride_y != 0:
1386 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1387 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1388
1389 # ensure parameters are valid
1390 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1391 if params_valid and (y_remainder != 0 or x_remainder != 0):
1392 error_result = True
1393
1394 info_dict = {
1395 "error_name": error_name,
1396 "error_result": error_result,
1397 "error_reason": error_reason,
1398 "param_reqs": param_reqs,
1399 }
1400 return info_dict
1401
1402 @staticmethod
1403 def checkConvParams(weight_shape, stride, pad, dilation):
1404 return (
1405 # Check kernel sizes
1406 min(weight_shape[1:-1]) >= 1
1407 and min(stride) >= 1
1408 and min(pad) >= 0
1409 and (dilation is None or min(dilation) >= 1)
1410 )
1411
1412 @staticmethod
1413 def evConvOutputShapeMismatch(check=False, **kwargs):
1414 error_name = ErrorIf.ConvOutputShapeMismatch
1415 param_reqs = {"rank": None, "dtype": None, "shape": None}
1416 error_result = False
1417 error_reason = (
1418 "Mismatch between output shape provided and expected output shape"
1419 )
1420
1421 if check:
1422 op = kwargs["op"]
1423 pad = kwargs["pad"]
1424 weight_shape = kwargs["weight_shape"]
1425 input_shape = kwargs["input_shape"]
1426 output_shape = kwargs["output_shape"]
1427 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1428 stride = kwargs["stride"]
1429
1430 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1431
1432 # calculate correct dimensions
1433 dims_correct = []
1434 if min(stride) > 0:
1435 for index in range(len(stride)):
1436 pad_offset = index * 2
1437 if op["op"] == Op.TRANSPOSE_CONV2D:
1438 dims_correct.append(
1439 (input_shape[index + 1] - 1) * stride[index]
1440 - pad[pad_offset]
1441 - pad[pad_offset + 1]
1442 + weight_shape[index + kernel_offset]
1443 )
1444 else:
1445 dims_correct.append(
1446 (
1447 input_shape[index + 1]
1448 - 1
1449 + pad[pad_offset]
1450 + pad[pad_offset + 1]
1451 - (weight_shape[index + kernel_offset] - 1)
1452 * dilation[index]
1453 )
1454 // stride[index]
1455 + 1
1456 )
1457
1458 # ensure parameters are valid
1459 params_valid = TosaErrorValidator.checkConvParams(
1460 weight_shape, stride, pad, dilation
1461 )
1462
1463 if params_valid and output_shape[1:-1] != dims_correct:
1464 error_result = True
1465
1466 info_dict = {
1467 "error_name": error_name,
1468 "error_result": error_result,
1469 "error_reason": error_reason,
1470 "param_reqs": param_reqs,
1471 }
1472 return info_dict
1473
1474 @staticmethod
1475 def evConvOutputShapeNonInteger(check=False, **kwargs):
1476 error_name = ErrorIf.ConvOutputShapeNonInteger
1477 param_reqs = {"rank": None, "dtype": None, "shape": None}
1478 error_result = False
1479 error_reason = "Parameters do not yield exact integer output dimensions"
1480
1481 if check:
1482 op = kwargs["op"]
1483 pad = kwargs["pad"]
1484 weight_shape = kwargs["weight_shape"]
1485 input_shape = kwargs["input_shape"]
1486 dilation = kwargs["dilation"]
1487 stride = kwargs["stride"]
1488
1489 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1490
1491 # calculate correct height, width dimensions
1492 remainders = []
1493 if min(stride) > 0:
1494 for index in range(len(stride)):
1495 pad_offset = index * 2
1496 remainders.append(
1497 (
1498 input_shape[index + 1]
1499 - 1
1500 + pad[pad_offset]
1501 + pad[pad_offset + 1]
1502 - (weight_shape[index + kernel_offset] - 1)
1503 * dilation[index]
1504 )
1505 % stride[index]
1506 )
1507
1508 # ensure parameters are valid
1509 params_valid = TosaErrorValidator.checkConvParams(
1510 weight_shape, stride, pad, dilation
1511 )
1512 if params_valid and max(remainders) > 0:
1513 error_result = True
1514
1515 info_dict = {
1516 "error_name": error_name,
1517 "error_result": error_result,
1518 "error_reason": error_reason,
1519 "param_reqs": param_reqs,
1520 }
1521 return info_dict
1522
1523 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001524 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1525 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1526 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1527 error_result = False
1528 error_reason = (
1529 "Mismatch between output shape provided and expected output shape"
1530 )
1531
1532 if check:
1533 output_shape = kwargs["output_shape"]
1534 input_shape = kwargs["input_shape"]
1535 axis = kwargs["axis"]
1536
1537 dimension_match = True
1538 axis_shift = 0
1539
1540 # Check that rank is correct before trying to check dimensions
1541 if (len(input_shape) - 1) == len(output_shape):
1542 for i in range(len(input_shape)):
1543 if i == axis:
1544 axis_shift = 1
1545 continue
1546 if input_shape[i] != output_shape[i - axis_shift]:
1547 dimension_match = False
1548
1549 if not dimension_match:
1550 error_result = True
1551
1552 info_dict = {
1553 "error_name": error_name,
1554 "error_result": error_result,
1555 "error_reason": error_reason,
1556 "param_reqs": param_reqs,
1557 }
1558 return info_dict
1559
1560 @staticmethod
1561 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1562 error_name = ErrorIf.ArgmaxOutputRankMismatch
1563 param_reqs = {"rank": None, "dtype": None, "shape": None}
1564 error_result = False
1565 error_reason = (
1566 "Mismatch between output shape provided and expected output shape"
1567 )
1568
1569 if check:
1570 output_shape = kwargs["output_shape"]
1571 input_shape = kwargs["input_shape"]
1572 axis = kwargs["axis"]
1573 valid_params = axis >= 0 and axis < len(input_shape)
1574
1575 if valid_params and (len(input_shape) - 1) != len(output_shape):
1576 error_result = True
1577
1578 info_dict = {
1579 "error_name": error_name,
1580 "error_result": error_result,
1581 "error_reason": error_reason,
1582 "param_reqs": param_reqs,
1583 }
1584 return info_dict
1585
1586 @staticmethod
1587 def evKernelSmallerOne(check=False, **kwargs):
1588 error_name = ErrorIf.KernelSmallerOne
1589 param_reqs = {"rank": None, "dtype": None, "shape": None}
1590 error_result = False
1591 error_reason = "At least one kernel dimension is smaller than zero"
1592
1593 if check:
1594 kernel = kwargs["kernel"]
1595 if min(kernel) < 1:
1596 error_result = True
1597
1598 info_dict = {
1599 "error_name": error_name,
1600 "error_result": error_result,
1601 "error_reason": error_reason,
1602 "param_reqs": param_reqs,
1603 }
1604 return info_dict
1605
1606 @staticmethod
1607 def evStrideSmallerOne(check=False, **kwargs):
1608 error_name = ErrorIf.StrideSmallerOne
1609 param_reqs = {"rank": None, "dtype": None, "shape": None}
1610 error_result = False
1611 error_reason = "At least one stride dimension is smaller than zero"
1612
1613 if check:
1614 stride = kwargs["stride"]
1615 if min(stride) < 1:
1616 error_result = True
1617
1618 info_dict = {
1619 "error_name": error_name,
1620 "error_result": error_result,
1621 "error_reason": error_reason,
1622 "param_reqs": param_reqs,
1623 }
1624 return info_dict
1625
1626 @staticmethod
1627 def evDilationSmallerOne(check=False, **kwargs):
1628 error_result = check and min(kwargs["dilation"]) < 1
1629 return {
1630 "error_name": ErrorIf.DilationSmallerOne,
1631 "error_reason": "At least one dilation is smaller than one",
1632 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1633 "error_result": error_result,
1634 }
1635
1636 @staticmethod
1637 def evScaleTrue(check=False, **kwargs):
1638 error_name = ErrorIf.ScaleTrue
1639 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1640 error_result = False
1641 error_reason = "Scale set to true but input type is INT48"
1642
1643 if check:
1644 input_dtype = kwargs["input_dtype"]
1645 scale32 = kwargs["scale32"]
1646 if scale32 and input_dtype == DType.INT48:
1647 error_result = True
1648
1649 info_dict = {
1650 "error_name": error_name,
1651 "error_result": error_result,
1652 "error_reason": error_reason,
1653 "param_reqs": param_reqs,
1654 }
1655 return info_dict
1656
1657 @staticmethod
1658 def evScaleNotTrue(check=False, **kwargs):
1659 error_name = ErrorIf.ScaleNotTrue
1660 param_reqs = {"rank": None, "dtype": None, "shape": None}
1661 error_result = False
1662 error_reason = "Scale set to false but double round set to true"
1663
1664 if check:
1665 scale32 = kwargs["scale32"]
1666 double_round = kwargs["double_round"]
1667 if not scale32 and double_round:
1668 error_result = True
1669
1670 info_dict = {
1671 "error_name": error_name,
1672 "error_result": error_result,
1673 "error_reason": error_reason,
1674 "param_reqs": param_reqs,
1675 }
1676 return info_dict
1677
1678 @staticmethod
1679 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1680 error_name = ErrorIf.TensorSizeInputOutputMismatch
1681 param_reqs = {"rank": None, "dtype": None, "shape": None}
1682 error_result = False
1683 error_reason = "Input tensor size does not match output tensor size"
1684
1685 if check:
1686 input_shape = kwargs["input_shape"]
1687 output_shape = kwargs["output_shape"]
1688 input_size = np.prod(input_shape)
1689 output_size = np.prod(output_shape)
1690 if input_size != output_size:
1691 error_result = True
1692
1693 info_dict = {
1694 "error_name": error_name,
1695 "error_result": error_result,
1696 "error_reason": error_reason,
1697 "param_reqs": param_reqs,
1698 }
1699 return info_dict
1700
1701 @staticmethod
1702 def evStartSmallerZero(check=False, **kwargs):
1703 error_name = ErrorIf.StartSmallerZero
1704 param_reqs = {"rank": None, "dtype": None, "shape": None}
1705 error_result = False
1706 error_reason = "Starting point smaller than zero"
1707
1708 if check:
1709 input_shape = kwargs["input_shape"]
1710 start = kwargs["start"]
1711 rank = len(input_shape)
1712 if len(start) == rank:
1713 for index in range(rank):
1714 if start[index] < 0:
1715 error_result = True
1716
1717 info_dict = {
1718 "error_name": error_name,
1719 "error_result": error_result,
1720 "error_reason": error_reason,
1721 "param_reqs": param_reqs,
1722 }
1723 return info_dict
1724
1725 @staticmethod
1726 def evSizeSmallerEqualZero(check=False, **kwargs):
1727 error_name = ErrorIf.SizeSmallerEqualZero
1728 param_reqs = {"rank": None, "dtype": None, "shape": None}
1729 error_result = False
1730 error_reason = "Size smaller than or equal to zero"
1731
1732 if check:
1733 input_shape = kwargs["input_shape"]
1734 size = kwargs["size"]
1735 rank = len(input_shape)
1736 if len(size) == rank:
1737 for index in range(rank):
1738 if size[index] <= 0:
1739 error_result = True
1740
1741 info_dict = {
1742 "error_name": error_name,
1743 "error_result": error_result,
1744 "error_reason": error_reason,
1745 "param_reqs": param_reqs,
1746 }
1747 return info_dict
1748
1749 @staticmethod
1750 def evStartSizeOutsideBounds(check=False, **kwargs):
1751 error_name = ErrorIf.StartSizeOutsideBounds
1752 param_reqs = {"rank": None, "dtype": None, "shape": None}
1753 error_result = False
1754 error_reason = "starting point plus size larger than input dimension"
1755
1756 if check:
1757 input_shape = kwargs["input_shape"]
1758 start = kwargs["start"]
1759 size = kwargs["size"]
1760 rank = len(input_shape)
1761 if len(start) == rank and len(size) == rank:
1762 for index in range(rank):
1763 if start[index] + size[index] > input_shape[index]:
1764 error_result = True
1765
1766 info_dict = {
1767 "error_name": error_name,
1768 "error_result": error_result,
1769 "error_reason": error_reason,
1770 "param_reqs": param_reqs,
1771 }
1772 return info_dict
1773
1774 @staticmethod
1775 def evSizeOutputShapeMismatch(check=False, **kwargs):
1776 error_name = ErrorIf.SizeOutputShapeMismatch
1777 param_reqs = {"rank": None, "dtype": None, "shape": None}
1778 error_result = False
1779 error_reason = "Size does not match output dimension"
1780
1781 if check:
1782 input_shape = kwargs["input_shape"]
1783 output_shape = kwargs["output_shape"]
1784 size = kwargs["size"]
1785 rank = len(input_shape)
1786 if len(size) == rank:
1787 for index in range(rank):
1788 if size[index] != output_shape[index]:
1789 error_result = True
1790
1791 info_dict = {
1792 "error_name": error_name,
1793 "error_result": error_result,
1794 "error_reason": error_reason,
1795 "param_reqs": param_reqs,
1796 }
1797 return info_dict
1798
1799 @staticmethod
1800 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1801 error_name = ErrorIf.InputSizeStartLengthMismatch
1802 param_reqs = {"rank": None, "dtype": None, "shape": None}
1803 error_result = False
1804 error_reason = "rank of input not equal to length of start or size"
1805
1806 if check:
1807 input_shape = kwargs["input_shape"]
1808 start = kwargs["start"]
1809 size = kwargs["size"]
1810 rank = len(input_shape)
1811 if rank != len(start) or rank != len(size):
1812 error_result = True
1813
1814 info_dict = {
1815 "error_name": error_name,
1816 "error_result": error_result,
1817 "error_reason": error_reason,
1818 "param_reqs": param_reqs,
1819 }
1820 return info_dict
1821
1822 @staticmethod
1823 def evIndexOutsideBounds(check=False, **kwargs):
1824 error_name = ErrorIf.IndexOutsideBounds
1825 param_reqs = {"rank": None, "dtype": None, "shape": None}
1826 error_result = False
1827 error_reason = "Index outside of allowed bounds"
1828
1829 if check:
1830 input_shape = kwargs["input_shape"]
1831 perms = kwargs["perms"]
1832 rank = len(input_shape)
1833
1834 for index in perms:
1835 if index < 0 or index > rank:
1836 error_result = True
1837
1838 info_dict = {
1839 "error_name": error_name,
1840 "error_result": error_result,
1841 "error_reason": error_reason,
1842 "param_reqs": param_reqs,
1843 }
1844 return info_dict
1845
1846 @staticmethod
1847 def evIndexUsedTwice(check=False, **kwargs):
1848 error_name = ErrorIf.IndexUsedTwice
1849 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1850 error_result = False
1851 error_reason = "Index used multiple times"
1852
1853 if check:
1854 perms = kwargs["perms"]
1855
1856 unique_indices = []
1857 for index in perms:
1858 if index in unique_indices:
1859 error_result = True
1860 else:
1861 unique_indices.append(index)
1862
1863 info_dict = {
1864 "error_name": error_name,
1865 "error_result": error_result,
1866 "error_reason": error_reason,
1867 "param_reqs": param_reqs,
1868 }
1869 return info_dict
1870
1871 @staticmethod
1872 def evMaxSmallerMin(check=False, **kwargs):
1873 error_name = ErrorIf.MaxSmallerMin
1874 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1875 error_result = False
1876 error_reason = "Max value smaller than min value"
1877
1878 if check:
1879 max_val = kwargs["max_val"]
1880 min_val = kwargs["min_val"]
1881 if max_val < min_val:
1882 error_result = True
1883
1884 info_dict = {
1885 "error_name": error_name,
1886 "error_result": error_result,
1887 "error_reason": error_reason,
1888 "param_reqs": param_reqs,
1889 }
1890 return info_dict
1891
1892 @staticmethod
1893 def evConcatInputRankMismatch(check=False, **kwargs):
1894 error_name = ErrorIf.ConcatInputRankMismatch
1895 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1896 error_result = False
1897 error_reason = "Input ranks are not identical"
1898
1899 if check:
1900 inputs = kwargs["inputs"]
1901 input_shape = kwargs["input_shape"]
1902 for input in inputs:
1903 if len(input.shape) != len(input_shape):
1904 error_result = True
1905
1906 info_dict = {
1907 "error_name": error_name,
1908 "error_result": error_result,
1909 "error_reason": error_reason,
1910 "param_reqs": param_reqs,
1911 }
1912 return info_dict
1913
1914 @staticmethod
1915 def evConcatInputDimMismatch(check=False, **kwargs):
1916 error_name = ErrorIf.ConcatInputDimMismatch
1917 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1918 error_result = False
1919 error_reason = "Input dimensions differ on too many axes"
1920
1921 if check:
1922 inputs = kwargs["inputs"]
1923 input_shape = kwargs["input_shape"]
1924 axis = kwargs["axis"]
1925
1926 # Ensure rank is valid before checking dims.
1927 valid_rank = True
1928 for input in inputs:
1929 if len(input.shape) != len(input_shape):
1930 valid_rank = False
1931
1932 if valid_rank:
1933 for input in inputs:
1934 for i, dim in enumerate(input.shape):
1935 if dim != input_shape[i] and axis != i:
1936 error_result = True
1937
1938 info_dict = {
1939 "error_name": error_name,
1940 "error_result": error_result,
1941 "error_reason": error_reason,
1942 "param_reqs": param_reqs,
1943 }
1944 return info_dict
1945
1946 @staticmethod
1947 def evConcatShapeSumMismatch(check=False, **kwargs):
1948 error_name = ErrorIf.ConcatShapeSumMismatch
1949 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1950 error_result = False
1951 error_reason = "Sum of dimensions on axis not equal to output dimension"
1952
1953 if check:
1954 inputs = kwargs["inputs"]
1955 input_shape = kwargs["input_shape"]
1956 output_shape = kwargs["output_shape"]
1957 axis = kwargs["axis"]
1958
1959 # Ensure rank is valid before checking dims.
1960 valid_params = True
1961 for input in inputs:
1962 if len(input.shape) != len(input_shape):
1963 valid_params = False
1964 if axis < 0 or axis > len(input_shape):
1965 valid_params = False
1966
1967 if valid_params:
1968 axis_dim_sum = 0
1969 for input in inputs:
1970 axis_dim_sum += input.shape[axis]
1971
1972 if axis_dim_sum != output_shape[axis]:
1973 error_result = True
1974
1975 info_dict = {
1976 "error_name": error_name,
1977 "error_result": error_result,
1978 "error_reason": error_reason,
1979 "param_reqs": param_reqs,
1980 }
1981 return info_dict
1982
1983 @staticmethod
1984 def evInputListThenGraphMismatch(check=False, **kwargs):
1985 error_name = ErrorIf.CondIfInputListThenGraphMismatch
1986 param_reqs = {"rank": None, "dtype": None, "shape": None}
1987 error_result = False
1988 error_reason = "Input list shape does not match then-graph shape"
1989
1990 if check:
1991 a = kwargs["a"]
1992 b = kwargs["b"]
1993 basicBlocks = kwargs["basicBlocks"]
1994 then_block = basicBlocks[1]
1995 then_inputs = then_block.inputs
1996 then_tens = then_block.tensors
1997 if (a.shape != then_tens[then_inputs[0]].shape) or (
1998 b.shape != then_tens[then_inputs[1]].shape
1999 ):
2000 error_result = True
2001
2002 info_dict = {
2003 "error_name": error_name,
2004 "error_result": error_result,
2005 "error_reason": error_reason,
2006 "param_reqs": param_reqs,
2007 }
2008 return info_dict
2009
2010 @staticmethod
2011 def evInputListElseGraphMismatch(check=False, **kwargs):
2012 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2013 param_reqs = {"rank": None, "dtype": None, "shape": None}
2014 error_result = False
2015 error_reason = "Input list shape does not match else-graph shape"
2016
2017 if check:
2018 a = kwargs["a"]
2019 b = kwargs["b"]
2020 basicBlocks = kwargs["basicBlocks"]
2021 else_block = basicBlocks[2]
2022 else_inputs = else_block.inputs
2023 else_tens = else_block.tensors
2024 if (a.shape != else_tens[else_inputs[0]].shape) or (
2025 b.shape != else_tens[else_inputs[1]].shape
2026 ):
2027 error_result = True
2028
2029 info_dict = {
2030 "error_name": error_name,
2031 "error_result": error_result,
2032 "error_reason": error_reason,
2033 "param_reqs": param_reqs,
2034 }
2035 return info_dict
2036
2037 @staticmethod
2038 def evOutputListThenGraphMismatch(check=False, **kwargs):
2039 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2040 param_reqs = {"rank": None, "dtype": None, "shape": None}
2041 error_result = False
2042 error_reason = "Output list shape does not match then-graph shape"
2043
2044 if check:
2045 basicBlocks = kwargs["basicBlocks"]
2046 cond_block = basicBlocks[0]
2047 cond_outputs = cond_block.outputs
2048 cond_tens = cond_block.tensors
2049 then_block = basicBlocks[1]
2050 then_outputs = then_block.outputs
2051 then_tens = then_block.tensors
2052 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2053 error_result = True
2054
2055 info_dict = {
2056 "error_name": error_name,
2057 "error_result": error_result,
2058 "error_reason": error_reason,
2059 "param_reqs": param_reqs,
2060 }
2061 return info_dict
2062
2063 @staticmethod
2064 def evOutputListElseGraphMismatch(check=False, **kwargs):
2065 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2066 param_reqs = {"rank": None, "dtype": None, "shape": None}
2067 error_result = False
2068 error_reason = "Output list shape does not match else-graph shape"
2069
2070 if check:
2071 basicBlocks = kwargs["basicBlocks"]
2072 cond_block = basicBlocks[0]
2073 cond_outputs = cond_block.outputs
2074 cond_tens = cond_block.tensors
2075 else_block = basicBlocks[2]
2076 else_outputs = else_block.outputs
2077 else_tens = else_block.tensors
2078 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2079 error_result = True
2080
2081 info_dict = {
2082 "error_name": error_name,
2083 "error_result": error_result,
2084 "error_reason": error_reason,
2085 "param_reqs": param_reqs,
2086 }
2087 return info_dict
2088
2089 @staticmethod
2090 def evInputListOutputListMismatch(check=False, **kwargs):
2091 error_name = ErrorIf.InputListOutputListMismatch
2092 param_reqs = {"rank": None, "dtype": None, "shape": None}
2093 error_result = False
2094 error_reason = "Input list does not match output list"
2095
2096 if check:
2097 basicBlocks = kwargs["basicBlocks"]
2098 while_block = basicBlocks[0]
2099 while_inputs = while_block.inputs
2100 while_outputs = while_block.outputs
2101 while_tens = while_block.tensors
2102 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2103 error_result = True
2104
2105 info_dict = {
2106 "error_name": error_name,
2107 "error_result": error_result,
2108 "error_reason": error_reason,
2109 "param_reqs": param_reqs,
2110 }
2111 return info_dict
2112
2113 @staticmethod
2114 def evInputListCondGraphMismatch(check=False, **kwargs):
2115 error_name = ErrorIf.InputListCondGraphMismatch
2116 param_reqs = {"rank": None, "dtype": None, "shape": None}
2117 error_result = False
2118 error_reason = "Input list does not match cond graph"
2119
2120 if check:
2121 basicBlocks = kwargs["basicBlocks"]
2122 while_block = basicBlocks[0]
2123 while_inputs = while_block.inputs
2124 while_tens = while_block.tensors
2125 cond_block = basicBlocks[1]
2126 cond_inputs = cond_block.inputs
2127 cond_tens = cond_block.tensors
2128 if (
2129 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2130 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2131 error_result = True
2132
2133 info_dict = {
2134 "error_name": error_name,
2135 "error_result": error_result,
2136 "error_reason": error_reason,
2137 "param_reqs": param_reqs,
2138 }
2139 return info_dict
2140
2141 @staticmethod
2142 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2143 error_name = ErrorIf.InputListBodyGraphInputMismatch
2144 param_reqs = {"rank": None, "dtype": None, "shape": None}
2145 error_result = False
2146 error_reason = "Input list does not match body graph input"
2147
2148 if check:
2149 basicBlocks = kwargs["basicBlocks"]
2150 while_block = basicBlocks[0]
2151 while_inputs = while_block.inputs
2152 while_tens = while_block.tensors
2153 body_block = basicBlocks[2]
2154 body_outputs = body_block.inputs
2155 body_tens = body_block.tensors
2156 if (
2157 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2158 ) or (
2159 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2160 ):
2161 error_result = True
2162
2163 info_dict = {
2164 "error_name": error_name,
2165 "error_result": error_result,
2166 "error_reason": error_reason,
2167 "param_reqs": param_reqs,
2168 }
2169 return info_dict
2170
2171 @staticmethod
2172 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2173 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2174 param_reqs = {"rank": None, "dtype": None, "shape": None}
2175 error_result = False
2176 error_reason = "Input list does not match body graph output"
2177
2178 if check:
2179 basicBlocks = kwargs["basicBlocks"]
2180 while_block = basicBlocks[0]
2181 while_inputs = while_block.inputs
2182 while_tens = while_block.tensors
2183 body_block = basicBlocks[2]
2184 body_outputs = body_block.outputs
2185 body_tens = body_block.tensors
2186 if (
2187 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2188 ) or (
2189 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2190 ):
2191 error_result = True
2192 info_dict = {
2193 "error_name": error_name,
2194 "error_result": error_result,
2195 "error_reason": error_reason,
2196 "param_reqs": param_reqs,
2197 }
2198 return info_dict
2199
2200 @staticmethod
2201 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2202 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2203 param_reqs = {"rank": None, "dtype": None, "shape": None}
2204 error_result = False
2205 error_reason = "Cond graph output is not a match list of booleans"
2206
2207 if check:
2208 basicBlocks = kwargs["basicBlocks"]
2209 cond_block = basicBlocks[1]
2210 cond_outputs = cond_block.outputs
2211 cond_tens = cond_block.tensors
2212 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2213 error_result = True
2214
2215 info_dict = {
2216 "error_name": error_name,
2217 "error_result": error_result,
2218 "error_reason": error_reason,
2219 "param_reqs": param_reqs,
2220 }
2221 return info_dict
2222
2223
2224class TosaInvalidValidator:
2225 @staticmethod
2226 def ivWrongDataTypeOrModeResize(**kwargs):
2227 input_dtype = kwargs["input_dtype"]
2228 args = kwargs["args"]
2229 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002230 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002231
2232 if mode == ResizeMode.BILINEAR:
2233 # Invalid output data type / Invalid input datatype
2234 return (
2235 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002236 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
2237 and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002238 )
2239 elif mode == ResizeMode.NEAREST:
2240 # Invalid output data type / Invalid input datatype
2241 return (input_dtype != output_dtype) or (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002242 input_dtype not in [DType.INT8, DType.INT16, DType.FLOAT]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002243 )
2244 else:
2245 # Invalid resize mode
2246 return True
2247
2248 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002249 def ivHeightWidthInvalid(**kwargs):
2250 opName = kwargs["opName"]
2251
2252 inputShapes = kwargs["shapeList"]
2253 input_shape = inputShapes[0]
2254
2255 args = kwargs["args"]
2256 strides = args[0]
2257 padding = args[1]
2258
2259 if opName.endswith("pool2d"):
2260 # avg_pool2d, max_pool2d
2261 kernel_shape = args[2]
2262 h = (
2263 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2264 ) // strides[0]
2265 w = (
2266 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2267 ) // strides[1]
2268 # return True if any dimension is < 1
2269 return h < 1 or w < 1
2270
2271 if opName.startswith("transpose_conv2d"):
2272 # transpose_conv2d
TatWai Chong24594f52022-06-08 00:48:04 -07002273 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002274 filter_shape = inputShapes[1]
2275 kernel_shape = filter_shape[1:-1]
2276
TatWai Chong24594f52022-06-08 00:48:04 -07002277 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002278 """Calculate the transpose_conv2d output size for a dimension.
2279
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002280 Args:
2281 in_size: the input size - int
2282 stride: the stride - int
2283 kernel_size: the kernel size - int
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002284 out_pad: the output padding - int
2285 in_pad: the input padding - int
2286
2287 Returns:
2288 the output size
2289 """
TatWai Chong24594f52022-06-08 00:48:04 -07002290 return (in_size - 1) * stride + kernel_size - in_pad - out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002291
2292 for pad_h, pad_w in (
2293 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2294 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2295 (0, 0), # VALID padding
2296 ):
2297 h = get_out_size(
2298 input_shape[1],
2299 strides[0],
2300 kernel_shape[0],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002301 padding[0],
2302 pad_h,
2303 )
2304 w = get_out_size(
2305 input_shape[2],
2306 strides[1],
2307 kernel_shape[1],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002308 padding[1],
2309 pad_w,
2310 )
2311 if output_shape[1] == h and output_shape[2] == w:
2312 return False
2313
2314 # output shape does not match the expected shape for any padding option
2315 return True
2316
2317 if "conv2d" in opName or "conv3d" in opName:
2318 # conv2d, conv3d, depthwise_conv2d
2319 dilations = args[2]
2320 filter_shape = inputShapes[1]
2321 kernel_shape = (
2322 filter_shape[0:2]
2323 if opName.startswith("depthwise_conv2d")
2324 else filter_shape[1:-1]
2325 )
2326
2327 for i in range(len(kernel_shape)):
2328 dim = (
2329 input_shape[i + 1]
2330 - kernel_shape[i]
2331 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2332 + padding[i * 2 + 0]
2333 + padding[i * 2 + 1]
2334 ) // strides[i] + 1
2335 # return True if any dimension is < 1
2336 if dim < 1:
2337 return True
2338 return False
2339
2340 assert False, f"Unrecognized Op: {opName}"
2341
2342 @staticmethod
2343 def ivNonPositiveOutputShape(**kwargs):
2344 args = kwargs["args"]
TatWai Chong24594f52022-06-08 00:48:04 -07002345 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002346 if output_shape[1] <= 0 or output_shape[2] <= 0:
2347 # Negative output shape
2348 return True
2349 return False