blob: a7668031f303b92f9da01ae985e4b732277cd930 [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005from generator.tosa_utils import product
6from generator.tosa_utils import usableDTypes
7from generator.tosa_utils import valueToName
8from tosa.DType import DType
9from tosa.Op import Op
10from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011
Matthew Haddone86fd342021-09-07 16:12:21 +010012
13class ErrorIf(object):
14 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010015 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
16 ScaleNLargerMax = "ScaleNLargerMax"
17 ScaleDLargerMax = "ScaleDLargerMax"
18 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010019 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010020 BorderSmallerMin = "BorderSmallerMin"
21 BorderLargerEqualMax = "BorderLargerEqualMax"
22 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
23 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010024 WrongInputType = "WrongInputType"
25 WrongOutputType = "WrongOutputType"
26 WrongInputList = "WrongInputList"
27 WrongOutputList = "WrongOutputList"
28 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010029 BatchMismatch = "BatchMismatch"
30 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010031 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000032 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010033 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010034 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010036 AxisSmallerZero = "AxisSmallerZero"
37 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010038 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
39 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010040 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010041 KernelSmallerOne = "KernelSmallerOne"
42 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000043 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010044 PadSmallerZero = "PadSmallerZero"
45 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010046 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010047 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010048 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
49 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
50 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010051 ScaleNotTrue = "ScaleNotTrue"
52 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010053 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
54 StartSmallerZero = "StartSmallerZero"
55 SizeSmallerEqualZero = "SizeSmallerEqualZero"
56 StartSizeOutsideBounds = "StartSizeOutsideBounds"
57 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
58 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
59 IndexOutsideBounds = "IndexOutsideBounds"
60 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010061 MaxSmallerMin = "MaxSmallerMin"
62 ConcatInputRankMismatch = "ConcatInputRankMismatch"
63 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010064 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010065 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
66 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
67 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
68 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
69 InputListOutputListMismatch = "InputListOutputListMismatch"
70 InputListCondGraphMismatch = "InputListCondGraphMismatch"
71 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
72 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
73 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010074 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
75 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010076
77
78class TosaErrorIfArgGen:
79 @staticmethod
80 def eiResizeErrorIf(
81 testGen,
82 error_name,
83 mode,
84 dtype,
85 shapeList,
86 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010087 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010088 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010089 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010090 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010091 if error_name == ErrorIf.ScaleSmallerEqualZero:
92 index = testGen.randInt(low=0, high=4)
93 scale[index] = testGen.rng.choice([-2, -1, 0])
94 elif error_name == ErrorIf.ScaleNLargerMax:
95 index = testGen.rng.choice([0, 2])
96 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
97 elif error_name == ErrorIf.ScaleDLargerMax:
98 index = testGen.rng.choice([1, 3])
99 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100101 if error_name == ErrorIf.OffsetLargerEqualMax:
102 index = testGen.rng.choice([0, 1])
103 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
104 elif error_name == ErrorIf.OffsetSmallerMin:
105 index = testGen.rng.choice([0, 1])
106 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
107
108 if error_name == ErrorIf.BorderLargerEqualMax:
109 index = testGen.rng.choice([0, 1])
110 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
111 elif error_name == ErrorIf.BorderSmallerMin:
112 index = testGen.rng.choice([0, 1])
113 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100114
115 if error_name == ErrorIf.WrongOutputType:
116 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
117 incorrect_types = (
118 DType.INT4,
119 DType.INT16,
120 DType.INT32,
121 DType.INT48,
122 DType.FLOAT,
James Ward8b390432022-08-12 20:48:56 +0100123 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100124 )
125 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
126 incorrect_types = (
127 DType.INT4,
128 DType.INT8,
129 DType.INT32,
130 DType.INT48,
131 DType.FLOAT,
James Ward8b390432022-08-12 20:48:56 +0100132 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100133 )
134 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
135 incorrect_types = (
136 DType.INT4,
137 DType.INT8,
138 DType.INT16,
139 DType.INT48,
140 DType.FLOAT,
James Ward8b390432022-08-12 20:48:56 +0100141 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100142 )
143 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
144 incorrect_types = (
145 DType.INT4,
146 DType.INT8,
147 DType.INT16,
148 DType.INT32,
149 DType.FLOAT,
James Ward8b390432022-08-12 20:48:56 +0100150 DType.FP16,
151 )
152 elif dtype == DType.FP16:
153 incorrect_types = (
154 DType.INT4,
155 DType.INT8,
156 DType.INT16,
157 DType.INT32,
158 DType.INT48,
159 DType.FLOAT,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100160 )
161 elif dtype == DType.FLOAT:
162 incorrect_types = (
163 DType.INT4,
164 DType.INT8,
165 DType.INT16,
166 DType.INT32,
167 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100168 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100169 )
170 outputDType = testGen.rng.choice(a=incorrect_types)
171
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100172 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100173
174 @staticmethod
175 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
176 if (
177 error_name == ErrorIf.StrideSmallerOne
178 # padding must not exceed the kernel size
179 and pad[0] < kernel[0]
180 and pad[1] < kernel[0]
181 and pad[2] < kernel[1]
182 and pad[3] < kernel[1]
183 ):
184 wrongStride = (
185 testGen.rng.choice([0, -1, -2, -3]),
186 testGen.rng.choice([0, -1, -2, -3]),
187 )
188 return wrongStride, pad, kernel
189 elif error_name == ErrorIf.PadSmallerZero:
190 wrongPad = (
191 testGen.rng.choice([-1, -2, -3]),
192 testGen.rng.choice([-1, -2, -3]),
193 testGen.rng.choice([-1, -2, -3]),
194 testGen.rng.choice([-1, -2, -3]),
195 )
196 return stride, wrongPad, kernel
197 elif error_name == ErrorIf.KernelSmallerOne:
198 wrongKernel = (
199 testGen.rng.choice([0, -1, -2, -3]),
200 testGen.rng.choice([0, -1, -2, -3]),
201 )
202 return stride, pad, wrongKernel
203 elif error_name == ErrorIf.PadLargerEqualKernel:
204 wrongPad = (
205 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
206 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
207 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
208 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
209 )
210 return stride, wrongPad, kernel
211 else:
212 return None, None, None
213
214 @staticmethod
215 def eiRescaleWrongOutputType(input_dtype, output_dtype):
216 if input_dtype == DType.INT8:
217 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
218 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100219 elif input_dtype == DType.INT16:
220 if output_dtype not in [
221 DType.UINT8,
222 DType.INT8,
223 DType.UINT16,
224 DType.INT16,
225 DType.INT32,
226 ]:
227 return True
228 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100229 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
230 return True
231 elif input_dtype == DType.INT48:
232 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
233 return True
234 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100235 if output_dtype not in [DType.INT8, DType.INT16]:
236 return True
237 elif input_dtype == DType.UINT16:
238 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100239 return True
240 return False
241
242 @staticmethod
243 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
244 # Mess up input/output tensors for ERROR_IF checks
245 if error_name == "WrongInputList":
246 add_input = testGen.rng.choice([True, False])
247 if add_input:
248 input_list.append("eiDummyInput")
249 else:
250 input_list = input_list[:-1]
251 elif error_name == "WrongOutputList":
252 add_output = testGen.rng.choice([True, False])
253 if add_output:
254 output_list.append("eiDummyOutput")
255 else:
256 output_list = []
257 return input_list, output_list
258
259 @staticmethod
260 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
261 """Restrict the dimensions and overall size of a shape to
262 max_dim and max_items.
263 """
264 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
265 while product(new_shape) > max_items:
266 new_shape = [max(d - 1, 1) for d in new_shape]
267 return new_shape
268
269 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
270 if error_name == ErrorIf.StartSmallerZero:
271 newStart = []
272 for i in range(len(input_shape)):
273 newStart.append(testGen.rng.choice([-3, -2, -1]))
274 return newStart, size
275 elif error_name == ErrorIf.SizeSmallerEqualZero:
276 newSize = []
277 for i in range(len(input_shape)):
278 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
279 return start, newSize
280 elif error_name == ErrorIf.StartSizeOutsideBounds:
281 newStart, newSize = [], []
282 for i in range(len(input_shape)):
283 newStart.append(input_shape[i] - 1)
284 newSize.append(testGen.rng.choice([2, 3, 4]))
285 return newStart, newSize
286 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
287 remove = testGen.rng.choice([True, False])
288 if remove:
289 newStart = start[1:]
290 newSize = size[1:]
291 else:
292 newStart = start
293 newStart.append(1)
294 newSize = size
295 newSize.append(1)
296 return newStart, newSize
297 else:
298 return start, size
299
300 @staticmethod
301 def eiCastErrorIf(testGen, input_dtype):
James Ward8b390432022-08-12 20:48:56 +0100302 if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]:
303 outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100304 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
305 outputDType = [DType.INT48]
306 else:
307 assert True, f"input_dtype ({input_dtype}) not supported"
308 return outputDType
309
310
311class TosaErrorValidator:
312 @staticmethod
313 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
314 """Check ERROR_IF statements are caught and set the expected result.
315
316 Args:
317 serializer: the serializer to set the expected result in
318 validator_fcns: a sequence of validator functions to verify the result
319 error_name: the name of the ERROR_IF condition to check for
320 kwargs: keyword arguments for the validator functions
321 Returns:
322 True if the result matches the expected result; otherwise False
323 """
324 overall_result = True
325 for val_fcn in validator_fcns:
326 val_result = val_fcn(True, **kwargs)
327 validator_name = val_result["error_name"]
328 error_result = val_result["error_result"]
329 error_reason = val_result["error_reason"]
330
331 # expect an error IFF the error_name and validator_name match
332 expected_result = error_result == (error_name == validator_name)
333 overall_result &= expected_result
334
335 if expected_result and error_result:
336 serializer.setExpectedReturnCode(2, True, desc=error_reason)
337 elif error_result: # and not expected_result
338 print(
339 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
340 f" Expected: {error_name}, Got: {validator_name}"
341 )
342 elif not expected_result: # and not error_result
343 print(
344 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
345 f" Expected: {error_name}"
346 )
347
348 if not expected_result:
349 for k, v in sorted(kwargs.items()):
350 if k != "op":
351 if k.endswith("dtype"):
352 v = valueToName(DType, v)
353 print(f" {k} = {v}")
354
355 return overall_result
356
357 @staticmethod
358 def evWrongInputType(check=False, **kwargs):
359 error_result = False
360
361 # Find the unsupported input data types
362 op = kwargs["op"]
363 input_dtypes = op["types"]
364 allowed_input_dtypes = {
365 t[0] if isinstance(t, list) else t for t in input_dtypes
366 }
367 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
368
369 if op["op"] == Op.CLAMP:
370 wrong_input_dtypes.remove(DType.INT48)
371
372 if check:
373 input_dtype = kwargs["input_dtype"]
374 if input_dtype not in allowed_input_dtypes:
375 error_result = True
376
377 info_dict = {
378 "error_name": ErrorIf.WrongInputType,
379 "error_result": error_result,
380 "error_reason": "Input data type not supported for this operator",
381 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
382 }
383 return info_dict
384
385 @staticmethod
386 def evWrongOutputType(check=False, **kwargs):
387 error_result = False
388
389 if check:
390 input_dtype = kwargs["input_dtype"]
391 output_dtype = kwargs["output_dtype"]
392 op = kwargs["op"]
393
394 if op["op"] == Op.RESIZE:
395 mode = kwargs["mode"]
396 if (
397 (
398 mode == ResizeMode.NEAREST
399 and input_dtype == DType.INT8
400 and output_dtype != DType.INT8
401 )
402 or (
403 mode == ResizeMode.NEAREST
404 and input_dtype == DType.INT16
405 and output_dtype != DType.INT16
406 )
407 or (
408 mode == ResizeMode.BILINEAR
409 and input_dtype == DType.INT8
410 and output_dtype != DType.INT32
411 )
412 or (
413 mode == ResizeMode.BILINEAR
414 and input_dtype == DType.INT16
415 and output_dtype != DType.INT48
416 )
James Ward8b390432022-08-12 20:48:56 +0100417 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100418 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
419 ):
420 error_result = True
421
422 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100423 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
424 input_dtype, output_dtype
425 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100426
427 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
428 if (
429 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
430 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100431 or (
432 input_dtype == DType.FP16
433 and output_dtype not in (DType.FP16, DType.FLOAT)
434 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100435 or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
436 ):
437 error_result = True
438
439 elif op["op"] == Op.ARGMAX:
440 if (
James Ward8b390432022-08-12 20:48:56 +0100441 input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100442 and output_dtype != DType.INT32
443 ):
444 error_result = True
445
446 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100447 if (
448 input_dtype not in (DType.FP16, DType.FLOAT)
449 and output_dtype != DType.INT32
450 ):
451 error_result = True
452 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100453 error_result = True
454 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
455 error_result = True
456
457 elif op["op"] == Op.TABLE:
458 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
459 error_result = True
460 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
461 error_result = True
462
463 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
464 if output_dtype != DType.BOOL:
465 error_result = True
466
467 elif op["op"] == Op.CAST:
468 if (
469 (
470 input_dtype == DType.BOOL
471 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
472 )
473 or (
474 input_dtype == DType.INT8
475 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100476 not in [
477 DType.BOOL,
478 DType.INT16,
479 DType.INT32,
480 DType.FLOAT,
481 DType.FP16,
482 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100483 )
484 or (
485 input_dtype == DType.INT16
486 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100487 not in [
488 DType.BOOL,
489 DType.INT8,
490 DType.INT32,
491 DType.FLOAT,
492 DType.FP16,
493 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100494 )
495 or (
496 input_dtype == DType.INT32
497 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100498 not in [
499 DType.BOOL,
500 DType.INT8,
501 DType.INT16,
502 DType.FLOAT,
503 DType.FP16,
504 ]
505 )
506 or (
507 input_dtype == DType.FP16
508 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100509 )
510 or (
511 input_dtype == DType.FLOAT
512 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
513 )
514 ):
515 error_result = True
516
517 elif op["op"] in {
518 Op.CONV2D,
519 Op.CONV3D,
520 Op.DEPTHWISE_CONV2D,
521 Op.TRANSPOSE_CONV2D,
522 }:
523 if (
524 input_dtype == DType.INT8
525 and output_dtype != DType.INT32
526 or input_dtype == DType.INT16
527 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100528 or input_dtype == DType.FP16
529 and output_dtype not in (DType.FP16, DType.FLOAT)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100530 or input_dtype == DType.FLOAT
531 and output_dtype != DType.FLOAT
532 ):
533 error_result = True
534 # invalid input types are ignored, to avoid reporting multiple errors
535
536 else:
537 if output_dtype != input_dtype:
538 error_result = True
539
540 info_dict = {
541 "error_name": ErrorIf.WrongOutputType,
542 "error_result": error_result,
543 "error_reason": (
544 "Output data type not supported for this configuration of operator"
545 ),
546 "param_reqs": {"rank": None, "dtype": None, "shape": None},
547 }
548 return info_dict
549
550 @staticmethod
551 def evWrongRank(check=False, **kwargs):
552 all_ranks = (1, 2, 3, 4, 5)
553
554 # Make a list of incorrect ranks
555 assert "op" in kwargs
556 op = kwargs["op"]
557 rmin, rmax = op["rank"]
558 rank_range = range(rmin, rmax + 1)
559 incorrect_ranks = list(set(all_ranks) - set(rank_range))
560 # Remove small incorrect ranks to avoid index errors
561 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
562 # Set minimum incorrect rank to 3 to avoid index error
563 if op["op"] in [Op.RESIZE]:
564 incorrect_ranks = [3, 5]
565 elif op["op"] in [Op.TRANSPOSE]:
566 incorrect_ranks = [7, 8]
567 elif op["op"] in [Op.CONV3D]:
568 incorrect_ranks = [6, 7]
569
570 error_name = ErrorIf.WrongRank
571 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
572 error_result = False
573 error_reason = "Rank not supported for this operator"
574
575 if check:
576 input_shape = kwargs["input_shape"]
577
578 if (
579 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
580 and len(input_shape) != 4
581 ):
582 error_result = True
583 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
584 error_result = True
585 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
586 error_result = True
587 else:
588 if len(input_shape) not in rank_range:
589 error_result = True
590
591 info_dict = {
592 "error_name": error_name,
593 "error_result": error_result,
594 "error_reason": error_reason,
595 "param_reqs": param_reqs,
596 }
597 return info_dict
598
599 @staticmethod
600 def evWrongInputList(check=False, **kwargs):
601 error_name = ErrorIf.WrongInputList
602 param_reqs = {"rank": None, "dtype": None, "shape": None}
603 error_result = False
604 error_reason = "Op input list does not match expected input"
605
606 if check:
607 op = kwargs["op"]
608 input_list = kwargs["input_list"]
609 num_operands = kwargs["num_operands"]
610 if op["op"] in [Op.SCATTER, Op.GATHER]:
611 # SCATTER/GATHER add an indices input tensor in their build functions
612 num_operands += 1
613 if len(input_list) != num_operands:
614 error_result = True
615
616 info_dict = {
617 "error_name": error_name,
618 "error_result": error_result,
619 "error_reason": error_reason,
620 "param_reqs": param_reqs,
621 }
622 return info_dict
623
624 @staticmethod
625 def evWrongOutputList(check=False, **kwargs):
626 error_name = ErrorIf.WrongOutputList
627 param_reqs = {"rank": None, "dtype": None, "shape": None}
628 error_result = False
629 error_reason = "Op output list does not match expected output"
630
631 if check:
632 output_list = kwargs["output_list"]
633 # Note this will be incorrect if an operator returns more than one output
634 if len(output_list) != 1:
635 error_result = True
636
637 info_dict = {
638 "error_name": error_name,
639 "error_result": error_result,
640 "error_reason": error_reason,
641 "param_reqs": param_reqs,
642 }
643 return info_dict
644
645 @staticmethod
646 def evMaxDimExceeded(check=False, **kwargs):
647 error_name = ErrorIf.MaxDimExceeded
648 param_reqs = {
649 "rank": [4, 4],
650 "dtype": [DType.INT8],
651 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
652 }
653 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100654 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100655
656 if check:
657 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100658 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100659 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100660 (input_shape[1] >= MAX_RESIZE_DIMENSION)
661 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
662 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
663 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100664 ):
665 error_result = True
666
667 info_dict = {
668 "error_name": error_name,
669 "error_result": error_result,
670 "error_reason": error_reason,
671 "param_reqs": param_reqs,
672 }
673 return info_dict
674
675 @staticmethod
676 def evBatchMismatch(check=False, **kwargs):
677 error_name = ErrorIf.BatchMismatch
678 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
679 error_result = False
680 error_reason = "Input batch size not equal to output batch size"
681
682 assert "op" in kwargs
683 op = kwargs["op"]
684 rmin, rmax = op["rank"]
685 rank_range = range(rmin, rmax + 1)
686
687 if check:
688 input_shape = kwargs["input_shape"]
689 output_shape = kwargs[
690 "result_tensor"
691 ].shape # Note this is just (N, OH, OW, C)
692
693 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
694 error_result = True
695
696 info_dict = {
697 "error_name": error_name,
698 "error_result": error_result,
699 "error_reason": error_reason,
700 "param_reqs": param_reqs,
701 }
702 return info_dict
703
704 @staticmethod
705 def evChannelMismatch(check=False, **kwargs):
706 error_name = ErrorIf.ChannelMismatch
707 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
708 error_result = False
709 error_reason = "Input channel size not equal to output channel size"
710
711 assert "op" in kwargs
712 op = kwargs["op"]
713 rmin, rmax = op["rank"]
714 rank_range = range(rmin, rmax + 1)
715
716 if check:
717 input_shape = kwargs["input_shape"]
718 output_shape = kwargs[
719 "result_tensor"
720 ].shape # Note this is just (N, OH, OW, C)
721 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
722 error_result = True
723
724 info_dict = {
725 "error_name": error_name,
726 "error_result": error_result,
727 "error_reason": error_reason,
728 "param_reqs": param_reqs,
729 }
730 return info_dict
731
732 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100733 def evScaleSmallerEqualZero(check=False, **kwargs):
734 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100735 param_reqs = {"rank": None, "dtype": None, "shape": None}
736 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100737 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100738
739 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100740 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100741
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100742 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100743 error_result = True
744
745 info_dict = {
746 "error_name": error_name,
747 "error_result": error_result,
748 "error_reason": error_reason,
749 "param_reqs": param_reqs,
750 }
751 return info_dict
752
753 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100754 def evScaleNLargerMax(check=False, **kwargs):
755 error_name = ErrorIf.ScaleNLargerMax
756 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100757 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100758 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100759
760 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100761 scale = kwargs["scale"]
762
763 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
764 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100765
766 info_dict = {
767 "error_name": error_name,
768 "error_result": error_result,
769 "error_reason": error_reason,
770 "param_reqs": param_reqs,
771 }
772 return info_dict
773
774 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100775 def evScaleDLargerMax(check=False, **kwargs):
776 error_name = ErrorIf.ScaleDLargerMax
777 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100778 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100779 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100780
781 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100782 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100783
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100784 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
785 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100786 ):
787 error_result = True
788
789 info_dict = {
790 "error_name": error_name,
791 "error_result": error_result,
792 "error_reason": error_reason,
793 "param_reqs": param_reqs,
794 }
795 return info_dict
796
797 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100798 def evOffsetSmallerMin(check=False, **kwargs):
799 error_name = ErrorIf.OffsetSmallerMin
800 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100801 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100802 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100803
804 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100805 scale = kwargs["scale"]
806 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100807
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100808 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100809 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100810 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100811 error_result = True
812
813 info_dict = {
814 "error_name": error_name,
815 "error_result": error_result,
816 "error_reason": error_reason,
817 "param_reqs": param_reqs,
818 }
819 return info_dict
820
821 @staticmethod
822 def evOffsetLargerEqualMax(check=False, **kwargs):
823 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100824 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100825 error_result = False
826 error_reason = "Offset value larger than or equal to maximum value"
827
828 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100829 scale = kwargs["scale"]
830 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100831
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100832 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
833 error_result = True
834 elif (
835 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
836 ):
837 error_result = True
838
839 info_dict = {
840 "error_name": error_name,
841 "error_result": error_result,
842 "error_reason": error_reason,
843 "param_reqs": param_reqs,
844 }
845 return info_dict
846
847 @staticmethod
848 def evBorderSmallerMin(check=False, **kwargs):
849 error_name = ErrorIf.BorderSmallerMin
850 param_reqs = {"rank": None, "dtype": None, "shape": None}
851 error_result = False
852 error_reason = "Border value smaller than minimum value"
853
854 if check:
855 scale = kwargs["scale"]
856 border = kwargs["border"]
857
858 if (
859 scale[0] > 0
860 and scale[0] <= (1 << 11)
861 and (border[0] < (-16 * scale[0]))
862 ):
863 error_result = True
864 elif (
865 scale[2] > 0
866 and scale[2] <= (1 << 11)
867 and (border[1] < (-16 * scale[2]))
868 ):
869 error_result = True
870
871 info_dict = {
872 "error_name": error_name,
873 "error_result": error_result,
874 "error_reason": error_reason,
875 "param_reqs": param_reqs,
876 }
877 return info_dict
878
879 @staticmethod
880 def evBorderLargerEqualMax(check=False, **kwargs):
881 error_name = ErrorIf.BorderLargerEqualMax
882 param_reqs = {"rank": None, "dtype": None, "shape": None}
883 error_result = False
884 error_reason = "Border value larger than or equal to maximum value"
885
886 if check:
887 scale = kwargs["scale"]
888 border = kwargs["border"]
889
890 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
891 error_result = True
892 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
893 error_result = True
894
895 info_dict = {
896 "error_name": error_name,
897 "error_result": error_result,
898 "error_reason": error_reason,
899 "param_reqs": param_reqs,
900 }
901 return info_dict
902
903 @staticmethod
904 def checkResizeParams(scale, offset, border):
905 return (
906 min(scale) > 0
907 and max(scale[0], scale[2]) <= (1 << 11)
908 and scale[1] < 16 * scale[0]
909 and scale[3] < 16 * scale[2]
910 and offset[0] >= -scale[0]
911 and offset[1] >= -scale[2]
912 and offset[0] < 16 * scale[0]
913 and offset[1] < 16 * scale[2]
914 and border[0] >= -16 * scale[0]
915 and border[1] >= -16 * scale[2]
916 and border[0] < scale[0]
917 and border[1] < scale[2]
918 )
919
920 @staticmethod
921 def evResizeOutputShapeMismatch(check=False, **kwargs):
922 error_name = ErrorIf.ResizeOutputShapeMismatch
923 param_reqs = {"rank": None, "dtype": None, "shape": None}
924 error_result = False
925 error_reason = (
926 "Mismatch between output shape provided and expected output shape"
927 )
928
929 if check:
930 input_shape = kwargs["input_shape"]
931 output_shape = kwargs["output_shape"]
932 scale = kwargs["scale"]
933 offset = kwargs["offset"]
934 border = kwargs["border"]
935
936 # Ensure parameters are valid
937 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
938
939 if (
940 params_valid
941 and max(output_shape) < MAX_RESIZE_DIMENSION
942 and max(input_shape) < MAX_RESIZE_DIMENSION
943 ):
944 output_y = (
945 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
946 ) // scale[1] + 1
947 output_x = (
948 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
949 ) // scale[3] + 1
950
951 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100952 error_result = True
953
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100954 info_dict = {
955 "error_name": error_name,
956 "error_result": error_result,
957 "error_reason": error_reason,
958 "param_reqs": param_reqs,
959 }
960 return info_dict
961
962 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100963 def evResizeOutputShapeNonInteger(check=False, **kwargs):
964 error_name = ErrorIf.ResizeOutputShapeNonInteger
965 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100966 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100967 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100968
969 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100970 input_shape = kwargs["input_shape"]
971 scale = kwargs["scale"]
972 offset = kwargs["offset"]
973 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100974
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100975 # Ensure parameters are valid
976 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100977
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100978 if params_valid:
979 remainder_y = (
980 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
981 ) % scale[1]
982 remainder_x = (
983 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
984 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100985
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100986 if max(remainder_y, remainder_x) > 0:
987 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100988
989 info_dict = {
990 "error_name": error_name,
991 "error_result": error_result,
992 "error_reason": error_reason,
993 "param_reqs": param_reqs,
994 }
995 return info_dict
996
997 @staticmethod
998 def evRankMismatch(check=False, **kwargs):
999 error_name = ErrorIf.RankMismatch
1000 param_reqs = {"rank": None, "dtype": None, "shape": None}
1001 error_result = False
1002 error_reason = "Input Rank does not match output rank"
1003
1004 if check:
1005 input1_shape = kwargs["input1"].shape
1006 input2_shape = kwargs["input2"].shape
1007 # In case of SELECT op
1008 input3_shape = (
1009 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1010 )
1011 output_shape = kwargs["result_tensor"].shape
1012 if (
1013 (len(input1_shape) != len(output_shape))
1014 or (len(input2_shape) != len(output_shape))
1015 or (len(input3_shape) != len(output_shape))
1016 ):
1017 error_result = True
1018
1019 info_dict = {
1020 "error_name": error_name,
1021 "error_result": error_result,
1022 "error_reason": error_reason,
1023 "param_reqs": param_reqs,
1024 }
1025 return info_dict
1026
1027 @staticmethod
1028 def evDimensionMismatch(check=False, **kwargs):
1029 error_name = ErrorIf.DimensionMismatch
1030 param_reqs = {"rank": None, "dtype": None, "shape": None}
1031 error_result = False
1032 error_reason = "Input Dimensions do not match output"
1033
1034 if check:
1035 input1_shape = kwargs["input1"].shape
1036 input2_shape = kwargs["input2"].shape
1037 # In case of SELECT op
1038 input3_shape = (
1039 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1040 )
1041 output_shape = kwargs["result_tensor"].shape
1042 for i in range(
1043 min(len(input1_shape), len(input2_shape), len(input3_shape))
1044 ):
1045 if (
1046 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
1047 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
1048 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
1049 ):
1050 error_result = True
1051
1052 info_dict = {
1053 "error_name": error_name,
1054 "error_result": error_result,
1055 "error_reason": error_reason,
1056 "param_reqs": param_reqs,
1057 }
1058 return info_dict
1059
1060 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001061 def _getZeroPoint(qinfo, index):
1062 """Return zero point value from quantization info.
1063
1064 Generally input_zp is index 0, output_zp is index 1
1065 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001066 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001067
1068 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001069 def evInputZeroPointNotZero(check=False, **kwargs):
1070 op = kwargs["op"]
1071 error_result = False
1072
1073 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001074 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001075
1076 # This does not apply to quantizable types
1077 inputDtypes = [
1078 dtype
1079 for dtype in op["types"]
1080 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1081 or (not isinstance(dtype, list) and dtype not in qTypes)
1082 ]
1083
1084 if check:
1085 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001086 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001087 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001088 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001089 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001090 (kwargs["input_dtype"], input_zero_point),
1091 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001092 ):
1093 if dtype not in qTypes and zp != 0:
1094 error_result = True
1095 break
1096 else:
1097 error_result = input_dtype not in qTypes and input_zero_point != 0
1098
1099 info_dict = {
1100 "error_name": ErrorIf.InputZeroPointNotZero,
1101 "error_result": error_result,
1102 "error_reason": "Input DType not INT8 and zero point not 0",
1103 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1104 }
1105 return info_dict
1106
1107 @staticmethod
1108 def evWeightZeroPointNotZero(check=False, **kwargs):
1109 op = kwargs["op"]
1110
1111 # exclude inputs with INT8 weights
1112 inputDtypes = [
1113 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1114 ]
1115
1116 error_name = ErrorIf.WeightZeroPointNotZero
1117 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1118 error_result = False
1119 error_reason = "Weight DType not INT8 and zero point not 0"
1120
1121 if check:
1122 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001123 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001124 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1125 error_result = True
1126
1127 info_dict = {
1128 "error_name": error_name,
1129 "error_result": error_result,
1130 "error_reason": error_reason,
1131 "param_reqs": param_reqs,
1132 }
1133 return info_dict
1134
1135 @staticmethod
1136 def evOutputZeroPointNotZero(check=False, **kwargs):
1137 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001138 inputDtypes = [
1139 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1140 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001141
1142 error_name = ErrorIf.OutputZeroPointNotZero
1143 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1144 error_result = False
1145 error_reason = "Output DType not INT8 and zero point not 0"
1146
1147 if check:
1148 input_dtype = kwargs["input_dtype"]
1149 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001150 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001151 if op["op"] == Op.AVG_POOL2D:
1152 if input_dtype != DType.INT8 and output_zero_point != 0:
1153 error_result = True
1154 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001155 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1156 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001157 ):
1158 error_result = True
1159
1160 info_dict = {
1161 "error_name": error_name,
1162 "error_result": error_result,
1163 "error_reason": error_reason,
1164 "param_reqs": param_reqs,
1165 }
1166 return info_dict
1167
1168 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001169 def evU16InputZeroPointNotValid(check=False, **kwargs):
1170 error_name = ErrorIf.U16InputZeroPointNotValid
1171 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1172 error_result = False
1173 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1174
1175 if check:
1176 input_dtype = kwargs["input_dtype"]
1177 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1178 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1179 0,
1180 32768,
1181 ]
1182
1183 info_dict = {
1184 "error_name": error_name,
1185 "error_result": error_result,
1186 "error_reason": error_reason,
1187 "param_reqs": param_reqs,
1188 }
1189 return info_dict
1190
1191 @staticmethod
1192 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1193 error_name = ErrorIf.U16OutputZeroPointNotValid
1194 param_reqs = {"rank": None, "dtype": None, "shape": None}
1195 error_result = False
1196 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1197
1198 if check:
1199 output_dtype = kwargs["output_dtype"]
1200 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1201
1202 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1203 0,
1204 32768,
1205 ]
1206
1207 info_dict = {
1208 "error_name": error_name,
1209 "error_result": error_result,
1210 "error_reason": error_reason,
1211 "param_reqs": param_reqs,
1212 }
1213 return info_dict
1214
1215 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001216 def evAxisSmallerZero(check=False, **kwargs):
1217 error_name = ErrorIf.AxisSmallerZero
1218 param_reqs = {"rank": None, "dtype": None, "shape": None}
1219 error_result = False
1220 error_reason = "Axis smaller than zero"
1221
1222 if check:
1223 axis = kwargs["axis"]
1224 if axis < 0:
1225 error_result = True
1226
1227 info_dict = {
1228 "error_name": error_name,
1229 "error_result": error_result,
1230 "error_reason": error_reason,
1231 "param_reqs": param_reqs,
1232 }
1233 return info_dict
1234
1235 @staticmethod
1236 def evAxisLargerRank(check=False, **kwargs):
1237 error_name = ErrorIf.AxisLargerRank
1238 param_reqs = {"rank": None, "dtype": None, "shape": None}
1239 error_result = False
1240 error_reason = "Axis larger than rank"
1241
1242 if check:
1243 axis = kwargs["axis"]
1244 shape = kwargs["input_shape"]
1245 if axis > len(shape):
1246 error_result = True
1247
1248 info_dict = {
1249 "error_name": error_name,
1250 "error_result": error_result,
1251 "error_reason": error_reason,
1252 "param_reqs": param_reqs,
1253 }
1254 return info_dict
1255
1256 @staticmethod
1257 def evShapeOfAxisNotOne(check=False, **kwargs):
1258 error_name = ErrorIf.ShapeOfAxisNotOne
1259 param_reqs = {"rank": None, "dtype": None, "shape": None}
1260 error_result = False
1261 error_reason = "shape[axis] is not equal to 1"
1262
1263 if check:
1264 axis = kwargs["axis"]
1265 shape = kwargs["output_shape"]
1266 if (0 <= axis < len(shape)) and shape[axis] != 1:
1267 error_result = True
1268
1269 info_dict = {
1270 "error_name": error_name,
1271 "error_result": error_result,
1272 "error_reason": error_reason,
1273 "param_reqs": param_reqs,
1274 }
1275 return info_dict
1276
1277 @staticmethod
1278 def evPadSmallerZero(check=False, **kwargs):
1279 error_name = ErrorIf.PadSmallerZero
1280 param_reqs = {"rank": None, "dtype": None, "shape": None}
1281 error_result = False
1282 error_reason = "At least one pad is smaller than zero"
1283
1284 if check:
1285 op = kwargs["op"]
1286 pad = kwargs["pad"]
1287 if op["op"] == Op.PAD:
1288 for padding in pad:
1289 if min(padding) < 0:
1290 error_result = True
1291 else:
1292 if min(pad) < 0:
1293 error_result = True
1294
1295 info_dict = {
1296 "error_name": error_name,
1297 "error_result": error_result,
1298 "error_reason": error_reason,
1299 "param_reqs": param_reqs,
1300 }
1301 return info_dict
1302
1303 @staticmethod
1304 def evPadLargerEqualKernel(check=False, **kwargs):
1305 error_name = ErrorIf.PadLargerEqualKernel
1306 param_reqs = {"rank": None, "dtype": None, "shape": None}
1307 error_result = False
1308 error_reason = "At least one pad is larger than kernel dimension"
1309
1310 if check:
1311 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001312 op = kwargs["op"]
1313 if op["op"] == Op.TRANSPOSE_CONV2D:
1314 # transpose_conv2d
1315 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001316 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001317 pad[0] <= -kernel[0]
1318 or pad[1] <= -kernel[0]
1319 or pad[2] <= -kernel[1]
1320 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001321 ):
1322 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001323 else:
1324 # pooling op
1325 kernel = kwargs["kernel"]
1326 if min(pad) > 0 and min(kernel) > 1:
1327 if (
1328 pad[0] >= kernel[0]
1329 or pad[1] >= kernel[0]
1330 or pad[2] >= kernel[1]
1331 or pad[3] >= kernel[1]
1332 ):
1333 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001334
1335 info_dict = {
1336 "error_name": error_name,
1337 "error_result": error_result,
1338 "error_reason": error_reason,
1339 "param_reqs": param_reqs,
1340 }
1341 return info_dict
1342
1343 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001344 def evPadOutputShapeMismatch(check=False, **kwargs):
1345 error_name = ErrorIf.PadOutputShapeMismatch
1346 param_reqs = {"rank": None, "dtype": None, "shape": None}
1347 error_result = False
1348 error_reason = "Pad output shape mismatch for requested padding"
1349
1350 if check:
1351 pad = kwargs["pad"]
1352 input_shape = kwargs["input_shape"]
1353 output_shape = kwargs["output_shape"]
1354 for dim, padding in enumerate(pad):
1355 expected_size = input_shape[dim] + padding[0] + padding[1]
1356 if expected_size != output_shape[dim]:
1357 error_result = True
1358
1359 info_dict = {
1360 "error_name": error_name,
1361 "error_result": error_result,
1362 "error_reason": error_reason,
1363 "param_reqs": param_reqs,
1364 }
1365 return info_dict
1366
1367 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001368 def checkPoolingParams(kernel, stride, pad):
1369 return (
1370 min(kernel) >= 1
1371 and min(stride) >= 1
1372 and min(pad) >= 0
1373 and not (
1374 pad[0] >= kernel[0]
1375 or pad[1] >= kernel[0]
1376 or pad[2] >= kernel[1]
1377 or pad[3] >= kernel[1]
1378 )
1379 )
1380
1381 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001382 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1383 error_name = ErrorIf.PoolingOutputShapeMismatch
1384 param_reqs = {"rank": None, "dtype": None, "shape": None}
1385 error_result = False
1386 error_reason = (
1387 "Mismatch between output shape provided and expected output shape"
1388 )
1389
1390 if check:
1391 pad = kwargs["pad"]
1392 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1393
1394 kernel = kwargs["kernel"]
1395 kernel_y, kernel_x = kernel[0], kernel[1]
1396
1397 input_shape = kwargs["input_shape"]
1398 IH, IW = input_shape[1], input_shape[2]
1399
1400 output_shape = kwargs["output_shape"]
1401 OH, OW = output_shape[1], output_shape[2]
1402
1403 stride = kwargs["stride"]
1404 stride_y, stride_x = stride[0], stride[1]
1405
1406 # calculate correct height, width dimensions
1407 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001408 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1409 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001410
1411 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001412 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001413
1414 if params_valid and (OH != y_correct or OW != x_correct):
1415 error_result = True
1416
1417 info_dict = {
1418 "error_name": error_name,
1419 "error_result": error_result,
1420 "error_reason": error_reason,
1421 "param_reqs": param_reqs,
1422 }
1423 return info_dict
1424
1425 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001426 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1427 error_name = ErrorIf.PoolingOutputShapeNonInteger
1428 param_reqs = {"rank": None, "dtype": None, "shape": None}
1429 error_result = False
1430 error_reason = "Parameters do not yield exact integer output dimensions"
1431
1432 if check:
1433 pad = kwargs["pad"]
1434 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1435
1436 kernel = kwargs["kernel"]
1437 kernel_y, kernel_x = kernel[0], kernel[1]
1438
1439 input_shape = kwargs["input_shape"]
1440 IH, IW = input_shape[1], input_shape[2]
1441
1442 stride = kwargs["stride"]
1443 stride_y, stride_x = stride[0], stride[1]
1444
1445 # calculate remainder of height, width dimensions
1446 if stride_x != 0 and stride_y != 0:
1447 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1448 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1449
1450 # ensure parameters are valid
1451 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1452 if params_valid and (y_remainder != 0 or x_remainder != 0):
1453 error_result = True
1454
1455 info_dict = {
1456 "error_name": error_name,
1457 "error_result": error_result,
1458 "error_reason": error_reason,
1459 "param_reqs": param_reqs,
1460 }
1461 return info_dict
1462
1463 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001464 def checkConvParams(op, weight_shape, stride, pad, dilation):
1465 if op == Op.TRANSPOSE_CONV2D:
1466 pad_ok = (
1467 pad[0] > -weight_shape[1]
1468 and pad[1] > -weight_shape[1]
1469 and pad[2] > -weight_shape[2]
1470 and pad[3] > -weight_shape[2]
1471 )
1472 else:
1473 pad_ok = min(pad) >= 0
1474
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001475 return (
1476 # Check kernel sizes
1477 min(weight_shape[1:-1]) >= 1
1478 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001479 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001480 and (dilation is None or min(dilation) >= 1)
1481 )
1482
1483 @staticmethod
1484 def evConvOutputShapeMismatch(check=False, **kwargs):
1485 error_name = ErrorIf.ConvOutputShapeMismatch
1486 param_reqs = {"rank": None, "dtype": None, "shape": None}
1487 error_result = False
1488 error_reason = (
1489 "Mismatch between output shape provided and expected output shape"
1490 )
1491
1492 if check:
1493 op = kwargs["op"]
1494 pad = kwargs["pad"]
1495 weight_shape = kwargs["weight_shape"]
1496 input_shape = kwargs["input_shape"]
1497 output_shape = kwargs["output_shape"]
1498 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1499 stride = kwargs["stride"]
1500
1501 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1502
1503 # calculate correct dimensions
1504 dims_correct = []
1505 if min(stride) > 0:
1506 for index in range(len(stride)):
1507 pad_offset = index * 2
1508 if op["op"] == Op.TRANSPOSE_CONV2D:
1509 dims_correct.append(
1510 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001511 + pad[pad_offset]
1512 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001513 + weight_shape[index + kernel_offset]
1514 )
1515 else:
1516 dims_correct.append(
1517 (
1518 input_shape[index + 1]
1519 - 1
1520 + pad[pad_offset]
1521 + pad[pad_offset + 1]
1522 - (weight_shape[index + kernel_offset] - 1)
1523 * dilation[index]
1524 )
1525 // stride[index]
1526 + 1
1527 )
1528
1529 # ensure parameters are valid
1530 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001531 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001532 )
1533
1534 if params_valid and output_shape[1:-1] != dims_correct:
1535 error_result = True
1536
1537 info_dict = {
1538 "error_name": error_name,
1539 "error_result": error_result,
1540 "error_reason": error_reason,
1541 "param_reqs": param_reqs,
1542 }
1543 return info_dict
1544
1545 @staticmethod
1546 def evConvOutputShapeNonInteger(check=False, **kwargs):
1547 error_name = ErrorIf.ConvOutputShapeNonInteger
1548 param_reqs = {"rank": None, "dtype": None, "shape": None}
1549 error_result = False
1550 error_reason = "Parameters do not yield exact integer output dimensions"
1551
1552 if check:
1553 op = kwargs["op"]
1554 pad = kwargs["pad"]
1555 weight_shape = kwargs["weight_shape"]
1556 input_shape = kwargs["input_shape"]
1557 dilation = kwargs["dilation"]
1558 stride = kwargs["stride"]
1559
1560 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1561
1562 # calculate correct height, width dimensions
1563 remainders = []
1564 if min(stride) > 0:
1565 for index in range(len(stride)):
1566 pad_offset = index * 2
1567 remainders.append(
1568 (
1569 input_shape[index + 1]
1570 - 1
1571 + pad[pad_offset]
1572 + pad[pad_offset + 1]
1573 - (weight_shape[index + kernel_offset] - 1)
1574 * dilation[index]
1575 )
1576 % stride[index]
1577 )
1578
1579 # ensure parameters are valid
1580 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001581 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001582 )
1583 if params_valid and max(remainders) > 0:
1584 error_result = True
1585
1586 info_dict = {
1587 "error_name": error_name,
1588 "error_result": error_result,
1589 "error_reason": error_reason,
1590 "param_reqs": param_reqs,
1591 }
1592 return info_dict
1593
1594 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001595 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1596 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1597 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1598 error_result = False
1599 error_reason = (
1600 "Mismatch between output shape provided and expected output shape"
1601 )
1602
1603 if check:
1604 output_shape = kwargs["output_shape"]
1605 input_shape = kwargs["input_shape"]
1606 axis = kwargs["axis"]
1607
1608 dimension_match = True
1609 axis_shift = 0
1610
1611 # Check that rank is correct before trying to check dimensions
1612 if (len(input_shape) - 1) == len(output_shape):
1613 for i in range(len(input_shape)):
1614 if i == axis:
1615 axis_shift = 1
1616 continue
1617 if input_shape[i] != output_shape[i - axis_shift]:
1618 dimension_match = False
1619
1620 if not dimension_match:
1621 error_result = True
1622
1623 info_dict = {
1624 "error_name": error_name,
1625 "error_result": error_result,
1626 "error_reason": error_reason,
1627 "param_reqs": param_reqs,
1628 }
1629 return info_dict
1630
1631 @staticmethod
1632 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1633 error_name = ErrorIf.ArgmaxOutputRankMismatch
1634 param_reqs = {"rank": None, "dtype": None, "shape": None}
1635 error_result = False
1636 error_reason = (
1637 "Mismatch between output shape provided and expected output shape"
1638 )
1639
1640 if check:
1641 output_shape = kwargs["output_shape"]
1642 input_shape = kwargs["input_shape"]
1643 axis = kwargs["axis"]
1644 valid_params = axis >= 0 and axis < len(input_shape)
1645
1646 if valid_params and (len(input_shape) - 1) != len(output_shape):
1647 error_result = True
1648
1649 info_dict = {
1650 "error_name": error_name,
1651 "error_result": error_result,
1652 "error_reason": error_reason,
1653 "param_reqs": param_reqs,
1654 }
1655 return info_dict
1656
1657 @staticmethod
1658 def evKernelSmallerOne(check=False, **kwargs):
1659 error_name = ErrorIf.KernelSmallerOne
1660 param_reqs = {"rank": None, "dtype": None, "shape": None}
1661 error_result = False
1662 error_reason = "At least one kernel dimension is smaller than zero"
1663
1664 if check:
1665 kernel = kwargs["kernel"]
1666 if min(kernel) < 1:
1667 error_result = True
1668
1669 info_dict = {
1670 "error_name": error_name,
1671 "error_result": error_result,
1672 "error_reason": error_reason,
1673 "param_reqs": param_reqs,
1674 }
1675 return info_dict
1676
1677 @staticmethod
1678 def evStrideSmallerOne(check=False, **kwargs):
1679 error_name = ErrorIf.StrideSmallerOne
1680 param_reqs = {"rank": None, "dtype": None, "shape": None}
1681 error_result = False
1682 error_reason = "At least one stride dimension is smaller than zero"
1683
1684 if check:
1685 stride = kwargs["stride"]
1686 if min(stride) < 1:
1687 error_result = True
1688
1689 info_dict = {
1690 "error_name": error_name,
1691 "error_result": error_result,
1692 "error_reason": error_reason,
1693 "param_reqs": param_reqs,
1694 }
1695 return info_dict
1696
1697 @staticmethod
1698 def evDilationSmallerOne(check=False, **kwargs):
1699 error_result = check and min(kwargs["dilation"]) < 1
1700 return {
1701 "error_name": ErrorIf.DilationSmallerOne,
1702 "error_reason": "At least one dilation is smaller than one",
1703 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1704 "error_result": error_result,
1705 }
1706
1707 @staticmethod
1708 def evScaleTrue(check=False, **kwargs):
1709 error_name = ErrorIf.ScaleTrue
1710 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1711 error_result = False
1712 error_reason = "Scale set to true but input type is INT48"
1713
1714 if check:
1715 input_dtype = kwargs["input_dtype"]
1716 scale32 = kwargs["scale32"]
1717 if scale32 and input_dtype == DType.INT48:
1718 error_result = True
1719
1720 info_dict = {
1721 "error_name": error_name,
1722 "error_result": error_result,
1723 "error_reason": error_reason,
1724 "param_reqs": param_reqs,
1725 }
1726 return info_dict
1727
1728 @staticmethod
1729 def evScaleNotTrue(check=False, **kwargs):
1730 error_name = ErrorIf.ScaleNotTrue
1731 param_reqs = {"rank": None, "dtype": None, "shape": None}
1732 error_result = False
1733 error_reason = "Scale set to false but double round set to true"
1734
1735 if check:
1736 scale32 = kwargs["scale32"]
1737 double_round = kwargs["double_round"]
1738 if not scale32 and double_round:
1739 error_result = True
1740
1741 info_dict = {
1742 "error_name": error_name,
1743 "error_result": error_result,
1744 "error_reason": error_reason,
1745 "param_reqs": param_reqs,
1746 }
1747 return info_dict
1748
1749 @staticmethod
1750 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1751 error_name = ErrorIf.TensorSizeInputOutputMismatch
1752 param_reqs = {"rank": None, "dtype": None, "shape": None}
1753 error_result = False
1754 error_reason = "Input tensor size does not match output tensor size"
1755
1756 if check:
1757 input_shape = kwargs["input_shape"]
1758 output_shape = kwargs["output_shape"]
1759 input_size = np.prod(input_shape)
1760 output_size = np.prod(output_shape)
1761 if input_size != output_size:
1762 error_result = True
1763
1764 info_dict = {
1765 "error_name": error_name,
1766 "error_result": error_result,
1767 "error_reason": error_reason,
1768 "param_reqs": param_reqs,
1769 }
1770 return info_dict
1771
1772 @staticmethod
1773 def evStartSmallerZero(check=False, **kwargs):
1774 error_name = ErrorIf.StartSmallerZero
1775 param_reqs = {"rank": None, "dtype": None, "shape": None}
1776 error_result = False
1777 error_reason = "Starting point smaller than zero"
1778
1779 if check:
1780 input_shape = kwargs["input_shape"]
1781 start = kwargs["start"]
1782 rank = len(input_shape)
1783 if len(start) == rank:
1784 for index in range(rank):
1785 if start[index] < 0:
1786 error_result = True
1787
1788 info_dict = {
1789 "error_name": error_name,
1790 "error_result": error_result,
1791 "error_reason": error_reason,
1792 "param_reqs": param_reqs,
1793 }
1794 return info_dict
1795
1796 @staticmethod
1797 def evSizeSmallerEqualZero(check=False, **kwargs):
1798 error_name = ErrorIf.SizeSmallerEqualZero
1799 param_reqs = {"rank": None, "dtype": None, "shape": None}
1800 error_result = False
1801 error_reason = "Size smaller than or equal to zero"
1802
1803 if check:
1804 input_shape = kwargs["input_shape"]
1805 size = kwargs["size"]
1806 rank = len(input_shape)
1807 if len(size) == rank:
1808 for index in range(rank):
1809 if size[index] <= 0:
1810 error_result = True
1811
1812 info_dict = {
1813 "error_name": error_name,
1814 "error_result": error_result,
1815 "error_reason": error_reason,
1816 "param_reqs": param_reqs,
1817 }
1818 return info_dict
1819
1820 @staticmethod
1821 def evStartSizeOutsideBounds(check=False, **kwargs):
1822 error_name = ErrorIf.StartSizeOutsideBounds
1823 param_reqs = {"rank": None, "dtype": None, "shape": None}
1824 error_result = False
1825 error_reason = "starting point plus size larger than input dimension"
1826
1827 if check:
1828 input_shape = kwargs["input_shape"]
1829 start = kwargs["start"]
1830 size = kwargs["size"]
1831 rank = len(input_shape)
1832 if len(start) == rank and len(size) == rank:
1833 for index in range(rank):
1834 if start[index] + size[index] > input_shape[index]:
1835 error_result = True
1836
1837 info_dict = {
1838 "error_name": error_name,
1839 "error_result": error_result,
1840 "error_reason": error_reason,
1841 "param_reqs": param_reqs,
1842 }
1843 return info_dict
1844
1845 @staticmethod
1846 def evSizeOutputShapeMismatch(check=False, **kwargs):
1847 error_name = ErrorIf.SizeOutputShapeMismatch
1848 param_reqs = {"rank": None, "dtype": None, "shape": None}
1849 error_result = False
1850 error_reason = "Size does not match output dimension"
1851
1852 if check:
1853 input_shape = kwargs["input_shape"]
1854 output_shape = kwargs["output_shape"]
1855 size = kwargs["size"]
1856 rank = len(input_shape)
1857 if len(size) == rank:
1858 for index in range(rank):
1859 if size[index] != output_shape[index]:
1860 error_result = True
1861
1862 info_dict = {
1863 "error_name": error_name,
1864 "error_result": error_result,
1865 "error_reason": error_reason,
1866 "param_reqs": param_reqs,
1867 }
1868 return info_dict
1869
1870 @staticmethod
1871 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1872 error_name = ErrorIf.InputSizeStartLengthMismatch
1873 param_reqs = {"rank": None, "dtype": None, "shape": None}
1874 error_result = False
1875 error_reason = "rank of input not equal to length of start or size"
1876
1877 if check:
1878 input_shape = kwargs["input_shape"]
1879 start = kwargs["start"]
1880 size = kwargs["size"]
1881 rank = len(input_shape)
1882 if rank != len(start) or rank != len(size):
1883 error_result = True
1884
1885 info_dict = {
1886 "error_name": error_name,
1887 "error_result": error_result,
1888 "error_reason": error_reason,
1889 "param_reqs": param_reqs,
1890 }
1891 return info_dict
1892
1893 @staticmethod
1894 def evIndexOutsideBounds(check=False, **kwargs):
1895 error_name = ErrorIf.IndexOutsideBounds
1896 param_reqs = {"rank": None, "dtype": None, "shape": None}
1897 error_result = False
1898 error_reason = "Index outside of allowed bounds"
1899
1900 if check:
1901 input_shape = kwargs["input_shape"]
1902 perms = kwargs["perms"]
1903 rank = len(input_shape)
1904
1905 for index in perms:
1906 if index < 0 or index > rank:
1907 error_result = True
1908
1909 info_dict = {
1910 "error_name": error_name,
1911 "error_result": error_result,
1912 "error_reason": error_reason,
1913 "param_reqs": param_reqs,
1914 }
1915 return info_dict
1916
1917 @staticmethod
1918 def evIndexUsedTwice(check=False, **kwargs):
1919 error_name = ErrorIf.IndexUsedTwice
1920 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1921 error_result = False
1922 error_reason = "Index used multiple times"
1923
1924 if check:
1925 perms = kwargs["perms"]
1926
1927 unique_indices = []
1928 for index in perms:
1929 if index in unique_indices:
1930 error_result = True
1931 else:
1932 unique_indices.append(index)
1933
1934 info_dict = {
1935 "error_name": error_name,
1936 "error_result": error_result,
1937 "error_reason": error_reason,
1938 "param_reqs": param_reqs,
1939 }
1940 return info_dict
1941
1942 @staticmethod
1943 def evMaxSmallerMin(check=False, **kwargs):
1944 error_name = ErrorIf.MaxSmallerMin
1945 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1946 error_result = False
1947 error_reason = "Max value smaller than min value"
1948
1949 if check:
1950 max_val = kwargs["max_val"]
1951 min_val = kwargs["min_val"]
1952 if max_val < min_val:
1953 error_result = True
1954
1955 info_dict = {
1956 "error_name": error_name,
1957 "error_result": error_result,
1958 "error_reason": error_reason,
1959 "param_reqs": param_reqs,
1960 }
1961 return info_dict
1962
1963 @staticmethod
1964 def evConcatInputRankMismatch(check=False, **kwargs):
1965 error_name = ErrorIf.ConcatInputRankMismatch
1966 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1967 error_result = False
1968 error_reason = "Input ranks are not identical"
1969
1970 if check:
1971 inputs = kwargs["inputs"]
1972 input_shape = kwargs["input_shape"]
1973 for input in inputs:
1974 if len(input.shape) != len(input_shape):
1975 error_result = True
1976
1977 info_dict = {
1978 "error_name": error_name,
1979 "error_result": error_result,
1980 "error_reason": error_reason,
1981 "param_reqs": param_reqs,
1982 }
1983 return info_dict
1984
1985 @staticmethod
1986 def evConcatInputDimMismatch(check=False, **kwargs):
1987 error_name = ErrorIf.ConcatInputDimMismatch
1988 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1989 error_result = False
1990 error_reason = "Input dimensions differ on too many axes"
1991
1992 if check:
1993 inputs = kwargs["inputs"]
1994 input_shape = kwargs["input_shape"]
1995 axis = kwargs["axis"]
1996
1997 # Ensure rank is valid before checking dims.
1998 valid_rank = True
1999 for input in inputs:
2000 if len(input.shape) != len(input_shape):
2001 valid_rank = False
2002
2003 if valid_rank:
2004 for input in inputs:
2005 for i, dim in enumerate(input.shape):
2006 if dim != input_shape[i] and axis != i:
2007 error_result = True
2008
2009 info_dict = {
2010 "error_name": error_name,
2011 "error_result": error_result,
2012 "error_reason": error_reason,
2013 "param_reqs": param_reqs,
2014 }
2015 return info_dict
2016
2017 @staticmethod
2018 def evConcatShapeSumMismatch(check=False, **kwargs):
2019 error_name = ErrorIf.ConcatShapeSumMismatch
2020 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2021 error_result = False
2022 error_reason = "Sum of dimensions on axis not equal to output dimension"
2023
2024 if check:
2025 inputs = kwargs["inputs"]
2026 input_shape = kwargs["input_shape"]
2027 output_shape = kwargs["output_shape"]
2028 axis = kwargs["axis"]
2029
2030 # Ensure rank is valid before checking dims.
2031 valid_params = True
2032 for input in inputs:
2033 if len(input.shape) != len(input_shape):
2034 valid_params = False
2035 if axis < 0 or axis > len(input_shape):
2036 valid_params = False
2037
2038 if valid_params:
2039 axis_dim_sum = 0
2040 for input in inputs:
2041 axis_dim_sum += input.shape[axis]
2042
2043 if axis_dim_sum != output_shape[axis]:
2044 error_result = True
2045
2046 info_dict = {
2047 "error_name": error_name,
2048 "error_result": error_result,
2049 "error_reason": error_reason,
2050 "param_reqs": param_reqs,
2051 }
2052 return info_dict
2053
2054 @staticmethod
2055 def evInputListThenGraphMismatch(check=False, **kwargs):
2056 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2057 param_reqs = {"rank": None, "dtype": None, "shape": None}
2058 error_result = False
2059 error_reason = "Input list shape does not match then-graph shape"
2060
2061 if check:
2062 a = kwargs["a"]
2063 b = kwargs["b"]
2064 basicBlocks = kwargs["basicBlocks"]
2065 then_block = basicBlocks[1]
2066 then_inputs = then_block.inputs
2067 then_tens = then_block.tensors
2068 if (a.shape != then_tens[then_inputs[0]].shape) or (
2069 b.shape != then_tens[then_inputs[1]].shape
2070 ):
2071 error_result = True
2072
2073 info_dict = {
2074 "error_name": error_name,
2075 "error_result": error_result,
2076 "error_reason": error_reason,
2077 "param_reqs": param_reqs,
2078 }
2079 return info_dict
2080
2081 @staticmethod
2082 def evInputListElseGraphMismatch(check=False, **kwargs):
2083 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2084 param_reqs = {"rank": None, "dtype": None, "shape": None}
2085 error_result = False
2086 error_reason = "Input list shape does not match else-graph shape"
2087
2088 if check:
2089 a = kwargs["a"]
2090 b = kwargs["b"]
2091 basicBlocks = kwargs["basicBlocks"]
2092 else_block = basicBlocks[2]
2093 else_inputs = else_block.inputs
2094 else_tens = else_block.tensors
2095 if (a.shape != else_tens[else_inputs[0]].shape) or (
2096 b.shape != else_tens[else_inputs[1]].shape
2097 ):
2098 error_result = True
2099
2100 info_dict = {
2101 "error_name": error_name,
2102 "error_result": error_result,
2103 "error_reason": error_reason,
2104 "param_reqs": param_reqs,
2105 }
2106 return info_dict
2107
2108 @staticmethod
2109 def evOutputListThenGraphMismatch(check=False, **kwargs):
2110 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2111 param_reqs = {"rank": None, "dtype": None, "shape": None}
2112 error_result = False
2113 error_reason = "Output list shape does not match then-graph shape"
2114
2115 if check:
2116 basicBlocks = kwargs["basicBlocks"]
2117 cond_block = basicBlocks[0]
2118 cond_outputs = cond_block.outputs
2119 cond_tens = cond_block.tensors
2120 then_block = basicBlocks[1]
2121 then_outputs = then_block.outputs
2122 then_tens = then_block.tensors
2123 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2124 error_result = True
2125
2126 info_dict = {
2127 "error_name": error_name,
2128 "error_result": error_result,
2129 "error_reason": error_reason,
2130 "param_reqs": param_reqs,
2131 }
2132 return info_dict
2133
2134 @staticmethod
2135 def evOutputListElseGraphMismatch(check=False, **kwargs):
2136 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2137 param_reqs = {"rank": None, "dtype": None, "shape": None}
2138 error_result = False
2139 error_reason = "Output list shape does not match else-graph shape"
2140
2141 if check:
2142 basicBlocks = kwargs["basicBlocks"]
2143 cond_block = basicBlocks[0]
2144 cond_outputs = cond_block.outputs
2145 cond_tens = cond_block.tensors
2146 else_block = basicBlocks[2]
2147 else_outputs = else_block.outputs
2148 else_tens = else_block.tensors
2149 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2150 error_result = True
2151
2152 info_dict = {
2153 "error_name": error_name,
2154 "error_result": error_result,
2155 "error_reason": error_reason,
2156 "param_reqs": param_reqs,
2157 }
2158 return info_dict
2159
2160 @staticmethod
2161 def evInputListOutputListMismatch(check=False, **kwargs):
2162 error_name = ErrorIf.InputListOutputListMismatch
2163 param_reqs = {"rank": None, "dtype": None, "shape": None}
2164 error_result = False
2165 error_reason = "Input list does not match output list"
2166
2167 if check:
2168 basicBlocks = kwargs["basicBlocks"]
2169 while_block = basicBlocks[0]
2170 while_inputs = while_block.inputs
2171 while_outputs = while_block.outputs
2172 while_tens = while_block.tensors
2173 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2174 error_result = True
2175
2176 info_dict = {
2177 "error_name": error_name,
2178 "error_result": error_result,
2179 "error_reason": error_reason,
2180 "param_reqs": param_reqs,
2181 }
2182 return info_dict
2183
2184 @staticmethod
2185 def evInputListCondGraphMismatch(check=False, **kwargs):
2186 error_name = ErrorIf.InputListCondGraphMismatch
2187 param_reqs = {"rank": None, "dtype": None, "shape": None}
2188 error_result = False
2189 error_reason = "Input list does not match cond graph"
2190
2191 if check:
2192 basicBlocks = kwargs["basicBlocks"]
2193 while_block = basicBlocks[0]
2194 while_inputs = while_block.inputs
2195 while_tens = while_block.tensors
2196 cond_block = basicBlocks[1]
2197 cond_inputs = cond_block.inputs
2198 cond_tens = cond_block.tensors
2199 if (
2200 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2201 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2202 error_result = True
2203
2204 info_dict = {
2205 "error_name": error_name,
2206 "error_result": error_result,
2207 "error_reason": error_reason,
2208 "param_reqs": param_reqs,
2209 }
2210 return info_dict
2211
2212 @staticmethod
2213 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2214 error_name = ErrorIf.InputListBodyGraphInputMismatch
2215 param_reqs = {"rank": None, "dtype": None, "shape": None}
2216 error_result = False
2217 error_reason = "Input list does not match body graph input"
2218
2219 if check:
2220 basicBlocks = kwargs["basicBlocks"]
2221 while_block = basicBlocks[0]
2222 while_inputs = while_block.inputs
2223 while_tens = while_block.tensors
2224 body_block = basicBlocks[2]
2225 body_outputs = body_block.inputs
2226 body_tens = body_block.tensors
2227 if (
2228 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2229 ) or (
2230 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2231 ):
2232 error_result = True
2233
2234 info_dict = {
2235 "error_name": error_name,
2236 "error_result": error_result,
2237 "error_reason": error_reason,
2238 "param_reqs": param_reqs,
2239 }
2240 return info_dict
2241
2242 @staticmethod
2243 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2244 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2245 param_reqs = {"rank": None, "dtype": None, "shape": None}
2246 error_result = False
2247 error_reason = "Input list does not match body graph output"
2248
2249 if check:
2250 basicBlocks = kwargs["basicBlocks"]
2251 while_block = basicBlocks[0]
2252 while_inputs = while_block.inputs
2253 while_tens = while_block.tensors
2254 body_block = basicBlocks[2]
2255 body_outputs = body_block.outputs
2256 body_tens = body_block.tensors
2257 if (
2258 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2259 ) or (
2260 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2261 ):
2262 error_result = True
2263 info_dict = {
2264 "error_name": error_name,
2265 "error_result": error_result,
2266 "error_reason": error_reason,
2267 "param_reqs": param_reqs,
2268 }
2269 return info_dict
2270
2271 @staticmethod
2272 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2273 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2274 param_reqs = {"rank": None, "dtype": None, "shape": None}
2275 error_result = False
2276 error_reason = "Cond graph output is not a match list of booleans"
2277
2278 if check:
2279 basicBlocks = kwargs["basicBlocks"]
2280 cond_block = basicBlocks[1]
2281 cond_outputs = cond_block.outputs
2282 cond_tens = cond_block.tensors
2283 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2284 error_result = True
2285
2286 info_dict = {
2287 "error_name": error_name,
2288 "error_result": error_result,
2289 "error_reason": error_reason,
2290 "param_reqs": param_reqs,
2291 }
2292 return info_dict
2293
2294
2295class TosaInvalidValidator:
2296 @staticmethod
2297 def ivWrongDataTypeOrModeResize(**kwargs):
2298 input_dtype = kwargs["input_dtype"]
2299 args = kwargs["args"]
2300 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002301 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002302
2303 if mode == ResizeMode.BILINEAR:
2304 # Invalid output data type / Invalid input datatype
2305 return (
2306 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002307 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002308 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002309 and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002310 )
2311 elif mode == ResizeMode.NEAREST:
2312 # Invalid output data type / Invalid input datatype
2313 return (input_dtype != output_dtype) or (
James Ward8b390432022-08-12 20:48:56 +01002314 input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002315 )
2316 else:
2317 # Invalid resize mode
2318 return True
2319
2320 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002321 def ivHeightWidthInvalid(**kwargs):
2322 opName = kwargs["opName"]
2323
2324 inputShapes = kwargs["shapeList"]
2325 input_shape = inputShapes[0]
2326
2327 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002328
2329 # MaxPool2D has no accum_dtype arg
2330 stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2)
2331 strides = args[stride_idx]
2332 padding = args[pad_idx]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002333
2334 if opName.endswith("pool2d"):
2335 # avg_pool2d, max_pool2d
2336 kernel_shape = args[2]
2337 h = (
2338 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2339 ) // strides[0]
2340 w = (
2341 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2342 ) // strides[1]
2343 # return True if any dimension is < 1
2344 return h < 1 or w < 1
2345
2346 if opName.startswith("transpose_conv2d"):
2347 # transpose_conv2d
TatWai Chong24594f52022-06-08 00:48:04 -07002348 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002349 filter_shape = inputShapes[1]
2350 kernel_shape = filter_shape[1:-1]
2351
TatWai Chong24594f52022-06-08 00:48:04 -07002352 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002353 """Calculate the transpose_conv2d output size for a dimension.
2354
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002355 Args:
2356 in_size: the input size - int
2357 stride: the stride - int
2358 kernel_size: the kernel size - int
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002359 out_pad: the output padding - int
2360 in_pad: the input padding - int
2361
2362 Returns:
2363 the output size
2364 """
TatWai Chong24594f52022-06-08 00:48:04 -07002365 return (in_size - 1) * stride + kernel_size - in_pad - out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002366
2367 for pad_h, pad_w in (
2368 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2369 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2370 (0, 0), # VALID padding
2371 ):
2372 h = get_out_size(
2373 input_shape[1],
2374 strides[0],
2375 kernel_shape[0],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002376 padding[0],
2377 pad_h,
2378 )
2379 w = get_out_size(
2380 input_shape[2],
2381 strides[1],
2382 kernel_shape[1],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002383 padding[1],
2384 pad_w,
2385 )
2386 if output_shape[1] == h and output_shape[2] == w:
2387 return False
2388
2389 # output shape does not match the expected shape for any padding option
2390 return True
2391
2392 if "conv2d" in opName or "conv3d" in opName:
2393 # conv2d, conv3d, depthwise_conv2d
2394 dilations = args[2]
2395 filter_shape = inputShapes[1]
2396 kernel_shape = (
2397 filter_shape[0:2]
2398 if opName.startswith("depthwise_conv2d")
2399 else filter_shape[1:-1]
2400 )
2401
2402 for i in range(len(kernel_shape)):
2403 dim = (
2404 input_shape[i + 1]
2405 - kernel_shape[i]
2406 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2407 + padding[i * 2 + 0]
2408 + padding[i * 2 + 1]
2409 ) // strides[i] + 1
2410 # return True if any dimension is < 1
2411 if dim < 1:
2412 return True
2413 return False
2414
2415 assert False, f"Unrecognized Op: {opName}"
2416
2417 @staticmethod
2418 def ivNonPositiveOutputShape(**kwargs):
2419 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002420 output_shape = args[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002421 if output_shape[1] <= 0 or output_shape[2] <= 0:
2422 # Negative output shape
2423 return True
2424 return False