blob: a5a834fcebbacd428536a01eafef68edfe53a06e [file] [log] [blame]
Won Jeon74342e52024-01-09 00:34:40 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01006from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from generator.tosa_utils import product
8from generator.tosa_utils import usableDTypes
9from generator.tosa_utils import valueToName
10from tosa.DType import DType
11from tosa.Op import Op
12from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000013
Matthew Haddone86fd342021-09-07 16:12:21 +010014
15class ErrorIf(object):
16 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010017 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
18 ScaleNLargerMax = "ScaleNLargerMax"
19 ScaleDLargerMax = "ScaleDLargerMax"
20 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010021 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010022 BorderSmallerMin = "BorderSmallerMin"
23 BorderLargerEqualMax = "BorderLargerEqualMax"
24 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
25 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010026 WrongInputType = "WrongInputType"
27 WrongOutputType = "WrongOutputType"
28 WrongInputList = "WrongInputList"
29 WrongOutputList = "WrongOutputList"
30 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010031 BatchMismatch = "BatchMismatch"
32 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010033 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000034 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010037 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 AxisSmallerZero = "AxisSmallerZero"
39 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010040 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
41 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010042 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010043 KernelSmallerOne = "KernelSmallerOne"
44 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000045 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010046 PadSmallerZero = "PadSmallerZero"
47 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010048 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010049 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010050 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
51 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
52 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010053 ScaleNotTrue = "ScaleNotTrue"
54 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010055 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
56 StartSmallerZero = "StartSmallerZero"
57 SizeSmallerEqualZero = "SizeSmallerEqualZero"
58 StartSizeOutsideBounds = "StartSizeOutsideBounds"
59 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
60 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
61 IndexOutsideBounds = "IndexOutsideBounds"
62 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010063 MaxSmallerMin = "MaxSmallerMin"
64 ConcatInputRankMismatch = "ConcatInputRankMismatch"
65 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010066 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010067 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
68 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
69 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
70 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
71 InputListOutputListMismatch = "InputListOutputListMismatch"
72 InputListCondGraphMismatch = "InputListCondGraphMismatch"
73 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
74 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
75 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010076 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
77 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000078 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
79 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
80 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000081 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Luke Hutton57287132023-02-06 14:54:18 +000082 FFTInputShapeMismatch = "FFTInputShapeMismatch"
83 FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
Jerry Ge264f7fa2023-04-21 22:49:57 +000084 ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
85 ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
Jerry Ge135c9552023-05-23 20:59:32 +000086 BroadcastShapesMismatch = "BroadcastShapesMismatch"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087
88
89class TosaErrorIfArgGen:
90 @staticmethod
91 def eiResizeErrorIf(
92 testGen,
93 error_name,
94 mode,
95 dtype,
96 shapeList,
97 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010098 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010099 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100100 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100102 if error_name == ErrorIf.ScaleSmallerEqualZero:
103 index = testGen.randInt(low=0, high=4)
104 scale[index] = testGen.rng.choice([-2, -1, 0])
105 elif error_name == ErrorIf.ScaleNLargerMax:
106 index = testGen.rng.choice([0, 2])
107 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
108 elif error_name == ErrorIf.ScaleDLargerMax:
109 index = testGen.rng.choice([1, 3])
110 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100112 if error_name == ErrorIf.OffsetLargerEqualMax:
113 index = testGen.rng.choice([0, 1])
114 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
115 elif error_name == ErrorIf.OffsetSmallerMin:
116 index = testGen.rng.choice([0, 1])
117 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
118
119 if error_name == ErrorIf.BorderLargerEqualMax:
120 index = testGen.rng.choice([0, 1])
121 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
122 elif error_name == ErrorIf.BorderSmallerMin:
123 index = testGen.rng.choice([0, 1])
124 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100125
126 if error_name == ErrorIf.WrongOutputType:
127 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
128 incorrect_types = (
129 DType.INT4,
130 DType.INT16,
131 DType.INT32,
132 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100133 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100134 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100135 )
136 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
137 incorrect_types = (
138 DType.INT4,
139 DType.INT8,
140 DType.INT32,
141 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100142 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100143 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100144 )
145 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
146 incorrect_types = (
147 DType.INT4,
148 DType.INT8,
149 DType.INT16,
150 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100151 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100152 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100153 )
154 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
155 incorrect_types = (
156 DType.INT4,
157 DType.INT8,
158 DType.INT16,
159 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100160 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100161 DType.FP16,
162 )
163 elif dtype == DType.FP16:
164 incorrect_types = (
165 DType.INT4,
166 DType.INT8,
167 DType.INT16,
168 DType.INT32,
169 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100170 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100171 )
James Ward24dbc422022-10-19 12:20:31 +0100172 elif dtype == DType.BF16:
173 incorrect_types = (
174 DType.INT4,
175 DType.INT8,
176 DType.INT16,
177 DType.INT32,
178 DType.INT48,
179 DType.FP32,
180 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100181 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 incorrect_types = (
183 DType.INT4,
184 DType.INT8,
185 DType.INT16,
186 DType.INT32,
187 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100188 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100189 )
190 outputDType = testGen.rng.choice(a=incorrect_types)
191
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100192 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 @staticmethod
195 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
196 if (
197 error_name == ErrorIf.StrideSmallerOne
198 # padding must not exceed the kernel size
199 and pad[0] < kernel[0]
200 and pad[1] < kernel[0]
201 and pad[2] < kernel[1]
202 and pad[3] < kernel[1]
203 ):
204 wrongStride = (
205 testGen.rng.choice([0, -1, -2, -3]),
206 testGen.rng.choice([0, -1, -2, -3]),
207 )
208 return wrongStride, pad, kernel
209 elif error_name == ErrorIf.PadSmallerZero:
210 wrongPad = (
211 testGen.rng.choice([-1, -2, -3]),
212 testGen.rng.choice([-1, -2, -3]),
213 testGen.rng.choice([-1, -2, -3]),
214 testGen.rng.choice([-1, -2, -3]),
215 )
216 return stride, wrongPad, kernel
217 elif error_name == ErrorIf.KernelSmallerOne:
218 wrongKernel = (
219 testGen.rng.choice([0, -1, -2, -3]),
220 testGen.rng.choice([0, -1, -2, -3]),
221 )
222 return stride, pad, wrongKernel
223 elif error_name == ErrorIf.PadLargerEqualKernel:
224 wrongPad = (
225 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
226 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
227 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
228 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
229 )
230 return stride, wrongPad, kernel
231 else:
232 return None, None, None
233
234 @staticmethod
235 def eiRescaleWrongOutputType(input_dtype, output_dtype):
236 if input_dtype == DType.INT8:
237 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
238 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100239 elif input_dtype == DType.INT16:
240 if output_dtype not in [
241 DType.UINT8,
242 DType.INT8,
243 DType.UINT16,
244 DType.INT16,
245 DType.INT32,
246 ]:
247 return True
248 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
250 return True
251 elif input_dtype == DType.INT48:
252 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
253 return True
254 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100255 if output_dtype not in [DType.INT8, DType.INT16]:
256 return True
257 elif input_dtype == DType.UINT16:
258 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 return True
260 return False
261
262 @staticmethod
263 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
264 # Mess up input/output tensors for ERROR_IF checks
265 if error_name == "WrongInputList":
266 add_input = testGen.rng.choice([True, False])
267 if add_input:
268 input_list.append("eiDummyInput")
269 else:
270 input_list = input_list[:-1]
271 elif error_name == "WrongOutputList":
272 add_output = testGen.rng.choice([True, False])
273 if add_output:
274 output_list.append("eiDummyOutput")
275 else:
276 output_list = []
277 return input_list, output_list
278
279 @staticmethod
280 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
281 """Restrict the dimensions and overall size of a shape to
282 max_dim and max_items.
283 """
284 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
285 while product(new_shape) > max_items:
286 new_shape = [max(d - 1, 1) for d in new_shape]
287 return new_shape
288
289 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
290 if error_name == ErrorIf.StartSmallerZero:
291 newStart = []
292 for i in range(len(input_shape)):
293 newStart.append(testGen.rng.choice([-3, -2, -1]))
294 return newStart, size
295 elif error_name == ErrorIf.SizeSmallerEqualZero:
296 newSize = []
297 for i in range(len(input_shape)):
298 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
299 return start, newSize
300 elif error_name == ErrorIf.StartSizeOutsideBounds:
301 newStart, newSize = [], []
302 for i in range(len(input_shape)):
303 newStart.append(input_shape[i] - 1)
304 newSize.append(testGen.rng.choice([2, 3, 4]))
305 return newStart, newSize
306 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
307 remove = testGen.rng.choice([True, False])
308 if remove:
309 newStart = start[1:]
310 newSize = size[1:]
311 else:
312 newStart = start
313 newStart.append(1)
314 newSize = size
315 newSize.append(1)
316 return newStart, newSize
317 else:
318 return start, size
319
320 @staticmethod
321 def eiCastErrorIf(testGen, input_dtype):
James Ward736fd1a2023-01-23 17:13:37 +0000322 if input_dtype in [DType.BOOL, DType.FP32]:
323 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
324 elif input_dtype in [DType.FP16, DType.BF16]:
325 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
327 outputDType = [DType.INT48]
328 else:
James Ward736fd1a2023-01-23 17:13:37 +0000329 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330 return outputDType
331
332
333class TosaErrorValidator:
334 @staticmethod
335 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
336 """Check ERROR_IF statements are caught and set the expected result.
337
338 Args:
339 serializer: the serializer to set the expected result in
340 validator_fcns: a sequence of validator functions to verify the result
341 error_name: the name of the ERROR_IF condition to check for
342 kwargs: keyword arguments for the validator functions
343 Returns:
344 True if the result matches the expected result; otherwise False
345 """
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000346 if validator_fcns is None:
347 # Nothing to do
348 return True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 overall_result = True
350 for val_fcn in validator_fcns:
351 val_result = val_fcn(True, **kwargs)
352 validator_name = val_result["error_name"]
353 error_result = val_result["error_result"]
354 error_reason = val_result["error_reason"]
355
356 # expect an error IFF the error_name and validator_name match
357 expected_result = error_result == (error_name == validator_name)
358 overall_result &= expected_result
359
360 if expected_result and error_result:
361 serializer.setExpectedReturnCode(2, True, desc=error_reason)
362 elif error_result: # and not expected_result
363 print(
364 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
365 f" Expected: {error_name}, Got: {validator_name}"
366 )
367 elif not expected_result: # and not error_result
368 print(
369 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
370 f" Expected: {error_name}"
371 )
372
373 if not expected_result:
374 for k, v in sorted(kwargs.items()):
375 if k != "op":
376 if k.endswith("dtype"):
377 v = valueToName(DType, v)
378 print(f" {k} = {v}")
379
380 return overall_result
381
382 @staticmethod
383 def evWrongInputType(check=False, **kwargs):
384 error_result = False
385
386 # Find the unsupported input data types
387 op = kwargs["op"]
388 input_dtypes = op["types"]
389 allowed_input_dtypes = {
390 t[0] if isinstance(t, list) else t for t in input_dtypes
391 }
392 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
393
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100394 # Turn the wrong dtypes into required list of types
395 if op["op"] in [
396 Op.FULLY_CONNECTED,
397 Op.CONV2D,
398 Op.CONV3D,
399 Op.DEPTHWISE_CONV2D,
400 Op.TRANSPOSE_CONV2D,
401 ]:
402 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
403
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100404 if op["op"] == Op.CLAMP:
405 wrong_input_dtypes.remove(DType.INT48)
406
407 if check:
408 input_dtype = kwargs["input_dtype"]
409 if input_dtype not in allowed_input_dtypes:
410 error_result = True
411
412 info_dict = {
413 "error_name": ErrorIf.WrongInputType,
414 "error_result": error_result,
415 "error_reason": "Input data type not supported for this operator",
416 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
417 }
418 return info_dict
419
420 @staticmethod
421 def evWrongOutputType(check=False, **kwargs):
422 error_result = False
423
424 if check:
425 input_dtype = kwargs["input_dtype"]
426 output_dtype = kwargs["output_dtype"]
427 op = kwargs["op"]
428
429 if op["op"] == Op.RESIZE:
430 mode = kwargs["mode"]
431 if (
432 (
433 mode == ResizeMode.NEAREST
434 and input_dtype == DType.INT8
435 and output_dtype != DType.INT8
436 )
437 or (
438 mode == ResizeMode.NEAREST
439 and input_dtype == DType.INT16
440 and output_dtype != DType.INT16
441 )
442 or (
443 mode == ResizeMode.BILINEAR
444 and input_dtype == DType.INT8
445 and output_dtype != DType.INT32
446 )
447 or (
448 mode == ResizeMode.BILINEAR
449 and input_dtype == DType.INT16
450 and output_dtype != DType.INT48
451 )
James Ward8b390432022-08-12 20:48:56 +0100452 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100453 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100454 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100455 ):
456 error_result = True
457
458 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100459 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
460 input_dtype, output_dtype
461 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100462
463 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
464 if (
465 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
466 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100467 or (
468 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100469 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100470 )
James Ward24dbc422022-10-19 12:20:31 +0100471 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100472 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100473 ):
474 error_result = True
475
476 elif op["op"] == Op.ARGMAX:
477 if (
James Ward24dbc422022-10-19 12:20:31 +0100478 input_dtype
479 in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100480 and output_dtype != DType.INT32
481 ):
482 error_result = True
483
484 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100485 if (
James Ward24dbc422022-10-19 12:20:31 +0100486 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100487 and output_dtype != DType.INT32
488 ):
489 error_result = True
490 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100491 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100492 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
493 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100494 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100495 error_result = True
496
497 elif op["op"] == Op.TABLE:
498 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
499 error_result = True
500 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
501 error_result = True
502
503 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
504 if output_dtype != DType.BOOL:
505 error_result = True
506
507 elif op["op"] == Op.CAST:
508 if (
509 (
510 input_dtype == DType.BOOL
511 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
512 )
513 or (
514 input_dtype == DType.INT8
515 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100516 not in [
517 DType.BOOL,
518 DType.INT16,
519 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100520 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100521 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100522 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100523 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100524 )
525 or (
526 input_dtype == DType.INT16
527 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100528 not in [
529 DType.BOOL,
530 DType.INT8,
531 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100532 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100533 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100534 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100535 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100536 )
537 or (
538 input_dtype == DType.INT32
539 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100540 not in [
541 DType.BOOL,
542 DType.INT8,
543 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100544 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100545 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100546 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100547 ]
548 )
549 or (
550 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000551 and output_dtype
552 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100553 )
554 or (
James Ward24dbc422022-10-19 12:20:31 +0100555 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000556 and output_dtype
557 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +0100558 )
559 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100560 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000561 and output_dtype
562 not in [
563 DType.INT8,
564 DType.INT16,
565 DType.INT32,
566 DType.FP16,
567 DType.BF16,
568 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100569 )
570 ):
571 error_result = True
572
Luke Hutton57287132023-02-06 14:54:18 +0000573 elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000574 if not all([ty == input_dtype for ty in output_dtype]):
575 error_result = True
576
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100577 elif op["op"] in {
578 Op.CONV2D,
579 Op.CONV3D,
580 Op.DEPTHWISE_CONV2D,
581 Op.TRANSPOSE_CONV2D,
582 }:
583 if (
584 input_dtype == DType.INT8
585 and output_dtype != DType.INT32
586 or input_dtype == DType.INT16
587 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100588 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100589 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100590 or input_dtype == DType.BF16
591 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100592 or input_dtype == DType.FP32
593 and output_dtype != DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100594 ):
595 error_result = True
596 # invalid input types are ignored, to avoid reporting multiple errors
597
Won Jeon74342e52024-01-09 00:34:40 +0000598 elif op["op"] in {Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE}:
599 if output_dtype != DType.SHAPE:
600 error_result = True
601
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100602 else:
603 if output_dtype != input_dtype:
604 error_result = True
605
606 info_dict = {
607 "error_name": ErrorIf.WrongOutputType,
608 "error_result": error_result,
609 "error_reason": (
610 "Output data type not supported for this configuration of operator"
611 ),
612 "param_reqs": {"rank": None, "dtype": None, "shape": None},
613 }
614 return info_dict
615
616 @staticmethod
617 def evWrongRank(check=False, **kwargs):
618 all_ranks = (1, 2, 3, 4, 5)
619
620 # Make a list of incorrect ranks
621 assert "op" in kwargs
622 op = kwargs["op"]
623 rmin, rmax = op["rank"]
624 rank_range = range(rmin, rmax + 1)
625 incorrect_ranks = list(set(all_ranks) - set(rank_range))
626 # Remove small incorrect ranks to avoid index errors
627 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
628 # Set minimum incorrect rank to 3 to avoid index error
629 if op["op"] in [Op.RESIZE]:
630 incorrect_ranks = [3, 5]
631 elif op["op"] in [Op.TRANSPOSE]:
632 incorrect_ranks = [7, 8]
633 elif op["op"] in [Op.CONV3D]:
634 incorrect_ranks = [6, 7]
635
636 error_name = ErrorIf.WrongRank
637 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
638 error_result = False
639 error_reason = "Rank not supported for this operator"
640
641 if check:
642 input_shape = kwargs["input_shape"]
643
644 if (
645 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
646 and len(input_shape) != 4
647 ):
648 error_result = True
649 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
650 error_result = True
651 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
652 error_result = True
653 else:
654 if len(input_shape) not in rank_range:
655 error_result = True
656
657 info_dict = {
658 "error_name": error_name,
659 "error_result": error_result,
660 "error_reason": error_reason,
661 "param_reqs": param_reqs,
662 }
663 return info_dict
664
665 @staticmethod
666 def evWrongInputList(check=False, **kwargs):
667 error_name = ErrorIf.WrongInputList
668 param_reqs = {"rank": None, "dtype": None, "shape": None}
669 error_result = False
670 error_reason = "Op input list does not match expected input"
671
672 if check:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100673 input_list = kwargs["input_list"]
674 num_operands = kwargs["num_operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100675 if len(input_list) != num_operands:
676 error_result = True
677
678 info_dict = {
679 "error_name": error_name,
680 "error_result": error_result,
681 "error_reason": error_reason,
682 "param_reqs": param_reqs,
683 }
684 return info_dict
685
686 @staticmethod
687 def evWrongOutputList(check=False, **kwargs):
688 error_name = ErrorIf.WrongOutputList
689 param_reqs = {"rank": None, "dtype": None, "shape": None}
690 error_result = False
691 error_reason = "Op output list does not match expected output"
692
693 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000694 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100695 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000696 expected_length = 1
Luke Hutton57287132023-02-06 14:54:18 +0000697 if op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000698 expected_length = 2
699
700 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100701 error_result = True
702
703 info_dict = {
704 "error_name": error_name,
705 "error_result": error_result,
706 "error_reason": error_reason,
707 "param_reqs": param_reqs,
708 }
709 return info_dict
710
711 @staticmethod
712 def evMaxDimExceeded(check=False, **kwargs):
713 error_name = ErrorIf.MaxDimExceeded
714 param_reqs = {
715 "rank": [4, 4],
716 "dtype": [DType.INT8],
717 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
718 }
719 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100720 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100721
722 if check:
723 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100724 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100725 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100726 (input_shape[1] >= MAX_RESIZE_DIMENSION)
727 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
728 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
729 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100730 ):
731 error_result = True
732
733 info_dict = {
734 "error_name": error_name,
735 "error_result": error_result,
736 "error_reason": error_reason,
737 "param_reqs": param_reqs,
738 }
739 return info_dict
740
741 @staticmethod
742 def evBatchMismatch(check=False, **kwargs):
743 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000744 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100745 error_result = False
746 error_reason = "Input batch size not equal to output batch size"
747
748 assert "op" in kwargs
749 op = kwargs["op"]
750 rmin, rmax = op["rank"]
751 rank_range = range(rmin, rmax + 1)
752
753 if check:
754 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100755
Luke Hutton261b7b62023-01-10 14:50:31 +0000756 for output in kwargs["result_tensors"]:
757 output_shape = (
758 output.shape
759 ) # Note batch is expected to be the first dim
760 if (len(input_shape) in rank_range) and (
761 input_shape[0] != output_shape[0]
762 ):
763 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100764
765 info_dict = {
766 "error_name": error_name,
767 "error_result": error_result,
768 "error_reason": error_reason,
769 "param_reqs": param_reqs,
770 }
771 return info_dict
772
773 @staticmethod
774 def evChannelMismatch(check=False, **kwargs):
775 error_name = ErrorIf.ChannelMismatch
776 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
777 error_result = False
778 error_reason = "Input channel size not equal to output channel size"
779
780 assert "op" in kwargs
781 op = kwargs["op"]
782 rmin, rmax = op["rank"]
783 rank_range = range(rmin, rmax + 1)
784
785 if check:
786 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000787 for output in kwargs["result_tensors"]:
788 output_shape = output.shape # Note this is just (N, OH, OW, C)
789 if (len(input_shape) in rank_range) and (
790 input_shape[3] != output_shape[3]
791 ):
792 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100793
794 info_dict = {
795 "error_name": error_name,
796 "error_result": error_result,
797 "error_reason": error_reason,
798 "param_reqs": param_reqs,
799 }
800 return info_dict
801
802 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100803 def evScaleSmallerEqualZero(check=False, **kwargs):
804 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100805 param_reqs = {"rank": None, "dtype": None, "shape": None}
806 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100807 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100808
809 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100810 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100811
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100812 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100813 error_result = True
814
815 info_dict = {
816 "error_name": error_name,
817 "error_result": error_result,
818 "error_reason": error_reason,
819 "param_reqs": param_reqs,
820 }
821 return info_dict
822
823 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100824 def evScaleNLargerMax(check=False, **kwargs):
825 error_name = ErrorIf.ScaleNLargerMax
826 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100827 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100828 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100829
830 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100831 scale = kwargs["scale"]
832
833 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
834 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100835
836 info_dict = {
837 "error_name": error_name,
838 "error_result": error_result,
839 "error_reason": error_reason,
840 "param_reqs": param_reqs,
841 }
842 return info_dict
843
844 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100845 def evScaleDLargerMax(check=False, **kwargs):
846 error_name = ErrorIf.ScaleDLargerMax
847 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100848 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100849 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100850
851 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100852 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100853
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100854 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
855 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100856 ):
857 error_result = True
858
859 info_dict = {
860 "error_name": error_name,
861 "error_result": error_result,
862 "error_reason": error_reason,
863 "param_reqs": param_reqs,
864 }
865 return info_dict
866
867 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100868 def evOffsetSmallerMin(check=False, **kwargs):
869 error_name = ErrorIf.OffsetSmallerMin
870 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100871 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100872 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100873
874 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100875 scale = kwargs["scale"]
876 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100877
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100878 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100879 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100880 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100881 error_result = True
882
883 info_dict = {
884 "error_name": error_name,
885 "error_result": error_result,
886 "error_reason": error_reason,
887 "param_reqs": param_reqs,
888 }
889 return info_dict
890
891 @staticmethod
892 def evOffsetLargerEqualMax(check=False, **kwargs):
893 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100894 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100895 error_result = False
896 error_reason = "Offset value larger than or equal to maximum value"
897
898 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100899 scale = kwargs["scale"]
900 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100902 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
903 error_result = True
904 elif (
905 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
906 ):
907 error_result = True
908
909 info_dict = {
910 "error_name": error_name,
911 "error_result": error_result,
912 "error_reason": error_reason,
913 "param_reqs": param_reqs,
914 }
915 return info_dict
916
917 @staticmethod
918 def evBorderSmallerMin(check=False, **kwargs):
919 error_name = ErrorIf.BorderSmallerMin
920 param_reqs = {"rank": None, "dtype": None, "shape": None}
921 error_result = False
922 error_reason = "Border value smaller than minimum value"
923
924 if check:
925 scale = kwargs["scale"]
926 border = kwargs["border"]
927
928 if (
929 scale[0] > 0
930 and scale[0] <= (1 << 11)
931 and (border[0] < (-16 * scale[0]))
932 ):
933 error_result = True
934 elif (
935 scale[2] > 0
936 and scale[2] <= (1 << 11)
937 and (border[1] < (-16 * scale[2]))
938 ):
939 error_result = True
940
941 info_dict = {
942 "error_name": error_name,
943 "error_result": error_result,
944 "error_reason": error_reason,
945 "param_reqs": param_reqs,
946 }
947 return info_dict
948
949 @staticmethod
950 def evBorderLargerEqualMax(check=False, **kwargs):
951 error_name = ErrorIf.BorderLargerEqualMax
952 param_reqs = {"rank": None, "dtype": None, "shape": None}
953 error_result = False
954 error_reason = "Border value larger than or equal to maximum value"
955
956 if check:
957 scale = kwargs["scale"]
958 border = kwargs["border"]
959
960 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
961 error_result = True
962 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
963 error_result = True
964
965 info_dict = {
966 "error_name": error_name,
967 "error_result": error_result,
968 "error_reason": error_reason,
969 "param_reqs": param_reqs,
970 }
971 return info_dict
972
973 @staticmethod
974 def checkResizeParams(scale, offset, border):
975 return (
976 min(scale) > 0
977 and max(scale[0], scale[2]) <= (1 << 11)
978 and scale[1] < 16 * scale[0]
979 and scale[3] < 16 * scale[2]
980 and offset[0] >= -scale[0]
981 and offset[1] >= -scale[2]
982 and offset[0] < 16 * scale[0]
983 and offset[1] < 16 * scale[2]
984 and border[0] >= -16 * scale[0]
985 and border[1] >= -16 * scale[2]
986 and border[0] < scale[0]
987 and border[1] < scale[2]
988 )
989
990 @staticmethod
991 def evResizeOutputShapeMismatch(check=False, **kwargs):
992 error_name = ErrorIf.ResizeOutputShapeMismatch
993 param_reqs = {"rank": None, "dtype": None, "shape": None}
994 error_result = False
995 error_reason = (
996 "Mismatch between output shape provided and expected output shape"
997 )
998
999 if check:
1000 input_shape = kwargs["input_shape"]
1001 output_shape = kwargs["output_shape"]
1002 scale = kwargs["scale"]
1003 offset = kwargs["offset"]
1004 border = kwargs["border"]
1005
1006 # Ensure parameters are valid
1007 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1008
1009 if (
1010 params_valid
1011 and max(output_shape) < MAX_RESIZE_DIMENSION
1012 and max(input_shape) < MAX_RESIZE_DIMENSION
1013 ):
1014 output_y = (
1015 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1016 ) // scale[1] + 1
1017 output_x = (
1018 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1019 ) // scale[3] + 1
1020
1021 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001022 error_result = True
1023
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001024 info_dict = {
1025 "error_name": error_name,
1026 "error_result": error_result,
1027 "error_reason": error_reason,
1028 "param_reqs": param_reqs,
1029 }
1030 return info_dict
1031
1032 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001033 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1034 error_name = ErrorIf.ResizeOutputShapeNonInteger
1035 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001037 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001038
1039 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001040 input_shape = kwargs["input_shape"]
1041 scale = kwargs["scale"]
1042 offset = kwargs["offset"]
1043 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001045 # Ensure parameters are valid
1046 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001047
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001048 if params_valid:
1049 remainder_y = (
1050 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1051 ) % scale[1]
1052 remainder_x = (
1053 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1054 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001055
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001056 if max(remainder_y, remainder_x) > 0:
1057 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001058
1059 info_dict = {
1060 "error_name": error_name,
1061 "error_result": error_result,
1062 "error_reason": error_reason,
1063 "param_reqs": param_reqs,
1064 }
1065 return info_dict
1066
1067 @staticmethod
1068 def evRankMismatch(check=False, **kwargs):
1069 error_name = ErrorIf.RankMismatch
1070 param_reqs = {"rank": None, "dtype": None, "shape": None}
1071 error_result = False
1072 error_reason = "Input Rank does not match output rank"
1073
1074 if check:
1075 input1_shape = kwargs["input1"].shape
Luke Huttona4e48ca2023-02-22 11:53:48 +00001076 input2_shape = (
1077 kwargs["input2"].shape if "input2" in kwargs else input1_shape
1078 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001079 # In case of SELECT op
1080 input3_shape = (
1081 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1082 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001083
1084 for output in kwargs["result_tensors"]:
1085 output_shape = output.shape
1086 if (
1087 (len(input1_shape) != len(output_shape))
1088 or (len(input2_shape) != len(output_shape))
1089 or (len(input3_shape) != len(output_shape))
1090 ):
1091 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001092
1093 info_dict = {
1094 "error_name": error_name,
1095 "error_result": error_result,
1096 "error_reason": error_reason,
1097 "param_reqs": param_reqs,
1098 }
1099 return info_dict
1100
1101 @staticmethod
1102 def evDimensionMismatch(check=False, **kwargs):
1103 error_name = ErrorIf.DimensionMismatch
1104 param_reqs = {"rank": None, "dtype": None, "shape": None}
1105 error_result = False
1106 error_reason = "Input Dimensions do not match output"
1107
1108 if check:
1109 input1_shape = kwargs["input1"].shape
1110 input2_shape = kwargs["input2"].shape
1111 # In case of SELECT op
1112 input3_shape = (
1113 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1114 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001115
Won Jeon74342e52024-01-09 00:34:40 +00001116 op = kwargs["op"]
1117 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE):
1118 output_shape = kwargs["result_tensors"][0].shape
1119 if input1_shape != output_shape:
1120 error_result = True
1121
1122 elif len(input1_shape) == len(input2_shape) == len(input3_shape):
Jerry Ge135c9552023-05-23 20:59:32 +00001123 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
1124 input3_shape,
1125 TosaErrorValidator.calculateBroadcastShape(
1126 input1_shape, input2_shape
1127 ),
1128 )
1129 if calculated_shape is not None:
1130 # Valid inputs - check for output mismatch
1131 for output in kwargs["result_tensors"]:
1132 output_shape = output.shape
1133 if calculated_shape != output_shape:
1134 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001135
1136 info_dict = {
1137 "error_name": error_name,
1138 "error_result": error_result,
1139 "error_reason": error_reason,
1140 "param_reqs": param_reqs,
1141 }
1142 return info_dict
1143
1144 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001145 def _getZeroPoint(qinfo, index):
1146 """Return zero point value from quantization info.
1147
1148 Generally input_zp is index 0, output_zp is index 1
1149 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001150 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001151
1152 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001153 def evInputZeroPointNotZero(check=False, **kwargs):
1154 op = kwargs["op"]
1155 error_result = False
1156
1157 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001158 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001159
1160 # This does not apply to quantizable types
1161 inputDtypes = [
1162 dtype
1163 for dtype in op["types"]
1164 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1165 or (not isinstance(dtype, list) and dtype not in qTypes)
1166 ]
1167
1168 if check:
1169 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001170 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001171 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001172 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001173 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001174 (kwargs["input_dtype"], input_zero_point),
1175 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001176 ):
1177 if dtype not in qTypes and zp != 0:
1178 error_result = True
1179 break
1180 else:
1181 error_result = input_dtype not in qTypes and input_zero_point != 0
1182
1183 info_dict = {
1184 "error_name": ErrorIf.InputZeroPointNotZero,
1185 "error_result": error_result,
1186 "error_reason": "Input DType not INT8 and zero point not 0",
1187 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1188 }
1189 return info_dict
1190
1191 @staticmethod
1192 def evWeightZeroPointNotZero(check=False, **kwargs):
1193 op = kwargs["op"]
1194
1195 # exclude inputs with INT8 weights
1196 inputDtypes = [
1197 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1198 ]
1199
1200 error_name = ErrorIf.WeightZeroPointNotZero
1201 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1202 error_result = False
1203 error_reason = "Weight DType not INT8 and zero point not 0"
1204
1205 if check:
1206 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001207 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001208 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1209 error_result = True
1210
1211 info_dict = {
1212 "error_name": error_name,
1213 "error_result": error_result,
1214 "error_reason": error_reason,
1215 "param_reqs": param_reqs,
1216 }
1217 return info_dict
1218
1219 @staticmethod
1220 def evOutputZeroPointNotZero(check=False, **kwargs):
1221 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001222 inputDtypes = [
1223 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1224 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001225
1226 error_name = ErrorIf.OutputZeroPointNotZero
1227 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1228 error_result = False
1229 error_reason = "Output DType not INT8 and zero point not 0"
1230
1231 if check:
1232 input_dtype = kwargs["input_dtype"]
1233 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001234 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001235 if op["op"] == Op.AVG_POOL2D:
1236 if input_dtype != DType.INT8 and output_zero_point != 0:
1237 error_result = True
1238 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001239 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1240 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001241 ):
1242 error_result = True
1243
1244 info_dict = {
1245 "error_name": error_name,
1246 "error_result": error_result,
1247 "error_reason": error_reason,
1248 "param_reqs": param_reqs,
1249 }
1250 return info_dict
1251
1252 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001253 def evU16InputZeroPointNotValid(check=False, **kwargs):
1254 error_name = ErrorIf.U16InputZeroPointNotValid
1255 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1256 error_result = False
1257 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1258
1259 if check:
1260 input_dtype = kwargs["input_dtype"]
1261 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1262 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1263 0,
1264 32768,
1265 ]
1266
1267 info_dict = {
1268 "error_name": error_name,
1269 "error_result": error_result,
1270 "error_reason": error_reason,
1271 "param_reqs": param_reqs,
1272 }
1273 return info_dict
1274
1275 @staticmethod
1276 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1277 error_name = ErrorIf.U16OutputZeroPointNotValid
1278 param_reqs = {"rank": None, "dtype": None, "shape": None}
1279 error_result = False
1280 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1281
1282 if check:
1283 output_dtype = kwargs["output_dtype"]
1284 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1285
1286 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1287 0,
1288 32768,
1289 ]
1290
1291 info_dict = {
1292 "error_name": error_name,
1293 "error_result": error_result,
1294 "error_reason": error_reason,
1295 "param_reqs": param_reqs,
1296 }
1297 return info_dict
1298
1299 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001300 def evAxisSmallerZero(check=False, **kwargs):
1301 error_name = ErrorIf.AxisSmallerZero
1302 param_reqs = {"rank": None, "dtype": None, "shape": None}
1303 error_result = False
1304 error_reason = "Axis smaller than zero"
1305
1306 if check:
1307 axis = kwargs["axis"]
1308 if axis < 0:
1309 error_result = True
1310
1311 info_dict = {
1312 "error_name": error_name,
1313 "error_result": error_result,
1314 "error_reason": error_reason,
1315 "param_reqs": param_reqs,
1316 }
1317 return info_dict
1318
1319 @staticmethod
1320 def evAxisLargerRank(check=False, **kwargs):
1321 error_name = ErrorIf.AxisLargerRank
1322 param_reqs = {"rank": None, "dtype": None, "shape": None}
1323 error_result = False
1324 error_reason = "Axis larger than rank"
1325
1326 if check:
1327 axis = kwargs["axis"]
1328 shape = kwargs["input_shape"]
1329 if axis > len(shape):
1330 error_result = True
1331
1332 info_dict = {
1333 "error_name": error_name,
1334 "error_result": error_result,
1335 "error_reason": error_reason,
1336 "param_reqs": param_reqs,
1337 }
1338 return info_dict
1339
1340 @staticmethod
1341 def evShapeOfAxisNotOne(check=False, **kwargs):
1342 error_name = ErrorIf.ShapeOfAxisNotOne
1343 param_reqs = {"rank": None, "dtype": None, "shape": None}
1344 error_result = False
1345 error_reason = "shape[axis] is not equal to 1"
1346
1347 if check:
1348 axis = kwargs["axis"]
1349 shape = kwargs["output_shape"]
1350 if (0 <= axis < len(shape)) and shape[axis] != 1:
1351 error_result = True
1352
1353 info_dict = {
1354 "error_name": error_name,
1355 "error_result": error_result,
1356 "error_reason": error_reason,
1357 "param_reqs": param_reqs,
1358 }
1359 return info_dict
1360
1361 @staticmethod
1362 def evPadSmallerZero(check=False, **kwargs):
1363 error_name = ErrorIf.PadSmallerZero
1364 param_reqs = {"rank": None, "dtype": None, "shape": None}
1365 error_result = False
1366 error_reason = "At least one pad is smaller than zero"
1367
1368 if check:
1369 op = kwargs["op"]
1370 pad = kwargs["pad"]
1371 if op["op"] == Op.PAD:
1372 for padding in pad:
1373 if min(padding) < 0:
1374 error_result = True
1375 else:
1376 if min(pad) < 0:
1377 error_result = True
1378
1379 info_dict = {
1380 "error_name": error_name,
1381 "error_result": error_result,
1382 "error_reason": error_reason,
1383 "param_reqs": param_reqs,
1384 }
1385 return info_dict
1386
1387 @staticmethod
1388 def evPadLargerEqualKernel(check=False, **kwargs):
1389 error_name = ErrorIf.PadLargerEqualKernel
1390 param_reqs = {"rank": None, "dtype": None, "shape": None}
1391 error_result = False
1392 error_reason = "At least one pad is larger than kernel dimension"
1393
1394 if check:
1395 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001396 op = kwargs["op"]
1397 if op["op"] == Op.TRANSPOSE_CONV2D:
1398 # transpose_conv2d
1399 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001400 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001401 pad[0] <= -kernel[0]
1402 or pad[1] <= -kernel[0]
1403 or pad[2] <= -kernel[1]
1404 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001405 ):
1406 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001407 else:
1408 # pooling op
1409 kernel = kwargs["kernel"]
1410 if min(pad) > 0 and min(kernel) > 1:
1411 if (
1412 pad[0] >= kernel[0]
1413 or pad[1] >= kernel[0]
1414 or pad[2] >= kernel[1]
1415 or pad[3] >= kernel[1]
1416 ):
1417 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001418
1419 info_dict = {
1420 "error_name": error_name,
1421 "error_result": error_result,
1422 "error_reason": error_reason,
1423 "param_reqs": param_reqs,
1424 }
1425 return info_dict
1426
1427 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001428 def evPadOutputShapeMismatch(check=False, **kwargs):
1429 error_name = ErrorIf.PadOutputShapeMismatch
1430 param_reqs = {"rank": None, "dtype": None, "shape": None}
1431 error_result = False
1432 error_reason = "Pad output shape mismatch for requested padding"
1433
1434 if check:
1435 pad = kwargs["pad"]
1436 input_shape = kwargs["input_shape"]
1437 output_shape = kwargs["output_shape"]
1438 for dim, padding in enumerate(pad):
1439 expected_size = input_shape[dim] + padding[0] + padding[1]
1440 if expected_size != output_shape[dim]:
1441 error_result = True
1442
1443 info_dict = {
1444 "error_name": error_name,
1445 "error_result": error_result,
1446 "error_reason": error_reason,
1447 "param_reqs": param_reqs,
1448 }
1449 return info_dict
1450
1451 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001452 def checkPoolingParams(kernel, stride, pad):
1453 return (
1454 min(kernel) >= 1
1455 and min(stride) >= 1
1456 and min(pad) >= 0
1457 and not (
1458 pad[0] >= kernel[0]
1459 or pad[1] >= kernel[0]
1460 or pad[2] >= kernel[1]
1461 or pad[3] >= kernel[1]
1462 )
1463 )
1464
1465 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001466 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1467 error_name = ErrorIf.PoolingOutputShapeMismatch
1468 param_reqs = {"rank": None, "dtype": None, "shape": None}
1469 error_result = False
1470 error_reason = (
1471 "Mismatch between output shape provided and expected output shape"
1472 )
1473
1474 if check:
1475 pad = kwargs["pad"]
1476 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1477
1478 kernel = kwargs["kernel"]
1479 kernel_y, kernel_x = kernel[0], kernel[1]
1480
1481 input_shape = kwargs["input_shape"]
1482 IH, IW = input_shape[1], input_shape[2]
1483
1484 output_shape = kwargs["output_shape"]
1485 OH, OW = output_shape[1], output_shape[2]
1486
1487 stride = kwargs["stride"]
1488 stride_y, stride_x = stride[0], stride[1]
1489
1490 # calculate correct height, width dimensions
1491 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001492 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1493 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001494
1495 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001496 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001497
1498 if params_valid and (OH != y_correct or OW != x_correct):
1499 error_result = True
1500
1501 info_dict = {
1502 "error_name": error_name,
1503 "error_result": error_result,
1504 "error_reason": error_reason,
1505 "param_reqs": param_reqs,
1506 }
1507 return info_dict
1508
1509 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001510 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1511 error_name = ErrorIf.PoolingOutputShapeNonInteger
1512 param_reqs = {"rank": None, "dtype": None, "shape": None}
1513 error_result = False
1514 error_reason = "Parameters do not yield exact integer output dimensions"
1515
1516 if check:
1517 pad = kwargs["pad"]
1518 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1519
1520 kernel = kwargs["kernel"]
1521 kernel_y, kernel_x = kernel[0], kernel[1]
1522
1523 input_shape = kwargs["input_shape"]
1524 IH, IW = input_shape[1], input_shape[2]
1525
1526 stride = kwargs["stride"]
1527 stride_y, stride_x = stride[0], stride[1]
1528
1529 # calculate remainder of height, width dimensions
1530 if stride_x != 0 and stride_y != 0:
1531 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1532 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1533
1534 # ensure parameters are valid
1535 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1536 if params_valid and (y_remainder != 0 or x_remainder != 0):
1537 error_result = True
1538
1539 info_dict = {
1540 "error_name": error_name,
1541 "error_result": error_result,
1542 "error_reason": error_reason,
1543 "param_reqs": param_reqs,
1544 }
1545 return info_dict
1546
1547 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001548 def checkConvParams(op, weight_shape, stride, pad, dilation):
1549 if op == Op.TRANSPOSE_CONV2D:
1550 pad_ok = (
1551 pad[0] > -weight_shape[1]
1552 and pad[1] > -weight_shape[1]
1553 and pad[2] > -weight_shape[2]
1554 and pad[3] > -weight_shape[2]
1555 )
1556 else:
1557 pad_ok = min(pad) >= 0
1558
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001559 return (
1560 # Check kernel sizes
1561 min(weight_shape[1:-1]) >= 1
1562 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001563 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001564 and (dilation is None or min(dilation) >= 1)
1565 )
1566
1567 @staticmethod
1568 def evConvOutputShapeMismatch(check=False, **kwargs):
1569 error_name = ErrorIf.ConvOutputShapeMismatch
1570 param_reqs = {"rank": None, "dtype": None, "shape": None}
1571 error_result = False
1572 error_reason = (
1573 "Mismatch between output shape provided and expected output shape"
1574 )
1575
1576 if check:
1577 op = kwargs["op"]
1578 pad = kwargs["pad"]
1579 weight_shape = kwargs["weight_shape"]
1580 input_shape = kwargs["input_shape"]
1581 output_shape = kwargs["output_shape"]
1582 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1583 stride = kwargs["stride"]
1584
1585 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1586
1587 # calculate correct dimensions
1588 dims_correct = []
1589 if min(stride) > 0:
1590 for index in range(len(stride)):
1591 pad_offset = index * 2
1592 if op["op"] == Op.TRANSPOSE_CONV2D:
1593 dims_correct.append(
1594 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001595 + pad[pad_offset]
1596 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001597 + weight_shape[index + kernel_offset]
1598 )
1599 else:
1600 dims_correct.append(
1601 (
1602 input_shape[index + 1]
1603 - 1
1604 + pad[pad_offset]
1605 + pad[pad_offset + 1]
1606 - (weight_shape[index + kernel_offset] - 1)
1607 * dilation[index]
1608 )
1609 // stride[index]
1610 + 1
1611 )
1612
1613 # ensure parameters are valid
1614 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001615 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001616 )
1617
1618 if params_valid and output_shape[1:-1] != dims_correct:
1619 error_result = True
1620
1621 info_dict = {
1622 "error_name": error_name,
1623 "error_result": error_result,
1624 "error_reason": error_reason,
1625 "param_reqs": param_reqs,
1626 }
1627 return info_dict
1628
1629 @staticmethod
1630 def evConvOutputShapeNonInteger(check=False, **kwargs):
1631 error_name = ErrorIf.ConvOutputShapeNonInteger
1632 param_reqs = {"rank": None, "dtype": None, "shape": None}
1633 error_result = False
1634 error_reason = "Parameters do not yield exact integer output dimensions"
1635
1636 if check:
1637 op = kwargs["op"]
1638 pad = kwargs["pad"]
1639 weight_shape = kwargs["weight_shape"]
1640 input_shape = kwargs["input_shape"]
1641 dilation = kwargs["dilation"]
1642 stride = kwargs["stride"]
1643
1644 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1645
1646 # calculate correct height, width dimensions
1647 remainders = []
1648 if min(stride) > 0:
1649 for index in range(len(stride)):
1650 pad_offset = index * 2
1651 remainders.append(
1652 (
1653 input_shape[index + 1]
1654 - 1
1655 + pad[pad_offset]
1656 + pad[pad_offset + 1]
1657 - (weight_shape[index + kernel_offset] - 1)
1658 * dilation[index]
1659 )
1660 % stride[index]
1661 )
1662
1663 # ensure parameters are valid
1664 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001665 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001666 )
1667 if params_valid and max(remainders) > 0:
1668 error_result = True
1669
1670 info_dict = {
1671 "error_name": error_name,
1672 "error_result": error_result,
1673 "error_reason": error_reason,
1674 "param_reqs": param_reqs,
1675 }
1676 return info_dict
1677
1678 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001679 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1680 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1681 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1682 error_result = False
1683 error_reason = (
1684 "Mismatch between output shape provided and expected output shape"
1685 )
1686
1687 if check:
1688 output_shape = kwargs["output_shape"]
1689 input_shape = kwargs["input_shape"]
1690 axis = kwargs["axis"]
1691
1692 dimension_match = True
1693 axis_shift = 0
1694
1695 # Check that rank is correct before trying to check dimensions
1696 if (len(input_shape) - 1) == len(output_shape):
1697 for i in range(len(input_shape)):
1698 if i == axis:
1699 axis_shift = 1
1700 continue
1701 if input_shape[i] != output_shape[i - axis_shift]:
1702 dimension_match = False
1703
1704 if not dimension_match:
1705 error_result = True
1706
1707 info_dict = {
1708 "error_name": error_name,
1709 "error_result": error_result,
1710 "error_reason": error_reason,
1711 "param_reqs": param_reqs,
1712 }
1713 return info_dict
1714
1715 @staticmethod
1716 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1717 error_name = ErrorIf.ArgmaxOutputRankMismatch
1718 param_reqs = {"rank": None, "dtype": None, "shape": None}
1719 error_result = False
1720 error_reason = (
1721 "Mismatch between output shape provided and expected output shape"
1722 )
1723
1724 if check:
1725 output_shape = kwargs["output_shape"]
1726 input_shape = kwargs["input_shape"]
1727 axis = kwargs["axis"]
1728 valid_params = axis >= 0 and axis < len(input_shape)
1729
1730 if valid_params and (len(input_shape) - 1) != len(output_shape):
1731 error_result = True
1732
1733 info_dict = {
1734 "error_name": error_name,
1735 "error_result": error_result,
1736 "error_reason": error_reason,
1737 "param_reqs": param_reqs,
1738 }
1739 return info_dict
1740
1741 @staticmethod
1742 def evKernelSmallerOne(check=False, **kwargs):
1743 error_name = ErrorIf.KernelSmallerOne
1744 param_reqs = {"rank": None, "dtype": None, "shape": None}
1745 error_result = False
1746 error_reason = "At least one kernel dimension is smaller than zero"
1747
1748 if check:
1749 kernel = kwargs["kernel"]
1750 if min(kernel) < 1:
1751 error_result = True
1752
1753 info_dict = {
1754 "error_name": error_name,
1755 "error_result": error_result,
1756 "error_reason": error_reason,
1757 "param_reqs": param_reqs,
1758 }
1759 return info_dict
1760
1761 @staticmethod
1762 def evStrideSmallerOne(check=False, **kwargs):
1763 error_name = ErrorIf.StrideSmallerOne
1764 param_reqs = {"rank": None, "dtype": None, "shape": None}
1765 error_result = False
1766 error_reason = "At least one stride dimension is smaller than zero"
1767
1768 if check:
1769 stride = kwargs["stride"]
1770 if min(stride) < 1:
1771 error_result = True
1772
1773 info_dict = {
1774 "error_name": error_name,
1775 "error_result": error_result,
1776 "error_reason": error_reason,
1777 "param_reqs": param_reqs,
1778 }
1779 return info_dict
1780
1781 @staticmethod
1782 def evDilationSmallerOne(check=False, **kwargs):
1783 error_result = check and min(kwargs["dilation"]) < 1
1784 return {
1785 "error_name": ErrorIf.DilationSmallerOne,
1786 "error_reason": "At least one dilation is smaller than one",
1787 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1788 "error_result": error_result,
1789 }
1790
1791 @staticmethod
1792 def evScaleTrue(check=False, **kwargs):
1793 error_name = ErrorIf.ScaleTrue
1794 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1795 error_result = False
1796 error_reason = "Scale set to true but input type is INT48"
1797
1798 if check:
1799 input_dtype = kwargs["input_dtype"]
1800 scale32 = kwargs["scale32"]
1801 if scale32 and input_dtype == DType.INT48:
1802 error_result = True
1803
1804 info_dict = {
1805 "error_name": error_name,
1806 "error_result": error_result,
1807 "error_reason": error_reason,
1808 "param_reqs": param_reqs,
1809 }
1810 return info_dict
1811
1812 @staticmethod
1813 def evScaleNotTrue(check=False, **kwargs):
1814 error_name = ErrorIf.ScaleNotTrue
1815 param_reqs = {"rank": None, "dtype": None, "shape": None}
1816 error_result = False
1817 error_reason = "Scale set to false but double round set to true"
1818
1819 if check:
1820 scale32 = kwargs["scale32"]
1821 double_round = kwargs["double_round"]
1822 if not scale32 and double_round:
1823 error_result = True
1824
1825 info_dict = {
1826 "error_name": error_name,
1827 "error_result": error_result,
1828 "error_reason": error_reason,
1829 "param_reqs": param_reqs,
1830 }
1831 return info_dict
1832
1833 @staticmethod
1834 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1835 error_name = ErrorIf.TensorSizeInputOutputMismatch
1836 param_reqs = {"rank": None, "dtype": None, "shape": None}
1837 error_result = False
1838 error_reason = "Input tensor size does not match output tensor size"
Jerry Ge264f7fa2023-04-21 22:49:57 +00001839 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001840
1841 if check:
1842 input_shape = kwargs["input_shape"]
1843 output_shape = kwargs["output_shape"]
Jerry Ge264f7fa2023-04-21 22:49:57 +00001844 shape_inferencing = False
1845 if -1 in output_shape and op["op"] == Op.RESHAPE:
1846 shape_inferencing = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001847 input_size = np.prod(input_shape)
1848 output_size = np.prod(output_shape)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001849 if input_size != output_size and not shape_inferencing:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001850 error_result = True
1851
1852 info_dict = {
1853 "error_name": error_name,
1854 "error_result": error_result,
1855 "error_reason": error_reason,
1856 "param_reqs": param_reqs,
1857 }
1858 return info_dict
1859
1860 @staticmethod
1861 def evStartSmallerZero(check=False, **kwargs):
1862 error_name = ErrorIf.StartSmallerZero
1863 param_reqs = {"rank": None, "dtype": None, "shape": None}
1864 error_result = False
1865 error_reason = "Starting point smaller than zero"
1866
1867 if check:
1868 input_shape = kwargs["input_shape"]
1869 start = kwargs["start"]
1870 rank = len(input_shape)
1871 if len(start) == rank:
1872 for index in range(rank):
1873 if start[index] < 0:
1874 error_result = True
1875
1876 info_dict = {
1877 "error_name": error_name,
1878 "error_result": error_result,
1879 "error_reason": error_reason,
1880 "param_reqs": param_reqs,
1881 }
1882 return info_dict
1883
1884 @staticmethod
1885 def evSizeSmallerEqualZero(check=False, **kwargs):
1886 error_name = ErrorIf.SizeSmallerEqualZero
1887 param_reqs = {"rank": None, "dtype": None, "shape": None}
1888 error_result = False
1889 error_reason = "Size smaller than or equal to zero"
1890
1891 if check:
1892 input_shape = kwargs["input_shape"]
1893 size = kwargs["size"]
1894 rank = len(input_shape)
1895 if len(size) == rank:
1896 for index in range(rank):
1897 if size[index] <= 0:
1898 error_result = True
1899
1900 info_dict = {
1901 "error_name": error_name,
1902 "error_result": error_result,
1903 "error_reason": error_reason,
1904 "param_reqs": param_reqs,
1905 }
1906 return info_dict
1907
1908 @staticmethod
1909 def evStartSizeOutsideBounds(check=False, **kwargs):
1910 error_name = ErrorIf.StartSizeOutsideBounds
1911 param_reqs = {"rank": None, "dtype": None, "shape": None}
1912 error_result = False
1913 error_reason = "starting point plus size larger than input dimension"
1914
1915 if check:
1916 input_shape = kwargs["input_shape"]
1917 start = kwargs["start"]
1918 size = kwargs["size"]
1919 rank = len(input_shape)
1920 if len(start) == rank and len(size) == rank:
1921 for index in range(rank):
1922 if start[index] + size[index] > input_shape[index]:
1923 error_result = True
1924
1925 info_dict = {
1926 "error_name": error_name,
1927 "error_result": error_result,
1928 "error_reason": error_reason,
1929 "param_reqs": param_reqs,
1930 }
1931 return info_dict
1932
1933 @staticmethod
1934 def evSizeOutputShapeMismatch(check=False, **kwargs):
1935 error_name = ErrorIf.SizeOutputShapeMismatch
1936 param_reqs = {"rank": None, "dtype": None, "shape": None}
1937 error_result = False
1938 error_reason = "Size does not match output dimension"
1939
1940 if check:
1941 input_shape = kwargs["input_shape"]
1942 output_shape = kwargs["output_shape"]
1943 size = kwargs["size"]
Luke Huttona4e48ca2023-02-22 11:53:48 +00001944
1945 if len(input_shape) == len(output_shape):
1946 rank = len(input_shape)
1947 if len(size) == rank:
1948 for index in range(rank):
1949 if size[index] != output_shape[index]:
1950 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001951
1952 info_dict = {
1953 "error_name": error_name,
1954 "error_result": error_result,
1955 "error_reason": error_reason,
1956 "param_reqs": param_reqs,
1957 }
1958 return info_dict
1959
1960 @staticmethod
1961 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1962 error_name = ErrorIf.InputSizeStartLengthMismatch
1963 param_reqs = {"rank": None, "dtype": None, "shape": None}
1964 error_result = False
1965 error_reason = "rank of input not equal to length of start or size"
1966
1967 if check:
1968 input_shape = kwargs["input_shape"]
1969 start = kwargs["start"]
1970 size = kwargs["size"]
1971 rank = len(input_shape)
1972 if rank != len(start) or rank != len(size):
1973 error_result = True
1974
1975 info_dict = {
1976 "error_name": error_name,
1977 "error_result": error_result,
1978 "error_reason": error_reason,
1979 "param_reqs": param_reqs,
1980 }
1981 return info_dict
1982
1983 @staticmethod
1984 def evIndexOutsideBounds(check=False, **kwargs):
1985 error_name = ErrorIf.IndexOutsideBounds
1986 param_reqs = {"rank": None, "dtype": None, "shape": None}
1987 error_result = False
1988 error_reason = "Index outside of allowed bounds"
1989
1990 if check:
1991 input_shape = kwargs["input_shape"]
1992 perms = kwargs["perms"]
1993 rank = len(input_shape)
1994
1995 for index in perms:
1996 if index < 0 or index > rank:
1997 error_result = True
1998
1999 info_dict = {
2000 "error_name": error_name,
2001 "error_result": error_result,
2002 "error_reason": error_reason,
2003 "param_reqs": param_reqs,
2004 }
2005 return info_dict
2006
2007 @staticmethod
2008 def evIndexUsedTwice(check=False, **kwargs):
2009 error_name = ErrorIf.IndexUsedTwice
2010 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2011 error_result = False
2012 error_reason = "Index used multiple times"
2013
2014 if check:
2015 perms = kwargs["perms"]
2016
2017 unique_indices = []
2018 for index in perms:
2019 if index in unique_indices:
2020 error_result = True
2021 else:
2022 unique_indices.append(index)
2023
2024 info_dict = {
2025 "error_name": error_name,
2026 "error_result": error_result,
2027 "error_reason": error_reason,
2028 "param_reqs": param_reqs,
2029 }
2030 return info_dict
2031
2032 @staticmethod
2033 def evMaxSmallerMin(check=False, **kwargs):
2034 error_name = ErrorIf.MaxSmallerMin
2035 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2036 error_result = False
2037 error_reason = "Max value smaller than min value"
2038
2039 if check:
2040 max_val = kwargs["max_val"]
2041 min_val = kwargs["min_val"]
2042 if max_val < min_val:
2043 error_result = True
2044
2045 info_dict = {
2046 "error_name": error_name,
2047 "error_result": error_result,
2048 "error_reason": error_reason,
2049 "param_reqs": param_reqs,
2050 }
2051 return info_dict
2052
2053 @staticmethod
2054 def evConcatInputRankMismatch(check=False, **kwargs):
2055 error_name = ErrorIf.ConcatInputRankMismatch
2056 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2057 error_result = False
2058 error_reason = "Input ranks are not identical"
2059
2060 if check:
2061 inputs = kwargs["inputs"]
2062 input_shape = kwargs["input_shape"]
2063 for input in inputs:
2064 if len(input.shape) != len(input_shape):
2065 error_result = True
2066
2067 info_dict = {
2068 "error_name": error_name,
2069 "error_result": error_result,
2070 "error_reason": error_reason,
2071 "param_reqs": param_reqs,
2072 }
2073 return info_dict
2074
2075 @staticmethod
2076 def evConcatInputDimMismatch(check=False, **kwargs):
2077 error_name = ErrorIf.ConcatInputDimMismatch
2078 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2079 error_result = False
2080 error_reason = "Input dimensions differ on too many axes"
2081
2082 if check:
2083 inputs = kwargs["inputs"]
2084 input_shape = kwargs["input_shape"]
2085 axis = kwargs["axis"]
2086
2087 # Ensure rank is valid before checking dims.
2088 valid_rank = True
2089 for input in inputs:
2090 if len(input.shape) != len(input_shape):
2091 valid_rank = False
2092
2093 if valid_rank:
2094 for input in inputs:
2095 for i, dim in enumerate(input.shape):
2096 if dim != input_shape[i] and axis != i:
2097 error_result = True
2098
2099 info_dict = {
2100 "error_name": error_name,
2101 "error_result": error_result,
2102 "error_reason": error_reason,
2103 "param_reqs": param_reqs,
2104 }
2105 return info_dict
2106
2107 @staticmethod
2108 def evConcatShapeSumMismatch(check=False, **kwargs):
2109 error_name = ErrorIf.ConcatShapeSumMismatch
2110 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2111 error_result = False
2112 error_reason = "Sum of dimensions on axis not equal to output dimension"
2113
2114 if check:
2115 inputs = kwargs["inputs"]
2116 input_shape = kwargs["input_shape"]
2117 output_shape = kwargs["output_shape"]
2118 axis = kwargs["axis"]
2119
2120 # Ensure rank is valid before checking dims.
2121 valid_params = True
2122 for input in inputs:
2123 if len(input.shape) != len(input_shape):
2124 valid_params = False
2125 if axis < 0 or axis > len(input_shape):
2126 valid_params = False
2127
2128 if valid_params:
2129 axis_dim_sum = 0
2130 for input in inputs:
2131 axis_dim_sum += input.shape[axis]
2132
2133 if axis_dim_sum != output_shape[axis]:
2134 error_result = True
2135
2136 info_dict = {
2137 "error_name": error_name,
2138 "error_result": error_result,
2139 "error_reason": error_reason,
2140 "param_reqs": param_reqs,
2141 }
2142 return info_dict
2143
2144 @staticmethod
2145 def evInputListThenGraphMismatch(check=False, **kwargs):
2146 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2147 param_reqs = {"rank": None, "dtype": None, "shape": None}
2148 error_result = False
2149 error_reason = "Input list shape does not match then-graph shape"
2150
2151 if check:
2152 a = kwargs["a"]
2153 b = kwargs["b"]
2154 basicBlocks = kwargs["basicBlocks"]
2155 then_block = basicBlocks[1]
2156 then_inputs = then_block.inputs
2157 then_tens = then_block.tensors
2158 if (a.shape != then_tens[then_inputs[0]].shape) or (
2159 b.shape != then_tens[then_inputs[1]].shape
2160 ):
2161 error_result = True
2162
2163 info_dict = {
2164 "error_name": error_name,
2165 "error_result": error_result,
2166 "error_reason": error_reason,
2167 "param_reqs": param_reqs,
2168 }
2169 return info_dict
2170
2171 @staticmethod
2172 def evInputListElseGraphMismatch(check=False, **kwargs):
2173 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2174 param_reqs = {"rank": None, "dtype": None, "shape": None}
2175 error_result = False
2176 error_reason = "Input list shape does not match else-graph shape"
2177
2178 if check:
2179 a = kwargs["a"]
2180 b = kwargs["b"]
2181 basicBlocks = kwargs["basicBlocks"]
2182 else_block = basicBlocks[2]
2183 else_inputs = else_block.inputs
2184 else_tens = else_block.tensors
2185 if (a.shape != else_tens[else_inputs[0]].shape) or (
2186 b.shape != else_tens[else_inputs[1]].shape
2187 ):
2188 error_result = True
2189
2190 info_dict = {
2191 "error_name": error_name,
2192 "error_result": error_result,
2193 "error_reason": error_reason,
2194 "param_reqs": param_reqs,
2195 }
2196 return info_dict
2197
2198 @staticmethod
2199 def evOutputListThenGraphMismatch(check=False, **kwargs):
2200 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2201 param_reqs = {"rank": None, "dtype": None, "shape": None}
2202 error_result = False
2203 error_reason = "Output list shape does not match then-graph shape"
2204
2205 if check:
2206 basicBlocks = kwargs["basicBlocks"]
2207 cond_block = basicBlocks[0]
2208 cond_outputs = cond_block.outputs
2209 cond_tens = cond_block.tensors
2210 then_block = basicBlocks[1]
2211 then_outputs = then_block.outputs
2212 then_tens = then_block.tensors
2213 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2214 error_result = True
2215
2216 info_dict = {
2217 "error_name": error_name,
2218 "error_result": error_result,
2219 "error_reason": error_reason,
2220 "param_reqs": param_reqs,
2221 }
2222 return info_dict
2223
2224 @staticmethod
2225 def evOutputListElseGraphMismatch(check=False, **kwargs):
2226 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2227 param_reqs = {"rank": None, "dtype": None, "shape": None}
2228 error_result = False
2229 error_reason = "Output list shape does not match else-graph shape"
2230
2231 if check:
2232 basicBlocks = kwargs["basicBlocks"]
2233 cond_block = basicBlocks[0]
2234 cond_outputs = cond_block.outputs
2235 cond_tens = cond_block.tensors
2236 else_block = basicBlocks[2]
2237 else_outputs = else_block.outputs
2238 else_tens = else_block.tensors
2239 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2240 error_result = True
2241
2242 info_dict = {
2243 "error_name": error_name,
2244 "error_result": error_result,
2245 "error_reason": error_reason,
2246 "param_reqs": param_reqs,
2247 }
2248 return info_dict
2249
2250 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002251 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2252 error_name = ErrorIf.CondIfCondNotMatchingBool
2253 param_reqs = {"rank": None, "dtype": None, "shape": None}
2254 error_result = False
2255 error_reason = "Conditional tensor does not match bool type"
2256
2257 if check:
2258 cond = kwargs["cond"]
2259 if cond.dtype != DType.BOOL:
2260 error_result = True
2261
2262 info_dict = {
2263 "error_name": error_name,
2264 "error_result": error_result,
2265 "error_reason": error_reason,
2266 "param_reqs": param_reqs,
2267 }
2268 return info_dict
2269
2270 @staticmethod
2271 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2272 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2273 param_reqs = {"rank": None, "dtype": None, "shape": None}
2274 error_result = False
2275 error_reason = "Conditional tensor is not equal to a size of one"
2276
2277 if check:
2278 cond = kwargs["cond"]
2279 # Size of 1 is equivalent to rank 0
2280 if len(cond.shape) != 0:
2281 error_result = True
2282
2283 info_dict = {
2284 "error_name": error_name,
2285 "error_result": error_result,
2286 "error_reason": error_reason,
2287 "param_reqs": param_reqs,
2288 }
2289 return info_dict
2290
2291 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002292 def evInputListOutputListMismatch(check=False, **kwargs):
2293 error_name = ErrorIf.InputListOutputListMismatch
2294 param_reqs = {"rank": None, "dtype": None, "shape": None}
2295 error_result = False
2296 error_reason = "Input list does not match output list"
2297
2298 if check:
2299 basicBlocks = kwargs["basicBlocks"]
2300 while_block = basicBlocks[0]
2301 while_inputs = while_block.inputs
2302 while_outputs = while_block.outputs
2303 while_tens = while_block.tensors
2304 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2305 error_result = True
2306
2307 info_dict = {
2308 "error_name": error_name,
2309 "error_result": error_result,
2310 "error_reason": error_reason,
2311 "param_reqs": param_reqs,
2312 }
2313 return info_dict
2314
2315 @staticmethod
2316 def evInputListCondGraphMismatch(check=False, **kwargs):
2317 error_name = ErrorIf.InputListCondGraphMismatch
2318 param_reqs = {"rank": None, "dtype": None, "shape": None}
2319 error_result = False
2320 error_reason = "Input list does not match cond graph"
2321
2322 if check:
2323 basicBlocks = kwargs["basicBlocks"]
2324 while_block = basicBlocks[0]
2325 while_inputs = while_block.inputs
2326 while_tens = while_block.tensors
2327 cond_block = basicBlocks[1]
2328 cond_inputs = cond_block.inputs
2329 cond_tens = cond_block.tensors
2330 if (
2331 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2332 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2333 error_result = True
2334
2335 info_dict = {
2336 "error_name": error_name,
2337 "error_result": error_result,
2338 "error_reason": error_reason,
2339 "param_reqs": param_reqs,
2340 }
2341 return info_dict
2342
2343 @staticmethod
2344 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2345 error_name = ErrorIf.InputListBodyGraphInputMismatch
2346 param_reqs = {"rank": None, "dtype": None, "shape": None}
2347 error_result = False
2348 error_reason = "Input list does not match body graph input"
2349
2350 if check:
2351 basicBlocks = kwargs["basicBlocks"]
2352 while_block = basicBlocks[0]
2353 while_inputs = while_block.inputs
2354 while_tens = while_block.tensors
2355 body_block = basicBlocks[2]
2356 body_outputs = body_block.inputs
2357 body_tens = body_block.tensors
2358 if (
2359 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2360 ) or (
2361 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2362 ):
2363 error_result = True
2364
2365 info_dict = {
2366 "error_name": error_name,
2367 "error_result": error_result,
2368 "error_reason": error_reason,
2369 "param_reqs": param_reqs,
2370 }
2371 return info_dict
2372
2373 @staticmethod
2374 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2375 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2376 param_reqs = {"rank": None, "dtype": None, "shape": None}
2377 error_result = False
2378 error_reason = "Input list does not match body graph output"
2379
2380 if check:
2381 basicBlocks = kwargs["basicBlocks"]
2382 while_block = basicBlocks[0]
2383 while_inputs = while_block.inputs
2384 while_tens = while_block.tensors
2385 body_block = basicBlocks[2]
2386 body_outputs = body_block.outputs
2387 body_tens = body_block.tensors
2388 if (
2389 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2390 ) or (
2391 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2392 ):
2393 error_result = True
2394 info_dict = {
2395 "error_name": error_name,
2396 "error_result": error_result,
2397 "error_reason": error_reason,
2398 "param_reqs": param_reqs,
2399 }
2400 return info_dict
2401
2402 @staticmethod
2403 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2404 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2405 param_reqs = {"rank": None, "dtype": None, "shape": None}
2406 error_result = False
2407 error_reason = "Cond graph output is not a match list of booleans"
2408
2409 if check:
2410 basicBlocks = kwargs["basicBlocks"]
2411 cond_block = basicBlocks[1]
2412 cond_outputs = cond_block.outputs
2413 cond_tens = cond_block.tensors
2414 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2415 error_result = True
2416
2417 info_dict = {
2418 "error_name": error_name,
2419 "error_result": error_result,
2420 "error_reason": error_reason,
2421 "param_reqs": param_reqs,
2422 }
2423 return info_dict
2424
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002425 @staticmethod
2426 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2427 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2428 param_reqs = {"rank": None, "dtype": None, "shape": None}
2429 error_result = False
2430 error_reason = "Cond graph output is not a shape of size one"
2431
2432 if check:
2433 basicBlocks = kwargs["basicBlocks"]
2434 cond_block = basicBlocks[1]
2435 cond_outputs = cond_block.outputs
2436 cond_tens = cond_block.tensors
2437 # Size of 1 is equivalent to rank 0
2438 if len(cond_tens[cond_outputs[0]].shape) != 0:
2439 error_result = True
2440
2441 info_dict = {
2442 "error_name": error_name,
2443 "error_result": error_result,
2444 "error_reason": error_reason,
2445 "param_reqs": param_reqs,
2446 }
2447 return info_dict
2448
Luke Hutton261b7b62023-01-10 14:50:31 +00002449 @staticmethod
2450 def evKernelNotPowerOfTwo(check=False, **kwargs):
2451 error_name = ErrorIf.KernelNotPowerOfTwo
2452 param_reqs = {"rank": None, "dtype": None, "shape": None}
2453 error_result = False
2454 error_reason = "kernel height and/or width not a power of two"
2455
2456 def is_power_of_two(x):
2457 return math.log(x, 2).is_integer()
2458
2459 if check:
2460 shape = kwargs["input_shape"]
2461 if len(shape) == 3:
2462 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2463 error_result = not valid_kernel
2464
2465 info_dict = {
2466 "error_name": error_name,
2467 "error_result": error_result,
2468 "error_reason": error_reason,
2469 "param_reqs": param_reqs,
2470 }
2471 return info_dict
2472
Luke Hutton57287132023-02-06 14:54:18 +00002473 @staticmethod
2474 def evFFTInputShapeMismatch(check=False, **kwargs):
2475 error_name = ErrorIf.FFTInputShapeMismatch
2476 param_reqs = {"rank": None, "dtype": None, "shape": None}
2477 error_result = False
2478 error_reason = "Mismatch between real and imaginary input shapes"
2479
2480 if check:
2481 input1 = kwargs["input1"]
2482 input2 = kwargs["input2"]
2483
2484 if input1.shape != input2.shape:
2485 error_result = True
2486
2487 info_dict = {
2488 "error_name": error_name,
2489 "error_result": error_result,
2490 "error_reason": error_reason,
2491 "param_reqs": param_reqs,
2492 }
2493 return info_dict
2494
2495 @staticmethod
2496 def evFFTOutputShapeMismatch(check=False, **kwargs):
2497 error_name = ErrorIf.FFTOutputShapeMismatch
2498 param_reqs = {"rank": None, "dtype": None, "shape": None}
2499 error_result = False
2500 error_reason = (
2501 "Mismatch between provided and expected output kernel (H, W) shape"
2502 )
2503
2504 if check:
2505 op = kwargs["op"]
2506 input_shape = kwargs["input_shape"]
2507
2508 if len(input_shape) == 3:
2509 output_shapes = kwargs["output_shape"]
2510
2511 # Ignoring batch size (N) from input shape
2512 expected_shape = input_shape[1:]
2513 if op["op"] == Op.RFFT2D:
2514 expected_shape[1] = expected_shape[1] // 2 + 1
2515
2516 # Ignoring batch size (N) from output shapes
2517 output_shape_0 = output_shapes[0][1:]
2518 output_shape_1 = output_shapes[1][1:]
2519 # Ensure sure the kernel sizes (H, W) of both outputs match the expected
2520 if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
2521 error_result = True
2522
2523 info_dict = {
2524 "error_name": error_name,
2525 "error_result": error_result,
2526 "error_reason": error_reason,
2527 "param_reqs": param_reqs,
2528 }
2529 return info_dict
2530
Jerry Ge264f7fa2023-04-21 22:49:57 +00002531 @staticmethod
Jerry Ge135c9552023-05-23 20:59:32 +00002532 def calculateBroadcastShape(input_shape_a, input_shape_b):
2533 if input_shape_a is not None and input_shape_b is not None:
2534 calculated_shape = input_shape_a.copy()
2535 for idx in range(len(calculated_shape)):
2536 if calculated_shape[idx] == 1:
2537 calculated_shape[idx] = input_shape_b[idx]
2538 elif (
2539 input_shape_b[idx] != 1
2540 and input_shape_b[idx] != calculated_shape[idx]
2541 ):
2542 return None
2543 return calculated_shape
2544 else:
2545 return None
2546
2547 @staticmethod
2548 def evBroadcastShapesMismatch(check=False, **kwargs):
2549 error_name = ErrorIf.BroadcastShapesMismatch
2550 param_reqs = {"rank": None, "dtype": None, "shape": None}
2551 error_result = False
2552 error_reason = "Broadcast shape calculating failed"
2553
2554 if check:
2555 input_shape_a = kwargs["input1"].shape
2556 input_shape_b = kwargs["input2"].shape
2557 input_shape_c = (
2558 kwargs["input3"].shape if "input3" in kwargs else input_shape_b
2559 )
2560
2561 if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
2562 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
2563 input_shape_c,
2564 TosaErrorValidator.calculateBroadcastShape(
2565 input_shape_a, input_shape_b
2566 ),
2567 )
2568 error_result = calculated_shape is None
2569
2570 info_dict = {
2571 "error_name": error_name,
2572 "error_result": error_result,
2573 "error_reason": error_reason,
2574 "param_reqs": param_reqs,
2575 }
2576 return info_dict
2577
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002578
2579class TosaInvalidValidator:
2580 @staticmethod
2581 def ivWrongDataTypeOrModeResize(**kwargs):
2582 input_dtype = kwargs["input_dtype"]
2583 args = kwargs["args"]
2584 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002585 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002586
2587 if mode == ResizeMode.BILINEAR:
2588 # Invalid output data type / Invalid input datatype
2589 return (
2590 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002591 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002592 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002593 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002594 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002595 )
2596 elif mode == ResizeMode.NEAREST:
2597 # Invalid output data type / Invalid input datatype
2598 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002599 input_dtype
2600 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002601 )
2602 else:
2603 # Invalid resize mode
2604 return True
2605
2606 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002607 def ivHeightWidthInvalid(**kwargs):
2608 opName = kwargs["opName"]
2609
2610 inputShapes = kwargs["shapeList"]
2611 input_shape = inputShapes[0]
2612
2613 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002614
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002615 if isinstance(args, dict):
2616 args_dict = args
2617 else:
2618 # Create args_dict from list elements
2619 # TODO - Remove this once all NWHC operators agFunctions have been
2620 # converted to args_dict output
2621
2622 # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
2623 stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
2624 args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]}
2625 # Alias different info for each op
2626 args_dict["kernel"] = args[pad_idx + 1]
2627 args_dict["out_shape"] = args[pad_idx + 1]
2628 args_dict["dilation"] = args[pad_idx + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002629
2630 # Common info for all ops
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002631 strides = args_dict["stride"]
2632 padding = args_dict["pad"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002633
2634 if opName.endswith("pool2d"):
2635 # avg_pool2d, max_pool2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002636 kernel_shape = args_dict["kernel"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002637 h = (
2638 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2639 ) // strides[0]
2640 w = (
2641 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2642 ) // strides[1]
2643 # return True if any dimension is < 1
2644 return h < 1 or w < 1
2645
2646 if opName.startswith("transpose_conv2d"):
2647 # transpose_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002648 output_shape = args_dict["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002649 filter_shape = inputShapes[1]
2650 kernel_shape = filter_shape[1:-1]
2651
TatWai Chong24594f52022-06-08 00:48:04 -07002652 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002653 """Calculate the transpose_conv2d output size for a dimension."""
2654 return (in_size - 1) * stride + kernel_size + in_pad + out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002655
Jeremy Johnson0c716862023-04-13 17:18:19 +01002656 h = get_out_size(
2657 input_shape[1],
2658 strides[0],
2659 kernel_shape[0],
2660 padding[0],
2661 padding[1],
2662 )
2663 w = get_out_size(
2664 input_shape[2],
2665 strides[1],
2666 kernel_shape[1],
2667 padding[2],
2668 padding[3],
2669 )
2670 if output_shape[1] == h and output_shape[2] == w:
2671 return False
2672 # output shape does not match the expected shape
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002673 return True
2674
2675 if "conv2d" in opName or "conv3d" in opName:
2676 # conv2d, conv3d, depthwise_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002677 dilations = args_dict["dilation"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002678 filter_shape = inputShapes[1]
2679 kernel_shape = (
2680 filter_shape[0:2]
2681 if opName.startswith("depthwise_conv2d")
2682 else filter_shape[1:-1]
2683 )
2684
2685 for i in range(len(kernel_shape)):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002686 pad_offset = i * 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002687 dim = (
2688 input_shape[i + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002689 - 1
2690 + padding[pad_offset]
2691 + padding[pad_offset + 1]
2692 - (kernel_shape[i] - 1) * dilations[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002693 ) // strides[i] + 1
2694 # return True if any dimension is < 1
2695 if dim < 1:
2696 return True
2697 return False
2698
2699 assert False, f"Unrecognized Op: {opName}"
2700
2701 @staticmethod
2702 def ivNonPositiveOutputShape(**kwargs):
2703 args = kwargs["args"]
Jeremy Johnson95a67102024-01-10 14:16:39 +00002704 output_shape = args["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002705 if output_shape[1] <= 0 or output_shape[2] <= 0:
2706 # Negative output shape
2707 return True
2708 return False