blob: 9a88acb22cd98b15dfd76af165ac4ada08997393 [file] [log] [blame]
Won Jeon74342e52024-01-09 00:34:40 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01006from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from generator.tosa_utils import product
8from generator.tosa_utils import usableDTypes
9from generator.tosa_utils import valueToName
10from tosa.DType import DType
11from tosa.Op import Op
12from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000013
Matthew Haddone86fd342021-09-07 16:12:21 +010014
15class ErrorIf(object):
16 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010017 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
18 ScaleNLargerMax = "ScaleNLargerMax"
19 ScaleDLargerMax = "ScaleDLargerMax"
20 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010021 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010022 BorderSmallerMin = "BorderSmallerMin"
23 BorderLargerEqualMax = "BorderLargerEqualMax"
24 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
25 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010026 WrongInputType = "WrongInputType"
27 WrongOutputType = "WrongOutputType"
28 WrongInputList = "WrongInputList"
29 WrongOutputList = "WrongOutputList"
30 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010031 BatchMismatch = "BatchMismatch"
32 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010033 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000034 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010037 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 AxisSmallerZero = "AxisSmallerZero"
39 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010040 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
41 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010042 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010043 KernelSmallerOne = "KernelSmallerOne"
44 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000045 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010046 PadSmallerZero = "PadSmallerZero"
47 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010048 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010049 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010050 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
51 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
52 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010053 ScaleNotTrue = "ScaleNotTrue"
54 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010055 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
56 StartSmallerZero = "StartSmallerZero"
57 SizeSmallerEqualZero = "SizeSmallerEqualZero"
58 StartSizeOutsideBounds = "StartSizeOutsideBounds"
59 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
60 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
61 IndexOutsideBounds = "IndexOutsideBounds"
62 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010063 MaxSmallerMin = "MaxSmallerMin"
64 ConcatInputRankMismatch = "ConcatInputRankMismatch"
65 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010066 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010067 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
68 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
69 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
70 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
71 InputListOutputListMismatch = "InputListOutputListMismatch"
72 InputListCondGraphMismatch = "InputListCondGraphMismatch"
73 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
74 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
75 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010076 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
77 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000078 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
79 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
80 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000081 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Luke Hutton57287132023-02-06 14:54:18 +000082 FFTInputShapeMismatch = "FFTInputShapeMismatch"
83 FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
Jerry Ge264f7fa2023-04-21 22:49:57 +000084 ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
85 ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
Jerry Ge135c9552023-05-23 20:59:32 +000086 BroadcastShapesMismatch = "BroadcastShapesMismatch"
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +000087 WrongAccumulatorType = "WrongAccumulatorType"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010088
89
90class TosaErrorIfArgGen:
91 @staticmethod
92 def eiResizeErrorIf(
93 testGen,
94 error_name,
95 mode,
96 dtype,
97 shapeList,
98 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010099 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100101 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100102 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100103 if error_name == ErrorIf.ScaleSmallerEqualZero:
104 index = testGen.randInt(low=0, high=4)
105 scale[index] = testGen.rng.choice([-2, -1, 0])
106 elif error_name == ErrorIf.ScaleNLargerMax:
107 index = testGen.rng.choice([0, 2])
108 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
109 elif error_name == ErrorIf.ScaleDLargerMax:
110 index = testGen.rng.choice([1, 3])
111 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100112
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100113 if error_name == ErrorIf.OffsetLargerEqualMax:
114 index = testGen.rng.choice([0, 1])
115 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
116 elif error_name == ErrorIf.OffsetSmallerMin:
117 index = testGen.rng.choice([0, 1])
118 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
119
120 if error_name == ErrorIf.BorderLargerEqualMax:
121 index = testGen.rng.choice([0, 1])
122 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
123 elif error_name == ErrorIf.BorderSmallerMin:
124 index = testGen.rng.choice([0, 1])
125 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100126
127 if error_name == ErrorIf.WrongOutputType:
128 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
129 incorrect_types = (
130 DType.INT4,
131 DType.INT16,
132 DType.INT32,
133 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100134 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100135 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100136 )
137 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
138 incorrect_types = (
139 DType.INT4,
140 DType.INT8,
141 DType.INT32,
142 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100143 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100144 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100145 )
146 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
147 incorrect_types = (
148 DType.INT4,
149 DType.INT8,
150 DType.INT16,
151 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100152 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100153 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100154 )
155 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
156 incorrect_types = (
157 DType.INT4,
158 DType.INT8,
159 DType.INT16,
160 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100161 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100162 DType.FP16,
163 )
164 elif dtype == DType.FP16:
165 incorrect_types = (
166 DType.INT4,
167 DType.INT8,
168 DType.INT16,
169 DType.INT32,
170 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100171 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100172 )
James Ward24dbc422022-10-19 12:20:31 +0100173 elif dtype == DType.BF16:
174 incorrect_types = (
175 DType.INT4,
176 DType.INT8,
177 DType.INT16,
178 DType.INT32,
179 DType.INT48,
180 DType.FP32,
181 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100182 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100183 incorrect_types = (
184 DType.INT4,
185 DType.INT8,
186 DType.INT16,
187 DType.INT32,
188 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100189 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100190 )
191 outputDType = testGen.rng.choice(a=incorrect_types)
192
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100193 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100194
195 @staticmethod
196 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
197 if (
198 error_name == ErrorIf.StrideSmallerOne
199 # padding must not exceed the kernel size
200 and pad[0] < kernel[0]
201 and pad[1] < kernel[0]
202 and pad[2] < kernel[1]
203 and pad[3] < kernel[1]
204 ):
205 wrongStride = (
206 testGen.rng.choice([0, -1, -2, -3]),
207 testGen.rng.choice([0, -1, -2, -3]),
208 )
209 return wrongStride, pad, kernel
210 elif error_name == ErrorIf.PadSmallerZero:
211 wrongPad = (
212 testGen.rng.choice([-1, -2, -3]),
213 testGen.rng.choice([-1, -2, -3]),
214 testGen.rng.choice([-1, -2, -3]),
215 testGen.rng.choice([-1, -2, -3]),
216 )
217 return stride, wrongPad, kernel
218 elif error_name == ErrorIf.KernelSmallerOne:
219 wrongKernel = (
220 testGen.rng.choice([0, -1, -2, -3]),
221 testGen.rng.choice([0, -1, -2, -3]),
222 )
223 return stride, pad, wrongKernel
224 elif error_name == ErrorIf.PadLargerEqualKernel:
225 wrongPad = (
226 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
227 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
228 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
229 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
230 )
231 return stride, wrongPad, kernel
232 else:
233 return None, None, None
234
235 @staticmethod
236 def eiRescaleWrongOutputType(input_dtype, output_dtype):
237 if input_dtype == DType.INT8:
238 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
239 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100240 elif input_dtype == DType.INT16:
241 if output_dtype not in [
242 DType.UINT8,
243 DType.INT8,
244 DType.UINT16,
245 DType.INT16,
246 DType.INT32,
247 ]:
248 return True
249 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100250 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
251 return True
252 elif input_dtype == DType.INT48:
253 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
254 return True
255 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100256 if output_dtype not in [DType.INT8, DType.INT16]:
257 return True
258 elif input_dtype == DType.UINT16:
259 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100260 return True
261 return False
262
263 @staticmethod
264 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
265 # Mess up input/output tensors for ERROR_IF checks
266 if error_name == "WrongInputList":
267 add_input = testGen.rng.choice([True, False])
268 if add_input:
269 input_list.append("eiDummyInput")
270 else:
271 input_list = input_list[:-1]
272 elif error_name == "WrongOutputList":
273 add_output = testGen.rng.choice([True, False])
274 if add_output:
275 output_list.append("eiDummyOutput")
276 else:
277 output_list = []
278 return input_list, output_list
279
280 @staticmethod
281 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
282 """Restrict the dimensions and overall size of a shape to
283 max_dim and max_items.
284 """
285 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
286 while product(new_shape) > max_items:
287 new_shape = [max(d - 1, 1) for d in new_shape]
288 return new_shape
289
290 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
291 if error_name == ErrorIf.StartSmallerZero:
292 newStart = []
293 for i in range(len(input_shape)):
294 newStart.append(testGen.rng.choice([-3, -2, -1]))
295 return newStart, size
296 elif error_name == ErrorIf.SizeSmallerEqualZero:
297 newSize = []
298 for i in range(len(input_shape)):
299 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
300 return start, newSize
301 elif error_name == ErrorIf.StartSizeOutsideBounds:
302 newStart, newSize = [], []
303 for i in range(len(input_shape)):
304 newStart.append(input_shape[i] - 1)
305 newSize.append(testGen.rng.choice([2, 3, 4]))
306 return newStart, newSize
307 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
308 remove = testGen.rng.choice([True, False])
TatWai Chongf15bad82024-01-31 21:33:27 -0800309
310 # Get an empty tensor when diminishing dimension on 1-d tensor.
311 if len(start) == 1 or len(size) == 1:
312 remove = False
313
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100314 if remove:
315 newStart = start[1:]
316 newSize = size[1:]
317 else:
318 newStart = start
319 newStart.append(1)
320 newSize = size
321 newSize.append(1)
322 return newStart, newSize
323 else:
324 return start, size
325
326 @staticmethod
327 def eiCastErrorIf(testGen, input_dtype):
James Ward736fd1a2023-01-23 17:13:37 +0000328 if input_dtype in [DType.BOOL, DType.FP32]:
329 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
330 elif input_dtype in [DType.FP16, DType.BF16]:
331 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100332 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
333 outputDType = [DType.INT48]
334 else:
James Ward736fd1a2023-01-23 17:13:37 +0000335 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100336 return outputDType
337
338
339class TosaErrorValidator:
340 @staticmethod
341 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
342 """Check ERROR_IF statements are caught and set the expected result.
343
344 Args:
345 serializer: the serializer to set the expected result in
346 validator_fcns: a sequence of validator functions to verify the result
347 error_name: the name of the ERROR_IF condition to check for
348 kwargs: keyword arguments for the validator functions
349 Returns:
350 True if the result matches the expected result; otherwise False
351 """
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000352 if validator_fcns is None:
353 # Nothing to do
354 return True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100355 overall_result = True
356 for val_fcn in validator_fcns:
357 val_result = val_fcn(True, **kwargs)
358 validator_name = val_result["error_name"]
359 error_result = val_result["error_result"]
360 error_reason = val_result["error_reason"]
361
362 # expect an error IFF the error_name and validator_name match
363 expected_result = error_result == (error_name == validator_name)
364 overall_result &= expected_result
365
366 if expected_result and error_result:
367 serializer.setExpectedReturnCode(2, True, desc=error_reason)
368 elif error_result: # and not expected_result
369 print(
370 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
371 f" Expected: {error_name}, Got: {validator_name}"
372 )
373 elif not expected_result: # and not error_result
374 print(
375 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
376 f" Expected: {error_name}"
377 )
378
379 if not expected_result:
380 for k, v in sorted(kwargs.items()):
381 if k != "op":
382 if k.endswith("dtype"):
383 v = valueToName(DType, v)
384 print(f" {k} = {v}")
385
386 return overall_result
387
388 @staticmethod
389 def evWrongInputType(check=False, **kwargs):
390 error_result = False
391
392 # Find the unsupported input data types
393 op = kwargs["op"]
394 input_dtypes = op["types"]
395 allowed_input_dtypes = {
396 t[0] if isinstance(t, list) else t for t in input_dtypes
397 }
398 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
399
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100400 # Turn the wrong dtypes into required list of types
401 if op["op"] in [
402 Op.FULLY_CONNECTED,
403 Op.CONV2D,
404 Op.CONV3D,
405 Op.DEPTHWISE_CONV2D,
406 Op.TRANSPOSE_CONV2D,
407 ]:
408 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
409
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100410 if op["op"] == Op.CLAMP:
411 wrong_input_dtypes.remove(DType.INT48)
412
413 if check:
414 input_dtype = kwargs["input_dtype"]
415 if input_dtype not in allowed_input_dtypes:
416 error_result = True
417
418 info_dict = {
419 "error_name": ErrorIf.WrongInputType,
420 "error_result": error_result,
421 "error_reason": "Input data type not supported for this operator",
422 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
423 }
424 return info_dict
425
426 @staticmethod
427 def evWrongOutputType(check=False, **kwargs):
428 error_result = False
429
430 if check:
431 input_dtype = kwargs["input_dtype"]
432 output_dtype = kwargs["output_dtype"]
433 op = kwargs["op"]
434
435 if op["op"] == Op.RESIZE:
436 mode = kwargs["mode"]
437 if (
438 (
439 mode == ResizeMode.NEAREST
440 and input_dtype == DType.INT8
441 and output_dtype != DType.INT8
442 )
443 or (
444 mode == ResizeMode.NEAREST
445 and input_dtype == DType.INT16
446 and output_dtype != DType.INT16
447 )
448 or (
449 mode == ResizeMode.BILINEAR
450 and input_dtype == DType.INT8
451 and output_dtype != DType.INT32
452 )
453 or (
454 mode == ResizeMode.BILINEAR
455 and input_dtype == DType.INT16
456 and output_dtype != DType.INT48
457 )
James Ward8b390432022-08-12 20:48:56 +0100458 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100459 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100460 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100461 ):
462 error_result = True
463
464 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100465 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
466 input_dtype, output_dtype
467 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100468
469 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
470 if (
471 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
472 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100473 or (
474 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100475 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100476 )
James Ward24dbc422022-10-19 12:20:31 +0100477 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100478 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100479 ):
480 error_result = True
481
482 elif op["op"] == Op.ARGMAX:
483 if (
James Ward24dbc422022-10-19 12:20:31 +0100484 input_dtype
485 in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100486 and output_dtype != DType.INT32
487 ):
488 error_result = True
489
490 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100491 if (
James Ward24dbc422022-10-19 12:20:31 +0100492 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100493 and output_dtype != DType.INT32
494 ):
495 error_result = True
496 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100497 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100498 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
499 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100500 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 error_result = True
502
503 elif op["op"] == Op.TABLE:
504 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
505 error_result = True
506 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
507 error_result = True
508
509 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
510 if output_dtype != DType.BOOL:
511 error_result = True
512
513 elif op["op"] == Op.CAST:
514 if (
515 (
516 input_dtype == DType.BOOL
517 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
518 )
519 or (
520 input_dtype == DType.INT8
521 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100522 not in [
523 DType.BOOL,
524 DType.INT16,
525 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100526 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100527 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100528 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100529 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100530 )
531 or (
532 input_dtype == DType.INT16
533 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100534 not in [
535 DType.BOOL,
536 DType.INT8,
537 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100538 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100539 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100540 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100541 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100542 )
543 or (
544 input_dtype == DType.INT32
545 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100546 not in [
547 DType.BOOL,
548 DType.INT8,
549 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100550 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100551 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100552 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100553 ]
554 )
555 or (
556 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000557 and output_dtype
558 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100559 )
560 or (
James Ward24dbc422022-10-19 12:20:31 +0100561 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000562 and output_dtype
563 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +0100564 )
565 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100566 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000567 and output_dtype
568 not in [
569 DType.INT8,
570 DType.INT16,
571 DType.INT32,
572 DType.FP16,
573 DType.BF16,
574 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100575 )
576 ):
577 error_result = True
578
Luke Hutton57287132023-02-06 14:54:18 +0000579 elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000580 if not all([ty == input_dtype for ty in output_dtype]):
581 error_result = True
582
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100583 elif op["op"] in {
584 Op.CONV2D,
585 Op.CONV3D,
586 Op.DEPTHWISE_CONV2D,
587 Op.TRANSPOSE_CONV2D,
588 }:
589 if (
590 input_dtype == DType.INT8
591 and output_dtype != DType.INT32
592 or input_dtype == DType.INT16
593 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100594 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100595 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100596 or input_dtype == DType.BF16
597 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100598 or input_dtype == DType.FP32
599 and output_dtype != DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100600 ):
601 error_result = True
602 # invalid input types are ignored, to avoid reporting multiple errors
603
Won Jeon74342e52024-01-09 00:34:40 +0000604 elif op["op"] in {Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE}:
605 if output_dtype != DType.SHAPE:
606 error_result = True
607
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100608 else:
609 if output_dtype != input_dtype:
610 error_result = True
611
612 info_dict = {
613 "error_name": ErrorIf.WrongOutputType,
614 "error_result": error_result,
615 "error_reason": (
616 "Output data type not supported for this configuration of operator"
617 ),
618 "param_reqs": {"rank": None, "dtype": None, "shape": None},
619 }
620 return info_dict
621
622 @staticmethod
623 def evWrongRank(check=False, **kwargs):
624 all_ranks = (1, 2, 3, 4, 5)
625
626 # Make a list of incorrect ranks
627 assert "op" in kwargs
628 op = kwargs["op"]
629 rmin, rmax = op["rank"]
630 rank_range = range(rmin, rmax + 1)
631 incorrect_ranks = list(set(all_ranks) - set(rank_range))
632 # Remove small incorrect ranks to avoid index errors
633 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
634 # Set minimum incorrect rank to 3 to avoid index error
635 if op["op"] in [Op.RESIZE]:
636 incorrect_ranks = [3, 5]
637 elif op["op"] in [Op.TRANSPOSE]:
638 incorrect_ranks = [7, 8]
639 elif op["op"] in [Op.CONV3D]:
640 incorrect_ranks = [6, 7]
641
642 error_name = ErrorIf.WrongRank
643 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
644 error_result = False
645 error_reason = "Rank not supported for this operator"
646
647 if check:
648 input_shape = kwargs["input_shape"]
649
650 if (
651 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
652 and len(input_shape) != 4
653 ):
654 error_result = True
655 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
656 error_result = True
657 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
658 error_result = True
659 else:
660 if len(input_shape) not in rank_range:
661 error_result = True
662
663 info_dict = {
664 "error_name": error_name,
665 "error_result": error_result,
666 "error_reason": error_reason,
667 "param_reqs": param_reqs,
668 }
669 return info_dict
670
671 @staticmethod
672 def evWrongInputList(check=False, **kwargs):
673 error_name = ErrorIf.WrongInputList
674 param_reqs = {"rank": None, "dtype": None, "shape": None}
675 error_result = False
676 error_reason = "Op input list does not match expected input"
677
678 if check:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100679 input_list = kwargs["input_list"]
680 num_operands = kwargs["num_operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100681 if len(input_list) != num_operands:
682 error_result = True
683
684 info_dict = {
685 "error_name": error_name,
686 "error_result": error_result,
687 "error_reason": error_reason,
688 "param_reqs": param_reqs,
689 }
690 return info_dict
691
692 @staticmethod
693 def evWrongOutputList(check=False, **kwargs):
694 error_name = ErrorIf.WrongOutputList
695 param_reqs = {"rank": None, "dtype": None, "shape": None}
696 error_result = False
697 error_reason = "Op output list does not match expected output"
698
699 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000700 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100701 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000702 expected_length = 1
Luke Hutton57287132023-02-06 14:54:18 +0000703 if op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000704 expected_length = 2
705
706 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100707 error_result = True
708
709 info_dict = {
710 "error_name": error_name,
711 "error_result": error_result,
712 "error_reason": error_reason,
713 "param_reqs": param_reqs,
714 }
715 return info_dict
716
717 @staticmethod
718 def evMaxDimExceeded(check=False, **kwargs):
719 error_name = ErrorIf.MaxDimExceeded
720 param_reqs = {
721 "rank": [4, 4],
722 "dtype": [DType.INT8],
723 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
724 }
725 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100726 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100727
728 if check:
729 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100730 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100731 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100732 (input_shape[1] >= MAX_RESIZE_DIMENSION)
733 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
734 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
735 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100736 ):
737 error_result = True
738
739 info_dict = {
740 "error_name": error_name,
741 "error_result": error_result,
742 "error_reason": error_reason,
743 "param_reqs": param_reqs,
744 }
745 return info_dict
746
747 @staticmethod
748 def evBatchMismatch(check=False, **kwargs):
749 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000750 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100751 error_result = False
752 error_reason = "Input batch size not equal to output batch size"
753
754 assert "op" in kwargs
755 op = kwargs["op"]
756 rmin, rmax = op["rank"]
757 rank_range = range(rmin, rmax + 1)
758
759 if check:
760 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100761
Luke Hutton261b7b62023-01-10 14:50:31 +0000762 for output in kwargs["result_tensors"]:
763 output_shape = (
764 output.shape
765 ) # Note batch is expected to be the first dim
766 if (len(input_shape) in rank_range) and (
767 input_shape[0] != output_shape[0]
768 ):
769 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100770
771 info_dict = {
772 "error_name": error_name,
773 "error_result": error_result,
774 "error_reason": error_reason,
775 "param_reqs": param_reqs,
776 }
777 return info_dict
778
779 @staticmethod
780 def evChannelMismatch(check=False, **kwargs):
781 error_name = ErrorIf.ChannelMismatch
782 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
783 error_result = False
784 error_reason = "Input channel size not equal to output channel size"
785
786 assert "op" in kwargs
787 op = kwargs["op"]
788 rmin, rmax = op["rank"]
789 rank_range = range(rmin, rmax + 1)
790
791 if check:
792 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000793 for output in kwargs["result_tensors"]:
794 output_shape = output.shape # Note this is just (N, OH, OW, C)
795 if (len(input_shape) in rank_range) and (
796 input_shape[3] != output_shape[3]
797 ):
798 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100799
800 info_dict = {
801 "error_name": error_name,
802 "error_result": error_result,
803 "error_reason": error_reason,
804 "param_reqs": param_reqs,
805 }
806 return info_dict
807
808 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100809 def evScaleSmallerEqualZero(check=False, **kwargs):
810 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100811 param_reqs = {"rank": None, "dtype": None, "shape": None}
812 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100813 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100814
815 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100816 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100817
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100818 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100819 error_result = True
820
821 info_dict = {
822 "error_name": error_name,
823 "error_result": error_result,
824 "error_reason": error_reason,
825 "param_reqs": param_reqs,
826 }
827 return info_dict
828
829 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100830 def evScaleNLargerMax(check=False, **kwargs):
831 error_name = ErrorIf.ScaleNLargerMax
832 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100833 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100834 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100835
836 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100837 scale = kwargs["scale"]
838
839 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
840 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100841
842 info_dict = {
843 "error_name": error_name,
844 "error_result": error_result,
845 "error_reason": error_reason,
846 "param_reqs": param_reqs,
847 }
848 return info_dict
849
850 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100851 def evScaleDLargerMax(check=False, **kwargs):
852 error_name = ErrorIf.ScaleDLargerMax
853 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100854 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100855 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100856
857 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100858 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100859
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100860 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
861 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100862 ):
863 error_result = True
864
865 info_dict = {
866 "error_name": error_name,
867 "error_result": error_result,
868 "error_reason": error_reason,
869 "param_reqs": param_reqs,
870 }
871 return info_dict
872
873 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100874 def evOffsetSmallerMin(check=False, **kwargs):
875 error_name = ErrorIf.OffsetSmallerMin
876 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100877 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100878 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100879
880 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100881 scale = kwargs["scale"]
882 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100883
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100884 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100885 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100886 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100887 error_result = True
888
889 info_dict = {
890 "error_name": error_name,
891 "error_result": error_result,
892 "error_reason": error_reason,
893 "param_reqs": param_reqs,
894 }
895 return info_dict
896
897 @staticmethod
898 def evOffsetLargerEqualMax(check=False, **kwargs):
899 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100900 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901 error_result = False
902 error_reason = "Offset value larger than or equal to maximum value"
903
904 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100905 scale = kwargs["scale"]
906 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100908 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
909 error_result = True
910 elif (
911 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
912 ):
913 error_result = True
914
915 info_dict = {
916 "error_name": error_name,
917 "error_result": error_result,
918 "error_reason": error_reason,
919 "param_reqs": param_reqs,
920 }
921 return info_dict
922
923 @staticmethod
924 def evBorderSmallerMin(check=False, **kwargs):
925 error_name = ErrorIf.BorderSmallerMin
926 param_reqs = {"rank": None, "dtype": None, "shape": None}
927 error_result = False
928 error_reason = "Border value smaller than minimum value"
929
930 if check:
931 scale = kwargs["scale"]
932 border = kwargs["border"]
933
934 if (
935 scale[0] > 0
936 and scale[0] <= (1 << 11)
937 and (border[0] < (-16 * scale[0]))
938 ):
939 error_result = True
940 elif (
941 scale[2] > 0
942 and scale[2] <= (1 << 11)
943 and (border[1] < (-16 * scale[2]))
944 ):
945 error_result = True
946
947 info_dict = {
948 "error_name": error_name,
949 "error_result": error_result,
950 "error_reason": error_reason,
951 "param_reqs": param_reqs,
952 }
953 return info_dict
954
955 @staticmethod
956 def evBorderLargerEqualMax(check=False, **kwargs):
957 error_name = ErrorIf.BorderLargerEqualMax
958 param_reqs = {"rank": None, "dtype": None, "shape": None}
959 error_result = False
960 error_reason = "Border value larger than or equal to maximum value"
961
962 if check:
963 scale = kwargs["scale"]
964 border = kwargs["border"]
965
966 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
967 error_result = True
968 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
969 error_result = True
970
971 info_dict = {
972 "error_name": error_name,
973 "error_result": error_result,
974 "error_reason": error_reason,
975 "param_reqs": param_reqs,
976 }
977 return info_dict
978
979 @staticmethod
980 def checkResizeParams(scale, offset, border):
981 return (
982 min(scale) > 0
983 and max(scale[0], scale[2]) <= (1 << 11)
984 and scale[1] < 16 * scale[0]
985 and scale[3] < 16 * scale[2]
986 and offset[0] >= -scale[0]
987 and offset[1] >= -scale[2]
988 and offset[0] < 16 * scale[0]
989 and offset[1] < 16 * scale[2]
990 and border[0] >= -16 * scale[0]
991 and border[1] >= -16 * scale[2]
992 and border[0] < scale[0]
993 and border[1] < scale[2]
994 )
995
996 @staticmethod
997 def evResizeOutputShapeMismatch(check=False, **kwargs):
998 error_name = ErrorIf.ResizeOutputShapeMismatch
999 param_reqs = {"rank": None, "dtype": None, "shape": None}
1000 error_result = False
1001 error_reason = (
1002 "Mismatch between output shape provided and expected output shape"
1003 )
1004
1005 if check:
1006 input_shape = kwargs["input_shape"]
1007 output_shape = kwargs["output_shape"]
1008 scale = kwargs["scale"]
1009 offset = kwargs["offset"]
1010 border = kwargs["border"]
1011
1012 # Ensure parameters are valid
1013 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1014
1015 if (
1016 params_valid
1017 and max(output_shape) < MAX_RESIZE_DIMENSION
1018 and max(input_shape) < MAX_RESIZE_DIMENSION
1019 ):
1020 output_y = (
1021 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1022 ) // scale[1] + 1
1023 output_x = (
1024 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1025 ) // scale[3] + 1
1026
1027 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001028 error_result = True
1029
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001030 info_dict = {
1031 "error_name": error_name,
1032 "error_result": error_result,
1033 "error_reason": error_reason,
1034 "param_reqs": param_reqs,
1035 }
1036 return info_dict
1037
1038 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001039 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1040 error_name = ErrorIf.ResizeOutputShapeNonInteger
1041 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001042 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001043 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044
1045 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001046 input_shape = kwargs["input_shape"]
1047 scale = kwargs["scale"]
1048 offset = kwargs["offset"]
1049 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001050
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001051 # Ensure parameters are valid
1052 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001053
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001054 if params_valid:
1055 remainder_y = (
1056 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1057 ) % scale[1]
1058 remainder_x = (
1059 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1060 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001061
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001062 if max(remainder_y, remainder_x) > 0:
1063 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001064
1065 info_dict = {
1066 "error_name": error_name,
1067 "error_result": error_result,
1068 "error_reason": error_reason,
1069 "param_reqs": param_reqs,
1070 }
1071 return info_dict
1072
1073 @staticmethod
1074 def evRankMismatch(check=False, **kwargs):
1075 error_name = ErrorIf.RankMismatch
1076 param_reqs = {"rank": None, "dtype": None, "shape": None}
1077 error_result = False
1078 error_reason = "Input Rank does not match output rank"
1079
1080 if check:
1081 input1_shape = kwargs["input1"].shape
Luke Huttona4e48ca2023-02-22 11:53:48 +00001082 input2_shape = (
1083 kwargs["input2"].shape if "input2" in kwargs else input1_shape
1084 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001085 # In case of SELECT op
1086 input3_shape = (
1087 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1088 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001089
1090 for output in kwargs["result_tensors"]:
1091 output_shape = output.shape
1092 if (
1093 (len(input1_shape) != len(output_shape))
1094 or (len(input2_shape) != len(output_shape))
1095 or (len(input3_shape) != len(output_shape))
1096 ):
1097 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001098
1099 info_dict = {
1100 "error_name": error_name,
1101 "error_result": error_result,
1102 "error_reason": error_reason,
1103 "param_reqs": param_reqs,
1104 }
1105 return info_dict
1106
1107 @staticmethod
1108 def evDimensionMismatch(check=False, **kwargs):
1109 error_name = ErrorIf.DimensionMismatch
1110 param_reqs = {"rank": None, "dtype": None, "shape": None}
1111 error_result = False
1112 error_reason = "Input Dimensions do not match output"
1113
1114 if check:
1115 input1_shape = kwargs["input1"].shape
1116 input2_shape = kwargs["input2"].shape
1117 # In case of SELECT op
1118 input3_shape = (
1119 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1120 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001121
Won Jeon74342e52024-01-09 00:34:40 +00001122 op = kwargs["op"]
1123 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE):
1124 output_shape = kwargs["result_tensors"][0].shape
1125 if input1_shape != output_shape:
1126 error_result = True
1127
1128 elif len(input1_shape) == len(input2_shape) == len(input3_shape):
Jerry Ge135c9552023-05-23 20:59:32 +00001129 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
1130 input3_shape,
1131 TosaErrorValidator.calculateBroadcastShape(
1132 input1_shape, input2_shape
1133 ),
1134 )
1135 if calculated_shape is not None:
1136 # Valid inputs - check for output mismatch
1137 for output in kwargs["result_tensors"]:
1138 output_shape = output.shape
1139 if calculated_shape != output_shape:
1140 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001141
1142 info_dict = {
1143 "error_name": error_name,
1144 "error_result": error_result,
1145 "error_reason": error_reason,
1146 "param_reqs": param_reqs,
1147 }
1148 return info_dict
1149
1150 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001151 def _getZeroPoint(qinfo, index):
1152 """Return zero point value from quantization info.
1153
1154 Generally input_zp is index 0, output_zp is index 1
1155 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001156 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001157
1158 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001159 def evInputZeroPointNotZero(check=False, **kwargs):
1160 op = kwargs["op"]
1161 error_result = False
1162
1163 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001164 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001165
1166 # This does not apply to quantizable types
1167 inputDtypes = [
1168 dtype
1169 for dtype in op["types"]
1170 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1171 or (not isinstance(dtype, list) and dtype not in qTypes)
1172 ]
1173
1174 if check:
1175 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001176 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001177 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001178 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001179 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001180 (kwargs["input_dtype"], input_zero_point),
1181 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001182 ):
1183 if dtype not in qTypes and zp != 0:
1184 error_result = True
1185 break
1186 else:
1187 error_result = input_dtype not in qTypes and input_zero_point != 0
1188
1189 info_dict = {
1190 "error_name": ErrorIf.InputZeroPointNotZero,
1191 "error_result": error_result,
1192 "error_reason": "Input DType not INT8 and zero point not 0",
1193 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1194 }
1195 return info_dict
1196
1197 @staticmethod
1198 def evWeightZeroPointNotZero(check=False, **kwargs):
1199 op = kwargs["op"]
1200
1201 # exclude inputs with INT8 weights
1202 inputDtypes = [
1203 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1204 ]
1205
1206 error_name = ErrorIf.WeightZeroPointNotZero
1207 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1208 error_result = False
1209 error_reason = "Weight DType not INT8 and zero point not 0"
1210
1211 if check:
1212 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001213 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001214 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1215 error_result = True
1216
1217 info_dict = {
1218 "error_name": error_name,
1219 "error_result": error_result,
1220 "error_reason": error_reason,
1221 "param_reqs": param_reqs,
1222 }
1223 return info_dict
1224
1225 @staticmethod
1226 def evOutputZeroPointNotZero(check=False, **kwargs):
1227 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001228 inputDtypes = [
1229 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1230 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001231
1232 error_name = ErrorIf.OutputZeroPointNotZero
1233 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1234 error_result = False
1235 error_reason = "Output DType not INT8 and zero point not 0"
1236
1237 if check:
1238 input_dtype = kwargs["input_dtype"]
1239 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001240 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001241 if op["op"] == Op.AVG_POOL2D:
1242 if input_dtype != DType.INT8 and output_zero_point != 0:
1243 error_result = True
1244 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001245 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1246 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001247 ):
1248 error_result = True
1249
1250 info_dict = {
1251 "error_name": error_name,
1252 "error_result": error_result,
1253 "error_reason": error_reason,
1254 "param_reqs": param_reqs,
1255 }
1256 return info_dict
1257
1258 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001259 def evU16InputZeroPointNotValid(check=False, **kwargs):
1260 error_name = ErrorIf.U16InputZeroPointNotValid
1261 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1262 error_result = False
1263 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1264
1265 if check:
1266 input_dtype = kwargs["input_dtype"]
1267 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1268 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1269 0,
1270 32768,
1271 ]
1272
1273 info_dict = {
1274 "error_name": error_name,
1275 "error_result": error_result,
1276 "error_reason": error_reason,
1277 "param_reqs": param_reqs,
1278 }
1279 return info_dict
1280
1281 @staticmethod
1282 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1283 error_name = ErrorIf.U16OutputZeroPointNotValid
1284 param_reqs = {"rank": None, "dtype": None, "shape": None}
1285 error_result = False
1286 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1287
1288 if check:
1289 output_dtype = kwargs["output_dtype"]
1290 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1291
1292 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1293 0,
1294 32768,
1295 ]
1296
1297 info_dict = {
1298 "error_name": error_name,
1299 "error_result": error_result,
1300 "error_reason": error_reason,
1301 "param_reqs": param_reqs,
1302 }
1303 return info_dict
1304
1305 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001306 def evAxisSmallerZero(check=False, **kwargs):
1307 error_name = ErrorIf.AxisSmallerZero
1308 param_reqs = {"rank": None, "dtype": None, "shape": None}
1309 error_result = False
1310 error_reason = "Axis smaller than zero"
1311
1312 if check:
1313 axis = kwargs["axis"]
1314 if axis < 0:
1315 error_result = True
1316
1317 info_dict = {
1318 "error_name": error_name,
1319 "error_result": error_result,
1320 "error_reason": error_reason,
1321 "param_reqs": param_reqs,
1322 }
1323 return info_dict
1324
1325 @staticmethod
1326 def evAxisLargerRank(check=False, **kwargs):
1327 error_name = ErrorIf.AxisLargerRank
1328 param_reqs = {"rank": None, "dtype": None, "shape": None}
1329 error_result = False
1330 error_reason = "Axis larger than rank"
1331
1332 if check:
1333 axis = kwargs["axis"]
1334 shape = kwargs["input_shape"]
1335 if axis > len(shape):
1336 error_result = True
1337
1338 info_dict = {
1339 "error_name": error_name,
1340 "error_result": error_result,
1341 "error_reason": error_reason,
1342 "param_reqs": param_reqs,
1343 }
1344 return info_dict
1345
1346 @staticmethod
1347 def evShapeOfAxisNotOne(check=False, **kwargs):
1348 error_name = ErrorIf.ShapeOfAxisNotOne
1349 param_reqs = {"rank": None, "dtype": None, "shape": None}
1350 error_result = False
1351 error_reason = "shape[axis] is not equal to 1"
1352
1353 if check:
1354 axis = kwargs["axis"]
1355 shape = kwargs["output_shape"]
1356 if (0 <= axis < len(shape)) and shape[axis] != 1:
1357 error_result = True
1358
1359 info_dict = {
1360 "error_name": error_name,
1361 "error_result": error_result,
1362 "error_reason": error_reason,
1363 "param_reqs": param_reqs,
1364 }
1365 return info_dict
1366
1367 @staticmethod
1368 def evPadSmallerZero(check=False, **kwargs):
1369 error_name = ErrorIf.PadSmallerZero
1370 param_reqs = {"rank": None, "dtype": None, "shape": None}
1371 error_result = False
1372 error_reason = "At least one pad is smaller than zero"
1373
1374 if check:
1375 op = kwargs["op"]
1376 pad = kwargs["pad"]
1377 if op["op"] == Op.PAD:
1378 for padding in pad:
1379 if min(padding) < 0:
1380 error_result = True
1381 else:
1382 if min(pad) < 0:
1383 error_result = True
1384
1385 info_dict = {
1386 "error_name": error_name,
1387 "error_result": error_result,
1388 "error_reason": error_reason,
1389 "param_reqs": param_reqs,
1390 }
1391 return info_dict
1392
1393 @staticmethod
1394 def evPadLargerEqualKernel(check=False, **kwargs):
1395 error_name = ErrorIf.PadLargerEqualKernel
1396 param_reqs = {"rank": None, "dtype": None, "shape": None}
1397 error_result = False
1398 error_reason = "At least one pad is larger than kernel dimension"
1399
1400 if check:
1401 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001402 op = kwargs["op"]
1403 if op["op"] == Op.TRANSPOSE_CONV2D:
1404 # transpose_conv2d
1405 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001406 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001407 pad[0] <= -kernel[0]
1408 or pad[1] <= -kernel[0]
1409 or pad[2] <= -kernel[1]
1410 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001411 ):
1412 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001413 else:
1414 # pooling op
1415 kernel = kwargs["kernel"]
1416 if min(pad) > 0 and min(kernel) > 1:
1417 if (
1418 pad[0] >= kernel[0]
1419 or pad[1] >= kernel[0]
1420 or pad[2] >= kernel[1]
1421 or pad[3] >= kernel[1]
1422 ):
1423 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001424
1425 info_dict = {
1426 "error_name": error_name,
1427 "error_result": error_result,
1428 "error_reason": error_reason,
1429 "param_reqs": param_reqs,
1430 }
1431 return info_dict
1432
1433 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001434 def evPadOutputShapeMismatch(check=False, **kwargs):
1435 error_name = ErrorIf.PadOutputShapeMismatch
1436 param_reqs = {"rank": None, "dtype": None, "shape": None}
1437 error_result = False
1438 error_reason = "Pad output shape mismatch for requested padding"
1439
1440 if check:
1441 pad = kwargs["pad"]
1442 input_shape = kwargs["input_shape"]
1443 output_shape = kwargs["output_shape"]
1444 for dim, padding in enumerate(pad):
1445 expected_size = input_shape[dim] + padding[0] + padding[1]
1446 if expected_size != output_shape[dim]:
1447 error_result = True
1448
1449 info_dict = {
1450 "error_name": error_name,
1451 "error_result": error_result,
1452 "error_reason": error_reason,
1453 "param_reqs": param_reqs,
1454 }
1455 return info_dict
1456
1457 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001458 def checkPoolingParams(kernel, stride, pad):
1459 return (
1460 min(kernel) >= 1
1461 and min(stride) >= 1
1462 and min(pad) >= 0
1463 and not (
1464 pad[0] >= kernel[0]
1465 or pad[1] >= kernel[0]
1466 or pad[2] >= kernel[1]
1467 or pad[3] >= kernel[1]
1468 )
1469 )
1470
1471 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001472 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1473 error_name = ErrorIf.PoolingOutputShapeMismatch
1474 param_reqs = {"rank": None, "dtype": None, "shape": None}
1475 error_result = False
1476 error_reason = (
1477 "Mismatch between output shape provided and expected output shape"
1478 )
1479
1480 if check:
1481 pad = kwargs["pad"]
1482 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1483
1484 kernel = kwargs["kernel"]
1485 kernel_y, kernel_x = kernel[0], kernel[1]
1486
1487 input_shape = kwargs["input_shape"]
1488 IH, IW = input_shape[1], input_shape[2]
1489
1490 output_shape = kwargs["output_shape"]
1491 OH, OW = output_shape[1], output_shape[2]
1492
1493 stride = kwargs["stride"]
1494 stride_y, stride_x = stride[0], stride[1]
1495
1496 # calculate correct height, width dimensions
1497 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001498 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1499 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001500
1501 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001502 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001503
1504 if params_valid and (OH != y_correct or OW != x_correct):
1505 error_result = True
1506
1507 info_dict = {
1508 "error_name": error_name,
1509 "error_result": error_result,
1510 "error_reason": error_reason,
1511 "param_reqs": param_reqs,
1512 }
1513 return info_dict
1514
1515 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001516 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1517 error_name = ErrorIf.PoolingOutputShapeNonInteger
1518 param_reqs = {"rank": None, "dtype": None, "shape": None}
1519 error_result = False
1520 error_reason = "Parameters do not yield exact integer output dimensions"
1521
1522 if check:
1523 pad = kwargs["pad"]
1524 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1525
1526 kernel = kwargs["kernel"]
1527 kernel_y, kernel_x = kernel[0], kernel[1]
1528
1529 input_shape = kwargs["input_shape"]
1530 IH, IW = input_shape[1], input_shape[2]
1531
1532 stride = kwargs["stride"]
1533 stride_y, stride_x = stride[0], stride[1]
1534
1535 # calculate remainder of height, width dimensions
1536 if stride_x != 0 and stride_y != 0:
1537 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1538 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1539
1540 # ensure parameters are valid
1541 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1542 if params_valid and (y_remainder != 0 or x_remainder != 0):
1543 error_result = True
1544
1545 info_dict = {
1546 "error_name": error_name,
1547 "error_result": error_result,
1548 "error_reason": error_reason,
1549 "param_reqs": param_reqs,
1550 }
1551 return info_dict
1552
1553 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001554 def checkConvParams(op, weight_shape, stride, pad, dilation):
1555 if op == Op.TRANSPOSE_CONV2D:
1556 pad_ok = (
1557 pad[0] > -weight_shape[1]
1558 and pad[1] > -weight_shape[1]
1559 and pad[2] > -weight_shape[2]
1560 and pad[3] > -weight_shape[2]
1561 )
1562 else:
1563 pad_ok = min(pad) >= 0
1564
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001565 return (
1566 # Check kernel sizes
1567 min(weight_shape[1:-1]) >= 1
1568 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001569 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001570 and (dilation is None or min(dilation) >= 1)
1571 )
1572
1573 @staticmethod
1574 def evConvOutputShapeMismatch(check=False, **kwargs):
1575 error_name = ErrorIf.ConvOutputShapeMismatch
1576 param_reqs = {"rank": None, "dtype": None, "shape": None}
1577 error_result = False
1578 error_reason = (
1579 "Mismatch between output shape provided and expected output shape"
1580 )
1581
1582 if check:
1583 op = kwargs["op"]
1584 pad = kwargs["pad"]
1585 weight_shape = kwargs["weight_shape"]
1586 input_shape = kwargs["input_shape"]
1587 output_shape = kwargs["output_shape"]
1588 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1589 stride = kwargs["stride"]
1590
1591 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1592
1593 # calculate correct dimensions
1594 dims_correct = []
1595 if min(stride) > 0:
1596 for index in range(len(stride)):
1597 pad_offset = index * 2
1598 if op["op"] == Op.TRANSPOSE_CONV2D:
1599 dims_correct.append(
1600 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001601 + pad[pad_offset]
1602 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001603 + weight_shape[index + kernel_offset]
1604 )
1605 else:
1606 dims_correct.append(
1607 (
1608 input_shape[index + 1]
1609 - 1
1610 + pad[pad_offset]
1611 + pad[pad_offset + 1]
1612 - (weight_shape[index + kernel_offset] - 1)
1613 * dilation[index]
1614 )
1615 // stride[index]
1616 + 1
1617 )
1618
1619 # ensure parameters are valid
1620 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001621 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001622 )
1623
1624 if params_valid and output_shape[1:-1] != dims_correct:
1625 error_result = True
1626
1627 info_dict = {
1628 "error_name": error_name,
1629 "error_result": error_result,
1630 "error_reason": error_reason,
1631 "param_reqs": param_reqs,
1632 }
1633 return info_dict
1634
1635 @staticmethod
1636 def evConvOutputShapeNonInteger(check=False, **kwargs):
1637 error_name = ErrorIf.ConvOutputShapeNonInteger
1638 param_reqs = {"rank": None, "dtype": None, "shape": None}
1639 error_result = False
1640 error_reason = "Parameters do not yield exact integer output dimensions"
1641
1642 if check:
1643 op = kwargs["op"]
1644 pad = kwargs["pad"]
1645 weight_shape = kwargs["weight_shape"]
1646 input_shape = kwargs["input_shape"]
1647 dilation = kwargs["dilation"]
1648 stride = kwargs["stride"]
1649
1650 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1651
1652 # calculate correct height, width dimensions
1653 remainders = []
1654 if min(stride) > 0:
1655 for index in range(len(stride)):
1656 pad_offset = index * 2
1657 remainders.append(
1658 (
1659 input_shape[index + 1]
1660 - 1
1661 + pad[pad_offset]
1662 + pad[pad_offset + 1]
1663 - (weight_shape[index + kernel_offset] - 1)
1664 * dilation[index]
1665 )
1666 % stride[index]
1667 )
1668
1669 # ensure parameters are valid
1670 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001671 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001672 )
1673 if params_valid and max(remainders) > 0:
1674 error_result = True
1675
1676 info_dict = {
1677 "error_name": error_name,
1678 "error_result": error_result,
1679 "error_reason": error_reason,
1680 "param_reqs": param_reqs,
1681 }
1682 return info_dict
1683
1684 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001685 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1686 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1687 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1688 error_result = False
1689 error_reason = (
1690 "Mismatch between output shape provided and expected output shape"
1691 )
1692
1693 if check:
1694 output_shape = kwargs["output_shape"]
1695 input_shape = kwargs["input_shape"]
1696 axis = kwargs["axis"]
1697
1698 dimension_match = True
1699 axis_shift = 0
1700
1701 # Check that rank is correct before trying to check dimensions
1702 if (len(input_shape) - 1) == len(output_shape):
1703 for i in range(len(input_shape)):
1704 if i == axis:
1705 axis_shift = 1
1706 continue
1707 if input_shape[i] != output_shape[i - axis_shift]:
1708 dimension_match = False
1709
1710 if not dimension_match:
1711 error_result = True
1712
1713 info_dict = {
1714 "error_name": error_name,
1715 "error_result": error_result,
1716 "error_reason": error_reason,
1717 "param_reqs": param_reqs,
1718 }
1719 return info_dict
1720
1721 @staticmethod
1722 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1723 error_name = ErrorIf.ArgmaxOutputRankMismatch
1724 param_reqs = {"rank": None, "dtype": None, "shape": None}
1725 error_result = False
1726 error_reason = (
1727 "Mismatch between output shape provided and expected output shape"
1728 )
1729
1730 if check:
1731 output_shape = kwargs["output_shape"]
1732 input_shape = kwargs["input_shape"]
1733 axis = kwargs["axis"]
1734 valid_params = axis >= 0 and axis < len(input_shape)
1735
1736 if valid_params and (len(input_shape) - 1) != len(output_shape):
1737 error_result = True
1738
1739 info_dict = {
1740 "error_name": error_name,
1741 "error_result": error_result,
1742 "error_reason": error_reason,
1743 "param_reqs": param_reqs,
1744 }
1745 return info_dict
1746
1747 @staticmethod
1748 def evKernelSmallerOne(check=False, **kwargs):
1749 error_name = ErrorIf.KernelSmallerOne
1750 param_reqs = {"rank": None, "dtype": None, "shape": None}
1751 error_result = False
1752 error_reason = "At least one kernel dimension is smaller than zero"
1753
1754 if check:
1755 kernel = kwargs["kernel"]
1756 if min(kernel) < 1:
1757 error_result = True
1758
1759 info_dict = {
1760 "error_name": error_name,
1761 "error_result": error_result,
1762 "error_reason": error_reason,
1763 "param_reqs": param_reqs,
1764 }
1765 return info_dict
1766
1767 @staticmethod
1768 def evStrideSmallerOne(check=False, **kwargs):
1769 error_name = ErrorIf.StrideSmallerOne
1770 param_reqs = {"rank": None, "dtype": None, "shape": None}
1771 error_result = False
1772 error_reason = "At least one stride dimension is smaller than zero"
1773
1774 if check:
1775 stride = kwargs["stride"]
1776 if min(stride) < 1:
1777 error_result = True
1778
1779 info_dict = {
1780 "error_name": error_name,
1781 "error_result": error_result,
1782 "error_reason": error_reason,
1783 "param_reqs": param_reqs,
1784 }
1785 return info_dict
1786
1787 @staticmethod
1788 def evDilationSmallerOne(check=False, **kwargs):
1789 error_result = check and min(kwargs["dilation"]) < 1
1790 return {
1791 "error_name": ErrorIf.DilationSmallerOne,
1792 "error_reason": "At least one dilation is smaller than one",
1793 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1794 "error_result": error_result,
1795 }
1796
1797 @staticmethod
1798 def evScaleTrue(check=False, **kwargs):
1799 error_name = ErrorIf.ScaleTrue
1800 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1801 error_result = False
1802 error_reason = "Scale set to true but input type is INT48"
1803
1804 if check:
1805 input_dtype = kwargs["input_dtype"]
1806 scale32 = kwargs["scale32"]
1807 if scale32 and input_dtype == DType.INT48:
1808 error_result = True
1809
1810 info_dict = {
1811 "error_name": error_name,
1812 "error_result": error_result,
1813 "error_reason": error_reason,
1814 "param_reqs": param_reqs,
1815 }
1816 return info_dict
1817
1818 @staticmethod
1819 def evScaleNotTrue(check=False, **kwargs):
1820 error_name = ErrorIf.ScaleNotTrue
1821 param_reqs = {"rank": None, "dtype": None, "shape": None}
1822 error_result = False
1823 error_reason = "Scale set to false but double round set to true"
1824
1825 if check:
1826 scale32 = kwargs["scale32"]
1827 double_round = kwargs["double_round"]
1828 if not scale32 and double_round:
1829 error_result = True
1830
1831 info_dict = {
1832 "error_name": error_name,
1833 "error_result": error_result,
1834 "error_reason": error_reason,
1835 "param_reqs": param_reqs,
1836 }
1837 return info_dict
1838
1839 @staticmethod
1840 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1841 error_name = ErrorIf.TensorSizeInputOutputMismatch
1842 param_reqs = {"rank": None, "dtype": None, "shape": None}
1843 error_result = False
1844 error_reason = "Input tensor size does not match output tensor size"
Jerry Ge264f7fa2023-04-21 22:49:57 +00001845 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001846
1847 if check:
1848 input_shape = kwargs["input_shape"]
1849 output_shape = kwargs["output_shape"]
Jerry Ge264f7fa2023-04-21 22:49:57 +00001850 shape_inferencing = False
1851 if -1 in output_shape and op["op"] == Op.RESHAPE:
1852 shape_inferencing = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001853 input_size = np.prod(input_shape)
1854 output_size = np.prod(output_shape)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001855 if input_size != output_size and not shape_inferencing:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001856 error_result = True
1857
1858 info_dict = {
1859 "error_name": error_name,
1860 "error_result": error_result,
1861 "error_reason": error_reason,
1862 "param_reqs": param_reqs,
1863 }
1864 return info_dict
1865
1866 @staticmethod
1867 def evStartSmallerZero(check=False, **kwargs):
1868 error_name = ErrorIf.StartSmallerZero
1869 param_reqs = {"rank": None, "dtype": None, "shape": None}
1870 error_result = False
1871 error_reason = "Starting point smaller than zero"
1872
1873 if check:
1874 input_shape = kwargs["input_shape"]
1875 start = kwargs["start"]
1876 rank = len(input_shape)
1877 if len(start) == rank:
1878 for index in range(rank):
1879 if start[index] < 0:
1880 error_result = True
1881
1882 info_dict = {
1883 "error_name": error_name,
1884 "error_result": error_result,
1885 "error_reason": error_reason,
1886 "param_reqs": param_reqs,
1887 }
1888 return info_dict
1889
1890 @staticmethod
1891 def evSizeSmallerEqualZero(check=False, **kwargs):
1892 error_name = ErrorIf.SizeSmallerEqualZero
1893 param_reqs = {"rank": None, "dtype": None, "shape": None}
1894 error_result = False
1895 error_reason = "Size smaller than or equal to zero"
1896
1897 if check:
1898 input_shape = kwargs["input_shape"]
1899 size = kwargs["size"]
1900 rank = len(input_shape)
1901 if len(size) == rank:
1902 for index in range(rank):
1903 if size[index] <= 0:
1904 error_result = True
1905
1906 info_dict = {
1907 "error_name": error_name,
1908 "error_result": error_result,
1909 "error_reason": error_reason,
1910 "param_reqs": param_reqs,
1911 }
1912 return info_dict
1913
1914 @staticmethod
1915 def evStartSizeOutsideBounds(check=False, **kwargs):
1916 error_name = ErrorIf.StartSizeOutsideBounds
1917 param_reqs = {"rank": None, "dtype": None, "shape": None}
1918 error_result = False
1919 error_reason = "starting point plus size larger than input dimension"
1920
1921 if check:
1922 input_shape = kwargs["input_shape"]
1923 start = kwargs["start"]
1924 size = kwargs["size"]
1925 rank = len(input_shape)
1926 if len(start) == rank and len(size) == rank:
1927 for index in range(rank):
1928 if start[index] + size[index] > input_shape[index]:
1929 error_result = True
1930
1931 info_dict = {
1932 "error_name": error_name,
1933 "error_result": error_result,
1934 "error_reason": error_reason,
1935 "param_reqs": param_reqs,
1936 }
1937 return info_dict
1938
1939 @staticmethod
1940 def evSizeOutputShapeMismatch(check=False, **kwargs):
1941 error_name = ErrorIf.SizeOutputShapeMismatch
1942 param_reqs = {"rank": None, "dtype": None, "shape": None}
1943 error_result = False
1944 error_reason = "Size does not match output dimension"
1945
1946 if check:
1947 input_shape = kwargs["input_shape"]
1948 output_shape = kwargs["output_shape"]
1949 size = kwargs["size"]
Luke Huttona4e48ca2023-02-22 11:53:48 +00001950
1951 if len(input_shape) == len(output_shape):
1952 rank = len(input_shape)
1953 if len(size) == rank:
1954 for index in range(rank):
1955 if size[index] != output_shape[index]:
1956 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001957
1958 info_dict = {
1959 "error_name": error_name,
1960 "error_result": error_result,
1961 "error_reason": error_reason,
1962 "param_reqs": param_reqs,
1963 }
1964 return info_dict
1965
1966 @staticmethod
1967 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1968 error_name = ErrorIf.InputSizeStartLengthMismatch
1969 param_reqs = {"rank": None, "dtype": None, "shape": None}
1970 error_result = False
1971 error_reason = "rank of input not equal to length of start or size"
1972
1973 if check:
1974 input_shape = kwargs["input_shape"]
1975 start = kwargs["start"]
1976 size = kwargs["size"]
1977 rank = len(input_shape)
1978 if rank != len(start) or rank != len(size):
1979 error_result = True
1980
1981 info_dict = {
1982 "error_name": error_name,
1983 "error_result": error_result,
1984 "error_reason": error_reason,
1985 "param_reqs": param_reqs,
1986 }
1987 return info_dict
1988
1989 @staticmethod
1990 def evIndexOutsideBounds(check=False, **kwargs):
1991 error_name = ErrorIf.IndexOutsideBounds
1992 param_reqs = {"rank": None, "dtype": None, "shape": None}
1993 error_result = False
1994 error_reason = "Index outside of allowed bounds"
1995
1996 if check:
1997 input_shape = kwargs["input_shape"]
1998 perms = kwargs["perms"]
1999 rank = len(input_shape)
2000
2001 for index in perms:
2002 if index < 0 or index > rank:
2003 error_result = True
2004
2005 info_dict = {
2006 "error_name": error_name,
2007 "error_result": error_result,
2008 "error_reason": error_reason,
2009 "param_reqs": param_reqs,
2010 }
2011 return info_dict
2012
2013 @staticmethod
2014 def evIndexUsedTwice(check=False, **kwargs):
2015 error_name = ErrorIf.IndexUsedTwice
2016 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2017 error_result = False
2018 error_reason = "Index used multiple times"
2019
2020 if check:
2021 perms = kwargs["perms"]
2022
2023 unique_indices = []
2024 for index in perms:
2025 if index in unique_indices:
2026 error_result = True
2027 else:
2028 unique_indices.append(index)
2029
2030 info_dict = {
2031 "error_name": error_name,
2032 "error_result": error_result,
2033 "error_reason": error_reason,
2034 "param_reqs": param_reqs,
2035 }
2036 return info_dict
2037
2038 @staticmethod
2039 def evMaxSmallerMin(check=False, **kwargs):
2040 error_name = ErrorIf.MaxSmallerMin
2041 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2042 error_result = False
2043 error_reason = "Max value smaller than min value"
2044
2045 if check:
2046 max_val = kwargs["max_val"]
2047 min_val = kwargs["min_val"]
2048 if max_val < min_val:
2049 error_result = True
2050
2051 info_dict = {
2052 "error_name": error_name,
2053 "error_result": error_result,
2054 "error_reason": error_reason,
2055 "param_reqs": param_reqs,
2056 }
2057 return info_dict
2058
2059 @staticmethod
2060 def evConcatInputRankMismatch(check=False, **kwargs):
2061 error_name = ErrorIf.ConcatInputRankMismatch
2062 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2063 error_result = False
2064 error_reason = "Input ranks are not identical"
2065
2066 if check:
2067 inputs = kwargs["inputs"]
2068 input_shape = kwargs["input_shape"]
2069 for input in inputs:
2070 if len(input.shape) != len(input_shape):
2071 error_result = True
2072
2073 info_dict = {
2074 "error_name": error_name,
2075 "error_result": error_result,
2076 "error_reason": error_reason,
2077 "param_reqs": param_reqs,
2078 }
2079 return info_dict
2080
2081 @staticmethod
2082 def evConcatInputDimMismatch(check=False, **kwargs):
2083 error_name = ErrorIf.ConcatInputDimMismatch
2084 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2085 error_result = False
2086 error_reason = "Input dimensions differ on too many axes"
2087
2088 if check:
2089 inputs = kwargs["inputs"]
2090 input_shape = kwargs["input_shape"]
2091 axis = kwargs["axis"]
2092
2093 # Ensure rank is valid before checking dims.
2094 valid_rank = True
2095 for input in inputs:
2096 if len(input.shape) != len(input_shape):
2097 valid_rank = False
2098
2099 if valid_rank:
2100 for input in inputs:
2101 for i, dim in enumerate(input.shape):
2102 if dim != input_shape[i] and axis != i:
2103 error_result = True
2104
2105 info_dict = {
2106 "error_name": error_name,
2107 "error_result": error_result,
2108 "error_reason": error_reason,
2109 "param_reqs": param_reqs,
2110 }
2111 return info_dict
2112
2113 @staticmethod
2114 def evConcatShapeSumMismatch(check=False, **kwargs):
2115 error_name = ErrorIf.ConcatShapeSumMismatch
2116 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2117 error_result = False
2118 error_reason = "Sum of dimensions on axis not equal to output dimension"
2119
2120 if check:
2121 inputs = kwargs["inputs"]
2122 input_shape = kwargs["input_shape"]
2123 output_shape = kwargs["output_shape"]
2124 axis = kwargs["axis"]
2125
2126 # Ensure rank is valid before checking dims.
2127 valid_params = True
2128 for input in inputs:
2129 if len(input.shape) != len(input_shape):
2130 valid_params = False
2131 if axis < 0 or axis > len(input_shape):
2132 valid_params = False
2133
2134 if valid_params:
2135 axis_dim_sum = 0
2136 for input in inputs:
2137 axis_dim_sum += input.shape[axis]
2138
2139 if axis_dim_sum != output_shape[axis]:
2140 error_result = True
2141
2142 info_dict = {
2143 "error_name": error_name,
2144 "error_result": error_result,
2145 "error_reason": error_reason,
2146 "param_reqs": param_reqs,
2147 }
2148 return info_dict
2149
2150 @staticmethod
2151 def evInputListThenGraphMismatch(check=False, **kwargs):
2152 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2153 param_reqs = {"rank": None, "dtype": None, "shape": None}
2154 error_result = False
2155 error_reason = "Input list shape does not match then-graph shape"
2156
2157 if check:
2158 a = kwargs["a"]
2159 b = kwargs["b"]
2160 basicBlocks = kwargs["basicBlocks"]
2161 then_block = basicBlocks[1]
2162 then_inputs = then_block.inputs
2163 then_tens = then_block.tensors
2164 if (a.shape != then_tens[then_inputs[0]].shape) or (
2165 b.shape != then_tens[then_inputs[1]].shape
2166 ):
2167 error_result = True
2168
2169 info_dict = {
2170 "error_name": error_name,
2171 "error_result": error_result,
2172 "error_reason": error_reason,
2173 "param_reqs": param_reqs,
2174 }
2175 return info_dict
2176
2177 @staticmethod
2178 def evInputListElseGraphMismatch(check=False, **kwargs):
2179 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2180 param_reqs = {"rank": None, "dtype": None, "shape": None}
2181 error_result = False
2182 error_reason = "Input list shape does not match else-graph shape"
2183
2184 if check:
2185 a = kwargs["a"]
2186 b = kwargs["b"]
2187 basicBlocks = kwargs["basicBlocks"]
2188 else_block = basicBlocks[2]
2189 else_inputs = else_block.inputs
2190 else_tens = else_block.tensors
2191 if (a.shape != else_tens[else_inputs[0]].shape) or (
2192 b.shape != else_tens[else_inputs[1]].shape
2193 ):
2194 error_result = True
2195
2196 info_dict = {
2197 "error_name": error_name,
2198 "error_result": error_result,
2199 "error_reason": error_reason,
2200 "param_reqs": param_reqs,
2201 }
2202 return info_dict
2203
2204 @staticmethod
2205 def evOutputListThenGraphMismatch(check=False, **kwargs):
2206 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2207 param_reqs = {"rank": None, "dtype": None, "shape": None}
2208 error_result = False
2209 error_reason = "Output list shape does not match then-graph shape"
2210
2211 if check:
2212 basicBlocks = kwargs["basicBlocks"]
2213 cond_block = basicBlocks[0]
2214 cond_outputs = cond_block.outputs
2215 cond_tens = cond_block.tensors
2216 then_block = basicBlocks[1]
2217 then_outputs = then_block.outputs
2218 then_tens = then_block.tensors
2219 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2220 error_result = True
2221
2222 info_dict = {
2223 "error_name": error_name,
2224 "error_result": error_result,
2225 "error_reason": error_reason,
2226 "param_reqs": param_reqs,
2227 }
2228 return info_dict
2229
2230 @staticmethod
2231 def evOutputListElseGraphMismatch(check=False, **kwargs):
2232 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2233 param_reqs = {"rank": None, "dtype": None, "shape": None}
2234 error_result = False
2235 error_reason = "Output list shape does not match else-graph shape"
2236
2237 if check:
2238 basicBlocks = kwargs["basicBlocks"]
2239 cond_block = basicBlocks[0]
2240 cond_outputs = cond_block.outputs
2241 cond_tens = cond_block.tensors
2242 else_block = basicBlocks[2]
2243 else_outputs = else_block.outputs
2244 else_tens = else_block.tensors
2245 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2246 error_result = True
2247
2248 info_dict = {
2249 "error_name": error_name,
2250 "error_result": error_result,
2251 "error_reason": error_reason,
2252 "param_reqs": param_reqs,
2253 }
2254 return info_dict
2255
2256 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002257 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2258 error_name = ErrorIf.CondIfCondNotMatchingBool
2259 param_reqs = {"rank": None, "dtype": None, "shape": None}
2260 error_result = False
2261 error_reason = "Conditional tensor does not match bool type"
2262
2263 if check:
2264 cond = kwargs["cond"]
2265 if cond.dtype != DType.BOOL:
2266 error_result = True
2267
2268 info_dict = {
2269 "error_name": error_name,
2270 "error_result": error_result,
2271 "error_reason": error_reason,
2272 "param_reqs": param_reqs,
2273 }
2274 return info_dict
2275
2276 @staticmethod
2277 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2278 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2279 param_reqs = {"rank": None, "dtype": None, "shape": None}
2280 error_result = False
2281 error_reason = "Conditional tensor is not equal to a size of one"
2282
2283 if check:
2284 cond = kwargs["cond"]
2285 # Size of 1 is equivalent to rank 0
2286 if len(cond.shape) != 0:
2287 error_result = True
2288
2289 info_dict = {
2290 "error_name": error_name,
2291 "error_result": error_result,
2292 "error_reason": error_reason,
2293 "param_reqs": param_reqs,
2294 }
2295 return info_dict
2296
2297 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002298 def evInputListOutputListMismatch(check=False, **kwargs):
2299 error_name = ErrorIf.InputListOutputListMismatch
2300 param_reqs = {"rank": None, "dtype": None, "shape": None}
2301 error_result = False
2302 error_reason = "Input list does not match output list"
2303
2304 if check:
2305 basicBlocks = kwargs["basicBlocks"]
2306 while_block = basicBlocks[0]
2307 while_inputs = while_block.inputs
2308 while_outputs = while_block.outputs
2309 while_tens = while_block.tensors
2310 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2311 error_result = True
2312
2313 info_dict = {
2314 "error_name": error_name,
2315 "error_result": error_result,
2316 "error_reason": error_reason,
2317 "param_reqs": param_reqs,
2318 }
2319 return info_dict
2320
2321 @staticmethod
2322 def evInputListCondGraphMismatch(check=False, **kwargs):
2323 error_name = ErrorIf.InputListCondGraphMismatch
2324 param_reqs = {"rank": None, "dtype": None, "shape": None}
2325 error_result = False
2326 error_reason = "Input list does not match cond graph"
2327
2328 if check:
2329 basicBlocks = kwargs["basicBlocks"]
2330 while_block = basicBlocks[0]
2331 while_inputs = while_block.inputs
2332 while_tens = while_block.tensors
2333 cond_block = basicBlocks[1]
2334 cond_inputs = cond_block.inputs
2335 cond_tens = cond_block.tensors
2336 if (
2337 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2338 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2339 error_result = True
2340
2341 info_dict = {
2342 "error_name": error_name,
2343 "error_result": error_result,
2344 "error_reason": error_reason,
2345 "param_reqs": param_reqs,
2346 }
2347 return info_dict
2348
2349 @staticmethod
2350 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2351 error_name = ErrorIf.InputListBodyGraphInputMismatch
2352 param_reqs = {"rank": None, "dtype": None, "shape": None}
2353 error_result = False
2354 error_reason = "Input list does not match body graph input"
2355
2356 if check:
2357 basicBlocks = kwargs["basicBlocks"]
2358 while_block = basicBlocks[0]
2359 while_inputs = while_block.inputs
2360 while_tens = while_block.tensors
2361 body_block = basicBlocks[2]
2362 body_outputs = body_block.inputs
2363 body_tens = body_block.tensors
2364 if (
2365 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2366 ) or (
2367 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2368 ):
2369 error_result = True
2370
2371 info_dict = {
2372 "error_name": error_name,
2373 "error_result": error_result,
2374 "error_reason": error_reason,
2375 "param_reqs": param_reqs,
2376 }
2377 return info_dict
2378
2379 @staticmethod
2380 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2381 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2382 param_reqs = {"rank": None, "dtype": None, "shape": None}
2383 error_result = False
2384 error_reason = "Input list does not match body graph output"
2385
2386 if check:
2387 basicBlocks = kwargs["basicBlocks"]
2388 while_block = basicBlocks[0]
2389 while_inputs = while_block.inputs
2390 while_tens = while_block.tensors
2391 body_block = basicBlocks[2]
2392 body_outputs = body_block.outputs
2393 body_tens = body_block.tensors
2394 if (
2395 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2396 ) or (
2397 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2398 ):
2399 error_result = True
2400 info_dict = {
2401 "error_name": error_name,
2402 "error_result": error_result,
2403 "error_reason": error_reason,
2404 "param_reqs": param_reqs,
2405 }
2406 return info_dict
2407
2408 @staticmethod
2409 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2410 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2411 param_reqs = {"rank": None, "dtype": None, "shape": None}
2412 error_result = False
2413 error_reason = "Cond graph output is not a match list of booleans"
2414
2415 if check:
2416 basicBlocks = kwargs["basicBlocks"]
2417 cond_block = basicBlocks[1]
2418 cond_outputs = cond_block.outputs
2419 cond_tens = cond_block.tensors
2420 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2421 error_result = True
2422
2423 info_dict = {
2424 "error_name": error_name,
2425 "error_result": error_result,
2426 "error_reason": error_reason,
2427 "param_reqs": param_reqs,
2428 }
2429 return info_dict
2430
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002431 @staticmethod
2432 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2433 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2434 param_reqs = {"rank": None, "dtype": None, "shape": None}
2435 error_result = False
2436 error_reason = "Cond graph output is not a shape of size one"
2437
2438 if check:
2439 basicBlocks = kwargs["basicBlocks"]
2440 cond_block = basicBlocks[1]
2441 cond_outputs = cond_block.outputs
2442 cond_tens = cond_block.tensors
2443 # Size of 1 is equivalent to rank 0
2444 if len(cond_tens[cond_outputs[0]].shape) != 0:
2445 error_result = True
2446
2447 info_dict = {
2448 "error_name": error_name,
2449 "error_result": error_result,
2450 "error_reason": error_reason,
2451 "param_reqs": param_reqs,
2452 }
2453 return info_dict
2454
Luke Hutton261b7b62023-01-10 14:50:31 +00002455 @staticmethod
2456 def evKernelNotPowerOfTwo(check=False, **kwargs):
2457 error_name = ErrorIf.KernelNotPowerOfTwo
2458 param_reqs = {"rank": None, "dtype": None, "shape": None}
2459 error_result = False
2460 error_reason = "kernel height and/or width not a power of two"
2461
2462 def is_power_of_two(x):
2463 return math.log(x, 2).is_integer()
2464
2465 if check:
2466 shape = kwargs["input_shape"]
2467 if len(shape) == 3:
2468 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2469 error_result = not valid_kernel
2470
2471 info_dict = {
2472 "error_name": error_name,
2473 "error_result": error_result,
2474 "error_reason": error_reason,
2475 "param_reqs": param_reqs,
2476 }
2477 return info_dict
2478
Luke Hutton57287132023-02-06 14:54:18 +00002479 @staticmethod
2480 def evFFTInputShapeMismatch(check=False, **kwargs):
2481 error_name = ErrorIf.FFTInputShapeMismatch
2482 param_reqs = {"rank": None, "dtype": None, "shape": None}
2483 error_result = False
2484 error_reason = "Mismatch between real and imaginary input shapes"
2485
2486 if check:
2487 input1 = kwargs["input1"]
2488 input2 = kwargs["input2"]
2489
2490 if input1.shape != input2.shape:
2491 error_result = True
2492
2493 info_dict = {
2494 "error_name": error_name,
2495 "error_result": error_result,
2496 "error_reason": error_reason,
2497 "param_reqs": param_reqs,
2498 }
2499 return info_dict
2500
2501 @staticmethod
2502 def evFFTOutputShapeMismatch(check=False, **kwargs):
2503 error_name = ErrorIf.FFTOutputShapeMismatch
2504 param_reqs = {"rank": None, "dtype": None, "shape": None}
2505 error_result = False
2506 error_reason = (
2507 "Mismatch between provided and expected output kernel (H, W) shape"
2508 )
2509
2510 if check:
2511 op = kwargs["op"]
2512 input_shape = kwargs["input_shape"]
2513
2514 if len(input_shape) == 3:
2515 output_shapes = kwargs["output_shape"]
2516
2517 # Ignoring batch size (N) from input shape
2518 expected_shape = input_shape[1:]
2519 if op["op"] == Op.RFFT2D:
2520 expected_shape[1] = expected_shape[1] // 2 + 1
2521
2522 # Ignoring batch size (N) from output shapes
2523 output_shape_0 = output_shapes[0][1:]
2524 output_shape_1 = output_shapes[1][1:]
2525 # Ensure sure the kernel sizes (H, W) of both outputs match the expected
2526 if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
2527 error_result = True
2528
2529 info_dict = {
2530 "error_name": error_name,
2531 "error_result": error_result,
2532 "error_reason": error_reason,
2533 "param_reqs": param_reqs,
2534 }
2535 return info_dict
2536
Jerry Ge264f7fa2023-04-21 22:49:57 +00002537 @staticmethod
Jerry Ge135c9552023-05-23 20:59:32 +00002538 def calculateBroadcastShape(input_shape_a, input_shape_b):
2539 if input_shape_a is not None and input_shape_b is not None:
2540 calculated_shape = input_shape_a.copy()
2541 for idx in range(len(calculated_shape)):
2542 if calculated_shape[idx] == 1:
2543 calculated_shape[idx] = input_shape_b[idx]
2544 elif (
2545 input_shape_b[idx] != 1
2546 and input_shape_b[idx] != calculated_shape[idx]
2547 ):
2548 return None
2549 return calculated_shape
2550 else:
2551 return None
2552
2553 @staticmethod
2554 def evBroadcastShapesMismatch(check=False, **kwargs):
2555 error_name = ErrorIf.BroadcastShapesMismatch
2556 param_reqs = {"rank": None, "dtype": None, "shape": None}
2557 error_result = False
2558 error_reason = "Broadcast shape calculating failed"
2559
2560 if check:
2561 input_shape_a = kwargs["input1"].shape
2562 input_shape_b = kwargs["input2"].shape
2563 input_shape_c = (
2564 kwargs["input3"].shape if "input3" in kwargs else input_shape_b
2565 )
2566
2567 if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
2568 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
2569 input_shape_c,
2570 TosaErrorValidator.calculateBroadcastShape(
2571 input_shape_a, input_shape_b
2572 ),
2573 )
2574 error_result = calculated_shape is None
2575
2576 info_dict = {
2577 "error_name": error_name,
2578 "error_result": error_result,
2579 "error_reason": error_reason,
2580 "param_reqs": param_reqs,
2581 }
2582 return info_dict
2583
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002584 def evWrongAccumulatorType(check=False, **kwargs):
2585 error_name = ErrorIf.WrongAccumulatorType
2586 param_reqs = {"rank": None, "dtype": None, "shape": None}
2587 error_result = False
2588 error_reason = "An unsupported accumulator data type was requested"
2589
2590 if check:
2591 op = kwargs["op"]
2592 input_dtype = kwargs["input_dtype"]
2593 accum_dtype = kwargs["accum_dtype"]
2594 if op["op"] == Op.AVG_POOL2D:
2595 if (
2596 input_dtype
2597 in (
2598 DType.INT8,
2599 DType.INT16,
2600 )
2601 and accum_dtype != DType.INT32
2602 ):
2603 error_result = True
2604 elif (
2605 input_dtype
2606 in (
2607 DType.FP32,
2608 DType.BF16,
2609 )
2610 and accum_dtype != DType.FP32
2611 ):
2612 error_result = True
2613 elif input_dtype == DType.FP16 and accum_dtype not in (
2614 DType.FP16,
2615 DType.FP32,
2616 ):
2617 error_result = True
2618
2619 info_dict = {
2620 "error_name": error_name,
2621 "error_result": error_result,
2622 "error_reason": error_reason,
2623 "param_reqs": param_reqs,
2624 }
2625 return info_dict
2626
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002627
2628class TosaInvalidValidator:
2629 @staticmethod
2630 def ivWrongDataTypeOrModeResize(**kwargs):
2631 input_dtype = kwargs["input_dtype"]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002632 args_dict = kwargs["args"]
2633 mode = args_dict["mode"]
2634 output_dtype = args_dict["output_dtype"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002635
2636 if mode == ResizeMode.BILINEAR:
2637 # Invalid output data type / Invalid input datatype
2638 return (
2639 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002640 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002641 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002642 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002643 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002644 )
2645 elif mode == ResizeMode.NEAREST:
2646 # Invalid output data type / Invalid input datatype
2647 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002648 input_dtype
2649 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002650 )
2651 else:
2652 # Invalid resize mode
2653 return True
2654
2655 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002656 def ivHeightWidthInvalid(**kwargs):
2657 opName = kwargs["opName"]
2658
2659 inputShapes = kwargs["shapeList"]
2660 input_shape = inputShapes[0]
2661
2662 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002663
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002664 if isinstance(args, dict):
2665 args_dict = args
2666 else:
2667 # Create args_dict from list elements
2668 # TODO - Remove this once all NWHC operators agFunctions have been
2669 # converted to args_dict output
2670
2671 # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
2672 stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
2673 args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]}
2674 # Alias different info for each op
2675 args_dict["kernel"] = args[pad_idx + 1]
2676 args_dict["out_shape"] = args[pad_idx + 1]
2677 args_dict["dilation"] = args[pad_idx + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002678
2679 # Common info for all ops
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002680 strides = args_dict["stride"]
2681 padding = args_dict["pad"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002682
2683 if opName.endswith("pool2d"):
2684 # avg_pool2d, max_pool2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002685 kernel_shape = args_dict["kernel"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002686 h = (
2687 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2688 ) // strides[0]
2689 w = (
2690 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2691 ) // strides[1]
2692 # return True if any dimension is < 1
2693 return h < 1 or w < 1
2694
2695 if opName.startswith("transpose_conv2d"):
2696 # transpose_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002697 output_shape = args_dict["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002698 filter_shape = inputShapes[1]
2699 kernel_shape = filter_shape[1:-1]
2700
TatWai Chong24594f52022-06-08 00:48:04 -07002701 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002702 """Calculate the transpose_conv2d output size for a dimension."""
2703 return (in_size - 1) * stride + kernel_size + in_pad + out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002704
Jeremy Johnson0c716862023-04-13 17:18:19 +01002705 h = get_out_size(
2706 input_shape[1],
2707 strides[0],
2708 kernel_shape[0],
2709 padding[0],
2710 padding[1],
2711 )
2712 w = get_out_size(
2713 input_shape[2],
2714 strides[1],
2715 kernel_shape[1],
2716 padding[2],
2717 padding[3],
2718 )
2719 if output_shape[1] == h and output_shape[2] == w:
2720 return False
2721 # output shape does not match the expected shape
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002722 return True
2723
2724 if "conv2d" in opName or "conv3d" in opName:
2725 # conv2d, conv3d, depthwise_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002726 dilations = args_dict["dilation"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002727 filter_shape = inputShapes[1]
2728 kernel_shape = (
2729 filter_shape[0:2]
2730 if opName.startswith("depthwise_conv2d")
2731 else filter_shape[1:-1]
2732 )
2733
2734 for i in range(len(kernel_shape)):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002735 pad_offset = i * 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002736 dim = (
2737 input_shape[i + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002738 - 1
2739 + padding[pad_offset]
2740 + padding[pad_offset + 1]
2741 - (kernel_shape[i] - 1) * dilations[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002742 ) // strides[i] + 1
2743 # return True if any dimension is < 1
2744 if dim < 1:
2745 return True
2746 return False
2747
2748 assert False, f"Unrecognized Op: {opName}"
2749
2750 @staticmethod
2751 def ivNonPositiveOutputShape(**kwargs):
2752 args = kwargs["args"]
Jeremy Johnson95a67102024-01-10 14:16:39 +00002753 output_shape = args["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002754 if output_shape[1] <= 0 or output_shape[2] <= 0:
2755 # Negative output shape
2756 return True
2757 return False