blob: 90c34282385b7075685266b1e85044c2c2f42d2b [file] [log] [blame]
Won Jeon74342e52024-01-09 00:34:40 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01006from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from generator.tosa_utils import product
8from generator.tosa_utils import usableDTypes
9from generator.tosa_utils import valueToName
10from tosa.DType import DType
11from tosa.Op import Op
12from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000013
Matthew Haddone86fd342021-09-07 16:12:21 +010014
15class ErrorIf(object):
16 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010017 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
18 ScaleNLargerMax = "ScaleNLargerMax"
19 ScaleDLargerMax = "ScaleDLargerMax"
20 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010021 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010022 BorderSmallerMin = "BorderSmallerMin"
23 BorderLargerEqualMax = "BorderLargerEqualMax"
24 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
25 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010026 WrongInputType = "WrongInputType"
27 WrongOutputType = "WrongOutputType"
28 WrongInputList = "WrongInputList"
29 WrongOutputList = "WrongOutputList"
30 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010031 BatchMismatch = "BatchMismatch"
32 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010033 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000034 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010037 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 AxisSmallerZero = "AxisSmallerZero"
39 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010040 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
41 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010042 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010043 KernelSmallerOne = "KernelSmallerOne"
44 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000045 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010046 PadSmallerZero = "PadSmallerZero"
47 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010048 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010049 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010050 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
51 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
52 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010053 ScaleNotTrue = "ScaleNotTrue"
54 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010055 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
56 StartSmallerZero = "StartSmallerZero"
57 SizeSmallerEqualZero = "SizeSmallerEqualZero"
58 StartSizeOutsideBounds = "StartSizeOutsideBounds"
59 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
60 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
61 IndexOutsideBounds = "IndexOutsideBounds"
62 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010063 MaxSmallerMin = "MaxSmallerMin"
64 ConcatInputRankMismatch = "ConcatInputRankMismatch"
65 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010066 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010067 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
68 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
69 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
70 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
71 InputListOutputListMismatch = "InputListOutputListMismatch"
72 InputListCondGraphMismatch = "InputListCondGraphMismatch"
73 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
74 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
75 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010076 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
77 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000078 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
79 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
80 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000081 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Luke Hutton57287132023-02-06 14:54:18 +000082 FFTInputShapeMismatch = "FFTInputShapeMismatch"
83 FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
Jerry Ge264f7fa2023-04-21 22:49:57 +000084 ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
85 ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
Jerry Ge135c9552023-05-23 20:59:32 +000086 BroadcastShapesMismatch = "BroadcastShapesMismatch"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087
88
89class TosaErrorIfArgGen:
90 @staticmethod
91 def eiResizeErrorIf(
92 testGen,
93 error_name,
94 mode,
95 dtype,
96 shapeList,
97 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010098 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010099 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100100 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100102 if error_name == ErrorIf.ScaleSmallerEqualZero:
103 index = testGen.randInt(low=0, high=4)
104 scale[index] = testGen.rng.choice([-2, -1, 0])
105 elif error_name == ErrorIf.ScaleNLargerMax:
106 index = testGen.rng.choice([0, 2])
107 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
108 elif error_name == ErrorIf.ScaleDLargerMax:
109 index = testGen.rng.choice([1, 3])
110 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100112 if error_name == ErrorIf.OffsetLargerEqualMax:
113 index = testGen.rng.choice([0, 1])
114 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
115 elif error_name == ErrorIf.OffsetSmallerMin:
116 index = testGen.rng.choice([0, 1])
117 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
118
119 if error_name == ErrorIf.BorderLargerEqualMax:
120 index = testGen.rng.choice([0, 1])
121 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
122 elif error_name == ErrorIf.BorderSmallerMin:
123 index = testGen.rng.choice([0, 1])
124 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100125
126 if error_name == ErrorIf.WrongOutputType:
127 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
128 incorrect_types = (
129 DType.INT4,
130 DType.INT16,
131 DType.INT32,
132 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100133 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100134 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100135 )
136 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
137 incorrect_types = (
138 DType.INT4,
139 DType.INT8,
140 DType.INT32,
141 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100142 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100143 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100144 )
145 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
146 incorrect_types = (
147 DType.INT4,
148 DType.INT8,
149 DType.INT16,
150 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100151 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100152 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100153 )
154 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
155 incorrect_types = (
156 DType.INT4,
157 DType.INT8,
158 DType.INT16,
159 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100160 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100161 DType.FP16,
162 )
163 elif dtype == DType.FP16:
164 incorrect_types = (
165 DType.INT4,
166 DType.INT8,
167 DType.INT16,
168 DType.INT32,
169 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100170 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100171 )
James Ward24dbc422022-10-19 12:20:31 +0100172 elif dtype == DType.BF16:
173 incorrect_types = (
174 DType.INT4,
175 DType.INT8,
176 DType.INT16,
177 DType.INT32,
178 DType.INT48,
179 DType.FP32,
180 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100181 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 incorrect_types = (
183 DType.INT4,
184 DType.INT8,
185 DType.INT16,
186 DType.INT32,
187 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100188 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100189 )
190 outputDType = testGen.rng.choice(a=incorrect_types)
191
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100192 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 @staticmethod
195 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
196 if (
197 error_name == ErrorIf.StrideSmallerOne
198 # padding must not exceed the kernel size
199 and pad[0] < kernel[0]
200 and pad[1] < kernel[0]
201 and pad[2] < kernel[1]
202 and pad[3] < kernel[1]
203 ):
204 wrongStride = (
205 testGen.rng.choice([0, -1, -2, -3]),
206 testGen.rng.choice([0, -1, -2, -3]),
207 )
208 return wrongStride, pad, kernel
209 elif error_name == ErrorIf.PadSmallerZero:
210 wrongPad = (
211 testGen.rng.choice([-1, -2, -3]),
212 testGen.rng.choice([-1, -2, -3]),
213 testGen.rng.choice([-1, -2, -3]),
214 testGen.rng.choice([-1, -2, -3]),
215 )
216 return stride, wrongPad, kernel
217 elif error_name == ErrorIf.KernelSmallerOne:
218 wrongKernel = (
219 testGen.rng.choice([0, -1, -2, -3]),
220 testGen.rng.choice([0, -1, -2, -3]),
221 )
222 return stride, pad, wrongKernel
223 elif error_name == ErrorIf.PadLargerEqualKernel:
224 wrongPad = (
225 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
226 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
227 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
228 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
229 )
230 return stride, wrongPad, kernel
231 else:
232 return None, None, None
233
234 @staticmethod
235 def eiRescaleWrongOutputType(input_dtype, output_dtype):
236 if input_dtype == DType.INT8:
237 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
238 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100239 elif input_dtype == DType.INT16:
240 if output_dtype not in [
241 DType.UINT8,
242 DType.INT8,
243 DType.UINT16,
244 DType.INT16,
245 DType.INT32,
246 ]:
247 return True
248 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
250 return True
251 elif input_dtype == DType.INT48:
252 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
253 return True
254 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100255 if output_dtype not in [DType.INT8, DType.INT16]:
256 return True
257 elif input_dtype == DType.UINT16:
258 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 return True
260 return False
261
262 @staticmethod
263 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
264 # Mess up input/output tensors for ERROR_IF checks
265 if error_name == "WrongInputList":
266 add_input = testGen.rng.choice([True, False])
267 if add_input:
268 input_list.append("eiDummyInput")
269 else:
270 input_list = input_list[:-1]
271 elif error_name == "WrongOutputList":
272 add_output = testGen.rng.choice([True, False])
273 if add_output:
274 output_list.append("eiDummyOutput")
275 else:
276 output_list = []
277 return input_list, output_list
278
279 @staticmethod
280 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
281 """Restrict the dimensions and overall size of a shape to
282 max_dim and max_items.
283 """
284 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
285 while product(new_shape) > max_items:
286 new_shape = [max(d - 1, 1) for d in new_shape]
287 return new_shape
288
289 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
290 if error_name == ErrorIf.StartSmallerZero:
291 newStart = []
292 for i in range(len(input_shape)):
293 newStart.append(testGen.rng.choice([-3, -2, -1]))
294 return newStart, size
295 elif error_name == ErrorIf.SizeSmallerEqualZero:
296 newSize = []
297 for i in range(len(input_shape)):
298 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
299 return start, newSize
300 elif error_name == ErrorIf.StartSizeOutsideBounds:
301 newStart, newSize = [], []
302 for i in range(len(input_shape)):
303 newStart.append(input_shape[i] - 1)
304 newSize.append(testGen.rng.choice([2, 3, 4]))
305 return newStart, newSize
306 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
307 remove = testGen.rng.choice([True, False])
TatWai Chongf15bad82024-01-31 21:33:27 -0800308
309 # Get an empty tensor when diminishing dimension on 1-d tensor.
310 if len(start) == 1 or len(size) == 1:
311 remove = False
312
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100313 if remove:
314 newStart = start[1:]
315 newSize = size[1:]
316 else:
317 newStart = start
318 newStart.append(1)
319 newSize = size
320 newSize.append(1)
321 return newStart, newSize
322 else:
323 return start, size
324
325 @staticmethod
326 def eiCastErrorIf(testGen, input_dtype):
James Ward736fd1a2023-01-23 17:13:37 +0000327 if input_dtype in [DType.BOOL, DType.FP32]:
328 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
329 elif input_dtype in [DType.FP16, DType.BF16]:
330 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100331 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
332 outputDType = [DType.INT48]
333 else:
James Ward736fd1a2023-01-23 17:13:37 +0000334 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100335 return outputDType
336
337
338class TosaErrorValidator:
339 @staticmethod
340 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
341 """Check ERROR_IF statements are caught and set the expected result.
342
343 Args:
344 serializer: the serializer to set the expected result in
345 validator_fcns: a sequence of validator functions to verify the result
346 error_name: the name of the ERROR_IF condition to check for
347 kwargs: keyword arguments for the validator functions
348 Returns:
349 True if the result matches the expected result; otherwise False
350 """
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000351 if validator_fcns is None:
352 # Nothing to do
353 return True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100354 overall_result = True
355 for val_fcn in validator_fcns:
356 val_result = val_fcn(True, **kwargs)
357 validator_name = val_result["error_name"]
358 error_result = val_result["error_result"]
359 error_reason = val_result["error_reason"]
360
361 # expect an error IFF the error_name and validator_name match
362 expected_result = error_result == (error_name == validator_name)
363 overall_result &= expected_result
364
365 if expected_result and error_result:
366 serializer.setExpectedReturnCode(2, True, desc=error_reason)
367 elif error_result: # and not expected_result
368 print(
369 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
370 f" Expected: {error_name}, Got: {validator_name}"
371 )
372 elif not expected_result: # and not error_result
373 print(
374 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
375 f" Expected: {error_name}"
376 )
377
378 if not expected_result:
379 for k, v in sorted(kwargs.items()):
380 if k != "op":
381 if k.endswith("dtype"):
382 v = valueToName(DType, v)
383 print(f" {k} = {v}")
384
385 return overall_result
386
387 @staticmethod
388 def evWrongInputType(check=False, **kwargs):
389 error_result = False
390
391 # Find the unsupported input data types
392 op = kwargs["op"]
393 input_dtypes = op["types"]
394 allowed_input_dtypes = {
395 t[0] if isinstance(t, list) else t for t in input_dtypes
396 }
397 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
398
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100399 # Turn the wrong dtypes into required list of types
400 if op["op"] in [
401 Op.FULLY_CONNECTED,
402 Op.CONV2D,
403 Op.CONV3D,
404 Op.DEPTHWISE_CONV2D,
405 Op.TRANSPOSE_CONV2D,
406 ]:
407 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
408
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100409 if op["op"] == Op.CLAMP:
410 wrong_input_dtypes.remove(DType.INT48)
411
412 if check:
413 input_dtype = kwargs["input_dtype"]
414 if input_dtype not in allowed_input_dtypes:
415 error_result = True
416
417 info_dict = {
418 "error_name": ErrorIf.WrongInputType,
419 "error_result": error_result,
420 "error_reason": "Input data type not supported for this operator",
421 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
422 }
423 return info_dict
424
425 @staticmethod
426 def evWrongOutputType(check=False, **kwargs):
427 error_result = False
428
429 if check:
430 input_dtype = kwargs["input_dtype"]
431 output_dtype = kwargs["output_dtype"]
432 op = kwargs["op"]
433
434 if op["op"] == Op.RESIZE:
435 mode = kwargs["mode"]
436 if (
437 (
438 mode == ResizeMode.NEAREST
439 and input_dtype == DType.INT8
440 and output_dtype != DType.INT8
441 )
442 or (
443 mode == ResizeMode.NEAREST
444 and input_dtype == DType.INT16
445 and output_dtype != DType.INT16
446 )
447 or (
448 mode == ResizeMode.BILINEAR
449 and input_dtype == DType.INT8
450 and output_dtype != DType.INT32
451 )
452 or (
453 mode == ResizeMode.BILINEAR
454 and input_dtype == DType.INT16
455 and output_dtype != DType.INT48
456 )
James Ward8b390432022-08-12 20:48:56 +0100457 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100458 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100459 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100460 ):
461 error_result = True
462
463 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100464 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
465 input_dtype, output_dtype
466 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100467
468 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
469 if (
470 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
471 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100472 or (
473 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100474 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100475 )
James Ward24dbc422022-10-19 12:20:31 +0100476 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100477 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100478 ):
479 error_result = True
480
481 elif op["op"] == Op.ARGMAX:
482 if (
James Ward24dbc422022-10-19 12:20:31 +0100483 input_dtype
484 in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100485 and output_dtype != DType.INT32
486 ):
487 error_result = True
488
489 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100490 if (
James Ward24dbc422022-10-19 12:20:31 +0100491 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100492 and output_dtype != DType.INT32
493 ):
494 error_result = True
495 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100496 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100497 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
498 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100499 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100500 error_result = True
501
502 elif op["op"] == Op.TABLE:
503 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
504 error_result = True
505 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
506 error_result = True
507
508 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
509 if output_dtype != DType.BOOL:
510 error_result = True
511
512 elif op["op"] == Op.CAST:
513 if (
514 (
515 input_dtype == DType.BOOL
516 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
517 )
518 or (
519 input_dtype == DType.INT8
520 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100521 not in [
522 DType.BOOL,
523 DType.INT16,
524 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100525 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100526 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100527 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100528 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100529 )
530 or (
531 input_dtype == DType.INT16
532 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100533 not in [
534 DType.BOOL,
535 DType.INT8,
536 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100537 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100538 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100539 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100540 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100541 )
542 or (
543 input_dtype == DType.INT32
544 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100545 not in [
546 DType.BOOL,
547 DType.INT8,
548 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100549 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100550 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100551 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100552 ]
553 )
554 or (
555 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000556 and output_dtype
557 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100558 )
559 or (
James Ward24dbc422022-10-19 12:20:31 +0100560 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000561 and output_dtype
562 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +0100563 )
564 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100565 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000566 and output_dtype
567 not in [
568 DType.INT8,
569 DType.INT16,
570 DType.INT32,
571 DType.FP16,
572 DType.BF16,
573 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100574 )
575 ):
576 error_result = True
577
Luke Hutton57287132023-02-06 14:54:18 +0000578 elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000579 if not all([ty == input_dtype for ty in output_dtype]):
580 error_result = True
581
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100582 elif op["op"] in {
583 Op.CONV2D,
584 Op.CONV3D,
585 Op.DEPTHWISE_CONV2D,
586 Op.TRANSPOSE_CONV2D,
587 }:
588 if (
589 input_dtype == DType.INT8
590 and output_dtype != DType.INT32
591 or input_dtype == DType.INT16
592 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100593 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100594 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100595 or input_dtype == DType.BF16
596 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100597 or input_dtype == DType.FP32
598 and output_dtype != DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100599 ):
600 error_result = True
601 # invalid input types are ignored, to avoid reporting multiple errors
602
Won Jeon74342e52024-01-09 00:34:40 +0000603 elif op["op"] in {Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE}:
604 if output_dtype != DType.SHAPE:
605 error_result = True
606
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100607 else:
608 if output_dtype != input_dtype:
609 error_result = True
610
611 info_dict = {
612 "error_name": ErrorIf.WrongOutputType,
613 "error_result": error_result,
614 "error_reason": (
615 "Output data type not supported for this configuration of operator"
616 ),
617 "param_reqs": {"rank": None, "dtype": None, "shape": None},
618 }
619 return info_dict
620
621 @staticmethod
622 def evWrongRank(check=False, **kwargs):
623 all_ranks = (1, 2, 3, 4, 5)
624
625 # Make a list of incorrect ranks
626 assert "op" in kwargs
627 op = kwargs["op"]
628 rmin, rmax = op["rank"]
629 rank_range = range(rmin, rmax + 1)
630 incorrect_ranks = list(set(all_ranks) - set(rank_range))
631 # Remove small incorrect ranks to avoid index errors
632 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
633 # Set minimum incorrect rank to 3 to avoid index error
634 if op["op"] in [Op.RESIZE]:
635 incorrect_ranks = [3, 5]
636 elif op["op"] in [Op.TRANSPOSE]:
637 incorrect_ranks = [7, 8]
638 elif op["op"] in [Op.CONV3D]:
639 incorrect_ranks = [6, 7]
640
641 error_name = ErrorIf.WrongRank
642 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
643 error_result = False
644 error_reason = "Rank not supported for this operator"
645
646 if check:
647 input_shape = kwargs["input_shape"]
648
649 if (
650 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
651 and len(input_shape) != 4
652 ):
653 error_result = True
654 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
655 error_result = True
656 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
657 error_result = True
658 else:
659 if len(input_shape) not in rank_range:
660 error_result = True
661
662 info_dict = {
663 "error_name": error_name,
664 "error_result": error_result,
665 "error_reason": error_reason,
666 "param_reqs": param_reqs,
667 }
668 return info_dict
669
670 @staticmethod
671 def evWrongInputList(check=False, **kwargs):
672 error_name = ErrorIf.WrongInputList
673 param_reqs = {"rank": None, "dtype": None, "shape": None}
674 error_result = False
675 error_reason = "Op input list does not match expected input"
676
677 if check:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100678 input_list = kwargs["input_list"]
679 num_operands = kwargs["num_operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100680 if len(input_list) != num_operands:
681 error_result = True
682
683 info_dict = {
684 "error_name": error_name,
685 "error_result": error_result,
686 "error_reason": error_reason,
687 "param_reqs": param_reqs,
688 }
689 return info_dict
690
691 @staticmethod
692 def evWrongOutputList(check=False, **kwargs):
693 error_name = ErrorIf.WrongOutputList
694 param_reqs = {"rank": None, "dtype": None, "shape": None}
695 error_result = False
696 error_reason = "Op output list does not match expected output"
697
698 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000699 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100700 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000701 expected_length = 1
Luke Hutton57287132023-02-06 14:54:18 +0000702 if op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000703 expected_length = 2
704
705 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100706 error_result = True
707
708 info_dict = {
709 "error_name": error_name,
710 "error_result": error_result,
711 "error_reason": error_reason,
712 "param_reqs": param_reqs,
713 }
714 return info_dict
715
716 @staticmethod
717 def evMaxDimExceeded(check=False, **kwargs):
718 error_name = ErrorIf.MaxDimExceeded
719 param_reqs = {
720 "rank": [4, 4],
721 "dtype": [DType.INT8],
722 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
723 }
724 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100725 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100726
727 if check:
728 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100729 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100730 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100731 (input_shape[1] >= MAX_RESIZE_DIMENSION)
732 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
733 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
734 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100735 ):
736 error_result = True
737
738 info_dict = {
739 "error_name": error_name,
740 "error_result": error_result,
741 "error_reason": error_reason,
742 "param_reqs": param_reqs,
743 }
744 return info_dict
745
746 @staticmethod
747 def evBatchMismatch(check=False, **kwargs):
748 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000749 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100750 error_result = False
751 error_reason = "Input batch size not equal to output batch size"
752
753 assert "op" in kwargs
754 op = kwargs["op"]
755 rmin, rmax = op["rank"]
756 rank_range = range(rmin, rmax + 1)
757
758 if check:
759 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100760
Luke Hutton261b7b62023-01-10 14:50:31 +0000761 for output in kwargs["result_tensors"]:
762 output_shape = (
763 output.shape
764 ) # Note batch is expected to be the first dim
765 if (len(input_shape) in rank_range) and (
766 input_shape[0] != output_shape[0]
767 ):
768 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100769
770 info_dict = {
771 "error_name": error_name,
772 "error_result": error_result,
773 "error_reason": error_reason,
774 "param_reqs": param_reqs,
775 }
776 return info_dict
777
778 @staticmethod
779 def evChannelMismatch(check=False, **kwargs):
780 error_name = ErrorIf.ChannelMismatch
781 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
782 error_result = False
783 error_reason = "Input channel size not equal to output channel size"
784
785 assert "op" in kwargs
786 op = kwargs["op"]
787 rmin, rmax = op["rank"]
788 rank_range = range(rmin, rmax + 1)
789
790 if check:
791 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000792 for output in kwargs["result_tensors"]:
793 output_shape = output.shape # Note this is just (N, OH, OW, C)
794 if (len(input_shape) in rank_range) and (
795 input_shape[3] != output_shape[3]
796 ):
797 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100798
799 info_dict = {
800 "error_name": error_name,
801 "error_result": error_result,
802 "error_reason": error_reason,
803 "param_reqs": param_reqs,
804 }
805 return info_dict
806
807 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100808 def evScaleSmallerEqualZero(check=False, **kwargs):
809 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100810 param_reqs = {"rank": None, "dtype": None, "shape": None}
811 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100812 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100813
814 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100815 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100816
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100817 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100818 error_result = True
819
820 info_dict = {
821 "error_name": error_name,
822 "error_result": error_result,
823 "error_reason": error_reason,
824 "param_reqs": param_reqs,
825 }
826 return info_dict
827
828 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100829 def evScaleNLargerMax(check=False, **kwargs):
830 error_name = ErrorIf.ScaleNLargerMax
831 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100832 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100833 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100834
835 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100836 scale = kwargs["scale"]
837
838 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
839 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100840
841 info_dict = {
842 "error_name": error_name,
843 "error_result": error_result,
844 "error_reason": error_reason,
845 "param_reqs": param_reqs,
846 }
847 return info_dict
848
849 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100850 def evScaleDLargerMax(check=False, **kwargs):
851 error_name = ErrorIf.ScaleDLargerMax
852 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100853 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100854 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100855
856 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100857 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100858
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100859 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
860 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100861 ):
862 error_result = True
863
864 info_dict = {
865 "error_name": error_name,
866 "error_result": error_result,
867 "error_reason": error_reason,
868 "param_reqs": param_reqs,
869 }
870 return info_dict
871
872 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100873 def evOffsetSmallerMin(check=False, **kwargs):
874 error_name = ErrorIf.OffsetSmallerMin
875 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100876 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100877 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100878
879 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100880 scale = kwargs["scale"]
881 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100882
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100883 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100884 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100885 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100886 error_result = True
887
888 info_dict = {
889 "error_name": error_name,
890 "error_result": error_result,
891 "error_reason": error_reason,
892 "param_reqs": param_reqs,
893 }
894 return info_dict
895
896 @staticmethod
897 def evOffsetLargerEqualMax(check=False, **kwargs):
898 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100899 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100900 error_result = False
901 error_reason = "Offset value larger than or equal to maximum value"
902
903 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100904 scale = kwargs["scale"]
905 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100906
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100907 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
908 error_result = True
909 elif (
910 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
911 ):
912 error_result = True
913
914 info_dict = {
915 "error_name": error_name,
916 "error_result": error_result,
917 "error_reason": error_reason,
918 "param_reqs": param_reqs,
919 }
920 return info_dict
921
922 @staticmethod
923 def evBorderSmallerMin(check=False, **kwargs):
924 error_name = ErrorIf.BorderSmallerMin
925 param_reqs = {"rank": None, "dtype": None, "shape": None}
926 error_result = False
927 error_reason = "Border value smaller than minimum value"
928
929 if check:
930 scale = kwargs["scale"]
931 border = kwargs["border"]
932
933 if (
934 scale[0] > 0
935 and scale[0] <= (1 << 11)
936 and (border[0] < (-16 * scale[0]))
937 ):
938 error_result = True
939 elif (
940 scale[2] > 0
941 and scale[2] <= (1 << 11)
942 and (border[1] < (-16 * scale[2]))
943 ):
944 error_result = True
945
946 info_dict = {
947 "error_name": error_name,
948 "error_result": error_result,
949 "error_reason": error_reason,
950 "param_reqs": param_reqs,
951 }
952 return info_dict
953
954 @staticmethod
955 def evBorderLargerEqualMax(check=False, **kwargs):
956 error_name = ErrorIf.BorderLargerEqualMax
957 param_reqs = {"rank": None, "dtype": None, "shape": None}
958 error_result = False
959 error_reason = "Border value larger than or equal to maximum value"
960
961 if check:
962 scale = kwargs["scale"]
963 border = kwargs["border"]
964
965 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
966 error_result = True
967 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
968 error_result = True
969
970 info_dict = {
971 "error_name": error_name,
972 "error_result": error_result,
973 "error_reason": error_reason,
974 "param_reqs": param_reqs,
975 }
976 return info_dict
977
978 @staticmethod
979 def checkResizeParams(scale, offset, border):
980 return (
981 min(scale) > 0
982 and max(scale[0], scale[2]) <= (1 << 11)
983 and scale[1] < 16 * scale[0]
984 and scale[3] < 16 * scale[2]
985 and offset[0] >= -scale[0]
986 and offset[1] >= -scale[2]
987 and offset[0] < 16 * scale[0]
988 and offset[1] < 16 * scale[2]
989 and border[0] >= -16 * scale[0]
990 and border[1] >= -16 * scale[2]
991 and border[0] < scale[0]
992 and border[1] < scale[2]
993 )
994
995 @staticmethod
996 def evResizeOutputShapeMismatch(check=False, **kwargs):
997 error_name = ErrorIf.ResizeOutputShapeMismatch
998 param_reqs = {"rank": None, "dtype": None, "shape": None}
999 error_result = False
1000 error_reason = (
1001 "Mismatch between output shape provided and expected output shape"
1002 )
1003
1004 if check:
1005 input_shape = kwargs["input_shape"]
1006 output_shape = kwargs["output_shape"]
1007 scale = kwargs["scale"]
1008 offset = kwargs["offset"]
1009 border = kwargs["border"]
1010
1011 # Ensure parameters are valid
1012 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1013
1014 if (
1015 params_valid
1016 and max(output_shape) < MAX_RESIZE_DIMENSION
1017 and max(input_shape) < MAX_RESIZE_DIMENSION
1018 ):
1019 output_y = (
1020 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1021 ) // scale[1] + 1
1022 output_x = (
1023 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1024 ) // scale[3] + 1
1025
1026 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001027 error_result = True
1028
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001029 info_dict = {
1030 "error_name": error_name,
1031 "error_result": error_result,
1032 "error_reason": error_reason,
1033 "param_reqs": param_reqs,
1034 }
1035 return info_dict
1036
1037 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001038 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1039 error_name = ErrorIf.ResizeOutputShapeNonInteger
1040 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001041 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001042 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001043
1044 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001045 input_shape = kwargs["input_shape"]
1046 scale = kwargs["scale"]
1047 offset = kwargs["offset"]
1048 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001049
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001050 # Ensure parameters are valid
1051 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001052
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001053 if params_valid:
1054 remainder_y = (
1055 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1056 ) % scale[1]
1057 remainder_x = (
1058 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1059 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001060
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001061 if max(remainder_y, remainder_x) > 0:
1062 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001063
1064 info_dict = {
1065 "error_name": error_name,
1066 "error_result": error_result,
1067 "error_reason": error_reason,
1068 "param_reqs": param_reqs,
1069 }
1070 return info_dict
1071
1072 @staticmethod
1073 def evRankMismatch(check=False, **kwargs):
1074 error_name = ErrorIf.RankMismatch
1075 param_reqs = {"rank": None, "dtype": None, "shape": None}
1076 error_result = False
1077 error_reason = "Input Rank does not match output rank"
1078
1079 if check:
1080 input1_shape = kwargs["input1"].shape
Luke Huttona4e48ca2023-02-22 11:53:48 +00001081 input2_shape = (
1082 kwargs["input2"].shape if "input2" in kwargs else input1_shape
1083 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001084 # In case of SELECT op
1085 input3_shape = (
1086 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1087 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001088
1089 for output in kwargs["result_tensors"]:
1090 output_shape = output.shape
1091 if (
1092 (len(input1_shape) != len(output_shape))
1093 or (len(input2_shape) != len(output_shape))
1094 or (len(input3_shape) != len(output_shape))
1095 ):
1096 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001097
1098 info_dict = {
1099 "error_name": error_name,
1100 "error_result": error_result,
1101 "error_reason": error_reason,
1102 "param_reqs": param_reqs,
1103 }
1104 return info_dict
1105
1106 @staticmethod
1107 def evDimensionMismatch(check=False, **kwargs):
1108 error_name = ErrorIf.DimensionMismatch
1109 param_reqs = {"rank": None, "dtype": None, "shape": None}
1110 error_result = False
1111 error_reason = "Input Dimensions do not match output"
1112
1113 if check:
1114 input1_shape = kwargs["input1"].shape
1115 input2_shape = kwargs["input2"].shape
1116 # In case of SELECT op
1117 input3_shape = (
1118 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1119 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001120
Won Jeon74342e52024-01-09 00:34:40 +00001121 op = kwargs["op"]
1122 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE):
1123 output_shape = kwargs["result_tensors"][0].shape
1124 if input1_shape != output_shape:
1125 error_result = True
1126
1127 elif len(input1_shape) == len(input2_shape) == len(input3_shape):
Jerry Ge135c9552023-05-23 20:59:32 +00001128 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
1129 input3_shape,
1130 TosaErrorValidator.calculateBroadcastShape(
1131 input1_shape, input2_shape
1132 ),
1133 )
1134 if calculated_shape is not None:
1135 # Valid inputs - check for output mismatch
1136 for output in kwargs["result_tensors"]:
1137 output_shape = output.shape
1138 if calculated_shape != output_shape:
1139 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001140
1141 info_dict = {
1142 "error_name": error_name,
1143 "error_result": error_result,
1144 "error_reason": error_reason,
1145 "param_reqs": param_reqs,
1146 }
1147 return info_dict
1148
1149 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001150 def _getZeroPoint(qinfo, index):
1151 """Return zero point value from quantization info.
1152
1153 Generally input_zp is index 0, output_zp is index 1
1154 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001155 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001156
1157 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001158 def evInputZeroPointNotZero(check=False, **kwargs):
1159 op = kwargs["op"]
1160 error_result = False
1161
1162 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001163 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001164
1165 # This does not apply to quantizable types
1166 inputDtypes = [
1167 dtype
1168 for dtype in op["types"]
1169 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1170 or (not isinstance(dtype, list) and dtype not in qTypes)
1171 ]
1172
1173 if check:
1174 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001175 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001176 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001177 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001178 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001179 (kwargs["input_dtype"], input_zero_point),
1180 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001181 ):
1182 if dtype not in qTypes and zp != 0:
1183 error_result = True
1184 break
1185 else:
1186 error_result = input_dtype not in qTypes and input_zero_point != 0
1187
1188 info_dict = {
1189 "error_name": ErrorIf.InputZeroPointNotZero,
1190 "error_result": error_result,
1191 "error_reason": "Input DType not INT8 and zero point not 0",
1192 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1193 }
1194 return info_dict
1195
1196 @staticmethod
1197 def evWeightZeroPointNotZero(check=False, **kwargs):
1198 op = kwargs["op"]
1199
1200 # exclude inputs with INT8 weights
1201 inputDtypes = [
1202 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1203 ]
1204
1205 error_name = ErrorIf.WeightZeroPointNotZero
1206 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1207 error_result = False
1208 error_reason = "Weight DType not INT8 and zero point not 0"
1209
1210 if check:
1211 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001212 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001213 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1214 error_result = True
1215
1216 info_dict = {
1217 "error_name": error_name,
1218 "error_result": error_result,
1219 "error_reason": error_reason,
1220 "param_reqs": param_reqs,
1221 }
1222 return info_dict
1223
1224 @staticmethod
1225 def evOutputZeroPointNotZero(check=False, **kwargs):
1226 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001227 inputDtypes = [
1228 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1229 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001230
1231 error_name = ErrorIf.OutputZeroPointNotZero
1232 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1233 error_result = False
1234 error_reason = "Output DType not INT8 and zero point not 0"
1235
1236 if check:
1237 input_dtype = kwargs["input_dtype"]
1238 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001239 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001240 if op["op"] == Op.AVG_POOL2D:
1241 if input_dtype != DType.INT8 and output_zero_point != 0:
1242 error_result = True
1243 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001244 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1245 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001246 ):
1247 error_result = True
1248
1249 info_dict = {
1250 "error_name": error_name,
1251 "error_result": error_result,
1252 "error_reason": error_reason,
1253 "param_reqs": param_reqs,
1254 }
1255 return info_dict
1256
1257 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001258 def evU16InputZeroPointNotValid(check=False, **kwargs):
1259 error_name = ErrorIf.U16InputZeroPointNotValid
1260 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1261 error_result = False
1262 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1263
1264 if check:
1265 input_dtype = kwargs["input_dtype"]
1266 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1267 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1268 0,
1269 32768,
1270 ]
1271
1272 info_dict = {
1273 "error_name": error_name,
1274 "error_result": error_result,
1275 "error_reason": error_reason,
1276 "param_reqs": param_reqs,
1277 }
1278 return info_dict
1279
1280 @staticmethod
1281 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1282 error_name = ErrorIf.U16OutputZeroPointNotValid
1283 param_reqs = {"rank": None, "dtype": None, "shape": None}
1284 error_result = False
1285 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1286
1287 if check:
1288 output_dtype = kwargs["output_dtype"]
1289 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1290
1291 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1292 0,
1293 32768,
1294 ]
1295
1296 info_dict = {
1297 "error_name": error_name,
1298 "error_result": error_result,
1299 "error_reason": error_reason,
1300 "param_reqs": param_reqs,
1301 }
1302 return info_dict
1303
1304 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001305 def evAxisSmallerZero(check=False, **kwargs):
1306 error_name = ErrorIf.AxisSmallerZero
1307 param_reqs = {"rank": None, "dtype": None, "shape": None}
1308 error_result = False
1309 error_reason = "Axis smaller than zero"
1310
1311 if check:
1312 axis = kwargs["axis"]
1313 if axis < 0:
1314 error_result = True
1315
1316 info_dict = {
1317 "error_name": error_name,
1318 "error_result": error_result,
1319 "error_reason": error_reason,
1320 "param_reqs": param_reqs,
1321 }
1322 return info_dict
1323
1324 @staticmethod
1325 def evAxisLargerRank(check=False, **kwargs):
1326 error_name = ErrorIf.AxisLargerRank
1327 param_reqs = {"rank": None, "dtype": None, "shape": None}
1328 error_result = False
1329 error_reason = "Axis larger than rank"
1330
1331 if check:
1332 axis = kwargs["axis"]
1333 shape = kwargs["input_shape"]
1334 if axis > len(shape):
1335 error_result = True
1336
1337 info_dict = {
1338 "error_name": error_name,
1339 "error_result": error_result,
1340 "error_reason": error_reason,
1341 "param_reqs": param_reqs,
1342 }
1343 return info_dict
1344
1345 @staticmethod
1346 def evShapeOfAxisNotOne(check=False, **kwargs):
1347 error_name = ErrorIf.ShapeOfAxisNotOne
1348 param_reqs = {"rank": None, "dtype": None, "shape": None}
1349 error_result = False
1350 error_reason = "shape[axis] is not equal to 1"
1351
1352 if check:
1353 axis = kwargs["axis"]
1354 shape = kwargs["output_shape"]
1355 if (0 <= axis < len(shape)) and shape[axis] != 1:
1356 error_result = True
1357
1358 info_dict = {
1359 "error_name": error_name,
1360 "error_result": error_result,
1361 "error_reason": error_reason,
1362 "param_reqs": param_reqs,
1363 }
1364 return info_dict
1365
1366 @staticmethod
1367 def evPadSmallerZero(check=False, **kwargs):
1368 error_name = ErrorIf.PadSmallerZero
1369 param_reqs = {"rank": None, "dtype": None, "shape": None}
1370 error_result = False
1371 error_reason = "At least one pad is smaller than zero"
1372
1373 if check:
1374 op = kwargs["op"]
1375 pad = kwargs["pad"]
1376 if op["op"] == Op.PAD:
1377 for padding in pad:
1378 if min(padding) < 0:
1379 error_result = True
1380 else:
1381 if min(pad) < 0:
1382 error_result = True
1383
1384 info_dict = {
1385 "error_name": error_name,
1386 "error_result": error_result,
1387 "error_reason": error_reason,
1388 "param_reqs": param_reqs,
1389 }
1390 return info_dict
1391
1392 @staticmethod
1393 def evPadLargerEqualKernel(check=False, **kwargs):
1394 error_name = ErrorIf.PadLargerEqualKernel
1395 param_reqs = {"rank": None, "dtype": None, "shape": None}
1396 error_result = False
1397 error_reason = "At least one pad is larger than kernel dimension"
1398
1399 if check:
1400 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001401 op = kwargs["op"]
1402 if op["op"] == Op.TRANSPOSE_CONV2D:
1403 # transpose_conv2d
1404 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001405 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001406 pad[0] <= -kernel[0]
1407 or pad[1] <= -kernel[0]
1408 or pad[2] <= -kernel[1]
1409 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001410 ):
1411 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001412 else:
1413 # pooling op
1414 kernel = kwargs["kernel"]
1415 if min(pad) > 0 and min(kernel) > 1:
1416 if (
1417 pad[0] >= kernel[0]
1418 or pad[1] >= kernel[0]
1419 or pad[2] >= kernel[1]
1420 or pad[3] >= kernel[1]
1421 ):
1422 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001423
1424 info_dict = {
1425 "error_name": error_name,
1426 "error_result": error_result,
1427 "error_reason": error_reason,
1428 "param_reqs": param_reqs,
1429 }
1430 return info_dict
1431
1432 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001433 def evPadOutputShapeMismatch(check=False, **kwargs):
1434 error_name = ErrorIf.PadOutputShapeMismatch
1435 param_reqs = {"rank": None, "dtype": None, "shape": None}
1436 error_result = False
1437 error_reason = "Pad output shape mismatch for requested padding"
1438
1439 if check:
1440 pad = kwargs["pad"]
1441 input_shape = kwargs["input_shape"]
1442 output_shape = kwargs["output_shape"]
1443 for dim, padding in enumerate(pad):
1444 expected_size = input_shape[dim] + padding[0] + padding[1]
1445 if expected_size != output_shape[dim]:
1446 error_result = True
1447
1448 info_dict = {
1449 "error_name": error_name,
1450 "error_result": error_result,
1451 "error_reason": error_reason,
1452 "param_reqs": param_reqs,
1453 }
1454 return info_dict
1455
1456 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001457 def checkPoolingParams(kernel, stride, pad):
1458 return (
1459 min(kernel) >= 1
1460 and min(stride) >= 1
1461 and min(pad) >= 0
1462 and not (
1463 pad[0] >= kernel[0]
1464 or pad[1] >= kernel[0]
1465 or pad[2] >= kernel[1]
1466 or pad[3] >= kernel[1]
1467 )
1468 )
1469
1470 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001471 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1472 error_name = ErrorIf.PoolingOutputShapeMismatch
1473 param_reqs = {"rank": None, "dtype": None, "shape": None}
1474 error_result = False
1475 error_reason = (
1476 "Mismatch between output shape provided and expected output shape"
1477 )
1478
1479 if check:
1480 pad = kwargs["pad"]
1481 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1482
1483 kernel = kwargs["kernel"]
1484 kernel_y, kernel_x = kernel[0], kernel[1]
1485
1486 input_shape = kwargs["input_shape"]
1487 IH, IW = input_shape[1], input_shape[2]
1488
1489 output_shape = kwargs["output_shape"]
1490 OH, OW = output_shape[1], output_shape[2]
1491
1492 stride = kwargs["stride"]
1493 stride_y, stride_x = stride[0], stride[1]
1494
1495 # calculate correct height, width dimensions
1496 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001497 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1498 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001499
1500 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001501 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001502
1503 if params_valid and (OH != y_correct or OW != x_correct):
1504 error_result = True
1505
1506 info_dict = {
1507 "error_name": error_name,
1508 "error_result": error_result,
1509 "error_reason": error_reason,
1510 "param_reqs": param_reqs,
1511 }
1512 return info_dict
1513
1514 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001515 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1516 error_name = ErrorIf.PoolingOutputShapeNonInteger
1517 param_reqs = {"rank": None, "dtype": None, "shape": None}
1518 error_result = False
1519 error_reason = "Parameters do not yield exact integer output dimensions"
1520
1521 if check:
1522 pad = kwargs["pad"]
1523 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1524
1525 kernel = kwargs["kernel"]
1526 kernel_y, kernel_x = kernel[0], kernel[1]
1527
1528 input_shape = kwargs["input_shape"]
1529 IH, IW = input_shape[1], input_shape[2]
1530
1531 stride = kwargs["stride"]
1532 stride_y, stride_x = stride[0], stride[1]
1533
1534 # calculate remainder of height, width dimensions
1535 if stride_x != 0 and stride_y != 0:
1536 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1537 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1538
1539 # ensure parameters are valid
1540 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1541 if params_valid and (y_remainder != 0 or x_remainder != 0):
1542 error_result = True
1543
1544 info_dict = {
1545 "error_name": error_name,
1546 "error_result": error_result,
1547 "error_reason": error_reason,
1548 "param_reqs": param_reqs,
1549 }
1550 return info_dict
1551
1552 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001553 def checkConvParams(op, weight_shape, stride, pad, dilation):
1554 if op == Op.TRANSPOSE_CONV2D:
1555 pad_ok = (
1556 pad[0] > -weight_shape[1]
1557 and pad[1] > -weight_shape[1]
1558 and pad[2] > -weight_shape[2]
1559 and pad[3] > -weight_shape[2]
1560 )
1561 else:
1562 pad_ok = min(pad) >= 0
1563
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001564 return (
1565 # Check kernel sizes
1566 min(weight_shape[1:-1]) >= 1
1567 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001568 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001569 and (dilation is None or min(dilation) >= 1)
1570 )
1571
1572 @staticmethod
1573 def evConvOutputShapeMismatch(check=False, **kwargs):
1574 error_name = ErrorIf.ConvOutputShapeMismatch
1575 param_reqs = {"rank": None, "dtype": None, "shape": None}
1576 error_result = False
1577 error_reason = (
1578 "Mismatch between output shape provided and expected output shape"
1579 )
1580
1581 if check:
1582 op = kwargs["op"]
1583 pad = kwargs["pad"]
1584 weight_shape = kwargs["weight_shape"]
1585 input_shape = kwargs["input_shape"]
1586 output_shape = kwargs["output_shape"]
1587 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1588 stride = kwargs["stride"]
1589
1590 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1591
1592 # calculate correct dimensions
1593 dims_correct = []
1594 if min(stride) > 0:
1595 for index in range(len(stride)):
1596 pad_offset = index * 2
1597 if op["op"] == Op.TRANSPOSE_CONV2D:
1598 dims_correct.append(
1599 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001600 + pad[pad_offset]
1601 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001602 + weight_shape[index + kernel_offset]
1603 )
1604 else:
1605 dims_correct.append(
1606 (
1607 input_shape[index + 1]
1608 - 1
1609 + pad[pad_offset]
1610 + pad[pad_offset + 1]
1611 - (weight_shape[index + kernel_offset] - 1)
1612 * dilation[index]
1613 )
1614 // stride[index]
1615 + 1
1616 )
1617
1618 # ensure parameters are valid
1619 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001620 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001621 )
1622
1623 if params_valid and output_shape[1:-1] != dims_correct:
1624 error_result = True
1625
1626 info_dict = {
1627 "error_name": error_name,
1628 "error_result": error_result,
1629 "error_reason": error_reason,
1630 "param_reqs": param_reqs,
1631 }
1632 return info_dict
1633
1634 @staticmethod
1635 def evConvOutputShapeNonInteger(check=False, **kwargs):
1636 error_name = ErrorIf.ConvOutputShapeNonInteger
1637 param_reqs = {"rank": None, "dtype": None, "shape": None}
1638 error_result = False
1639 error_reason = "Parameters do not yield exact integer output dimensions"
1640
1641 if check:
1642 op = kwargs["op"]
1643 pad = kwargs["pad"]
1644 weight_shape = kwargs["weight_shape"]
1645 input_shape = kwargs["input_shape"]
1646 dilation = kwargs["dilation"]
1647 stride = kwargs["stride"]
1648
1649 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1650
1651 # calculate correct height, width dimensions
1652 remainders = []
1653 if min(stride) > 0:
1654 for index in range(len(stride)):
1655 pad_offset = index * 2
1656 remainders.append(
1657 (
1658 input_shape[index + 1]
1659 - 1
1660 + pad[pad_offset]
1661 + pad[pad_offset + 1]
1662 - (weight_shape[index + kernel_offset] - 1)
1663 * dilation[index]
1664 )
1665 % stride[index]
1666 )
1667
1668 # ensure parameters are valid
1669 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001670 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001671 )
1672 if params_valid and max(remainders) > 0:
1673 error_result = True
1674
1675 info_dict = {
1676 "error_name": error_name,
1677 "error_result": error_result,
1678 "error_reason": error_reason,
1679 "param_reqs": param_reqs,
1680 }
1681 return info_dict
1682
1683 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001684 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1685 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1686 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1687 error_result = False
1688 error_reason = (
1689 "Mismatch between output shape provided and expected output shape"
1690 )
1691
1692 if check:
1693 output_shape = kwargs["output_shape"]
1694 input_shape = kwargs["input_shape"]
1695 axis = kwargs["axis"]
1696
1697 dimension_match = True
1698 axis_shift = 0
1699
1700 # Check that rank is correct before trying to check dimensions
1701 if (len(input_shape) - 1) == len(output_shape):
1702 for i in range(len(input_shape)):
1703 if i == axis:
1704 axis_shift = 1
1705 continue
1706 if input_shape[i] != output_shape[i - axis_shift]:
1707 dimension_match = False
1708
1709 if not dimension_match:
1710 error_result = True
1711
1712 info_dict = {
1713 "error_name": error_name,
1714 "error_result": error_result,
1715 "error_reason": error_reason,
1716 "param_reqs": param_reqs,
1717 }
1718 return info_dict
1719
1720 @staticmethod
1721 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1722 error_name = ErrorIf.ArgmaxOutputRankMismatch
1723 param_reqs = {"rank": None, "dtype": None, "shape": None}
1724 error_result = False
1725 error_reason = (
1726 "Mismatch between output shape provided and expected output shape"
1727 )
1728
1729 if check:
1730 output_shape = kwargs["output_shape"]
1731 input_shape = kwargs["input_shape"]
1732 axis = kwargs["axis"]
1733 valid_params = axis >= 0 and axis < len(input_shape)
1734
1735 if valid_params and (len(input_shape) - 1) != len(output_shape):
1736 error_result = True
1737
1738 info_dict = {
1739 "error_name": error_name,
1740 "error_result": error_result,
1741 "error_reason": error_reason,
1742 "param_reqs": param_reqs,
1743 }
1744 return info_dict
1745
1746 @staticmethod
1747 def evKernelSmallerOne(check=False, **kwargs):
1748 error_name = ErrorIf.KernelSmallerOne
1749 param_reqs = {"rank": None, "dtype": None, "shape": None}
1750 error_result = False
1751 error_reason = "At least one kernel dimension is smaller than zero"
1752
1753 if check:
1754 kernel = kwargs["kernel"]
1755 if min(kernel) < 1:
1756 error_result = True
1757
1758 info_dict = {
1759 "error_name": error_name,
1760 "error_result": error_result,
1761 "error_reason": error_reason,
1762 "param_reqs": param_reqs,
1763 }
1764 return info_dict
1765
1766 @staticmethod
1767 def evStrideSmallerOne(check=False, **kwargs):
1768 error_name = ErrorIf.StrideSmallerOne
1769 param_reqs = {"rank": None, "dtype": None, "shape": None}
1770 error_result = False
1771 error_reason = "At least one stride dimension is smaller than zero"
1772
1773 if check:
1774 stride = kwargs["stride"]
1775 if min(stride) < 1:
1776 error_result = True
1777
1778 info_dict = {
1779 "error_name": error_name,
1780 "error_result": error_result,
1781 "error_reason": error_reason,
1782 "param_reqs": param_reqs,
1783 }
1784 return info_dict
1785
1786 @staticmethod
1787 def evDilationSmallerOne(check=False, **kwargs):
1788 error_result = check and min(kwargs["dilation"]) < 1
1789 return {
1790 "error_name": ErrorIf.DilationSmallerOne,
1791 "error_reason": "At least one dilation is smaller than one",
1792 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1793 "error_result": error_result,
1794 }
1795
1796 @staticmethod
1797 def evScaleTrue(check=False, **kwargs):
1798 error_name = ErrorIf.ScaleTrue
1799 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1800 error_result = False
1801 error_reason = "Scale set to true but input type is INT48"
1802
1803 if check:
1804 input_dtype = kwargs["input_dtype"]
1805 scale32 = kwargs["scale32"]
1806 if scale32 and input_dtype == DType.INT48:
1807 error_result = True
1808
1809 info_dict = {
1810 "error_name": error_name,
1811 "error_result": error_result,
1812 "error_reason": error_reason,
1813 "param_reqs": param_reqs,
1814 }
1815 return info_dict
1816
1817 @staticmethod
1818 def evScaleNotTrue(check=False, **kwargs):
1819 error_name = ErrorIf.ScaleNotTrue
1820 param_reqs = {"rank": None, "dtype": None, "shape": None}
1821 error_result = False
1822 error_reason = "Scale set to false but double round set to true"
1823
1824 if check:
1825 scale32 = kwargs["scale32"]
1826 double_round = kwargs["double_round"]
1827 if not scale32 and double_round:
1828 error_result = True
1829
1830 info_dict = {
1831 "error_name": error_name,
1832 "error_result": error_result,
1833 "error_reason": error_reason,
1834 "param_reqs": param_reqs,
1835 }
1836 return info_dict
1837
1838 @staticmethod
1839 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1840 error_name = ErrorIf.TensorSizeInputOutputMismatch
1841 param_reqs = {"rank": None, "dtype": None, "shape": None}
1842 error_result = False
1843 error_reason = "Input tensor size does not match output tensor size"
Jerry Ge264f7fa2023-04-21 22:49:57 +00001844 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001845
1846 if check:
1847 input_shape = kwargs["input_shape"]
1848 output_shape = kwargs["output_shape"]
Jerry Ge264f7fa2023-04-21 22:49:57 +00001849 shape_inferencing = False
1850 if -1 in output_shape and op["op"] == Op.RESHAPE:
1851 shape_inferencing = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001852 input_size = np.prod(input_shape)
1853 output_size = np.prod(output_shape)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001854 if input_size != output_size and not shape_inferencing:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001855 error_result = True
1856
1857 info_dict = {
1858 "error_name": error_name,
1859 "error_result": error_result,
1860 "error_reason": error_reason,
1861 "param_reqs": param_reqs,
1862 }
1863 return info_dict
1864
1865 @staticmethod
1866 def evStartSmallerZero(check=False, **kwargs):
1867 error_name = ErrorIf.StartSmallerZero
1868 param_reqs = {"rank": None, "dtype": None, "shape": None}
1869 error_result = False
1870 error_reason = "Starting point smaller than zero"
1871
1872 if check:
1873 input_shape = kwargs["input_shape"]
1874 start = kwargs["start"]
1875 rank = len(input_shape)
1876 if len(start) == rank:
1877 for index in range(rank):
1878 if start[index] < 0:
1879 error_result = True
1880
1881 info_dict = {
1882 "error_name": error_name,
1883 "error_result": error_result,
1884 "error_reason": error_reason,
1885 "param_reqs": param_reqs,
1886 }
1887 return info_dict
1888
1889 @staticmethod
1890 def evSizeSmallerEqualZero(check=False, **kwargs):
1891 error_name = ErrorIf.SizeSmallerEqualZero
1892 param_reqs = {"rank": None, "dtype": None, "shape": None}
1893 error_result = False
1894 error_reason = "Size smaller than or equal to zero"
1895
1896 if check:
1897 input_shape = kwargs["input_shape"]
1898 size = kwargs["size"]
1899 rank = len(input_shape)
1900 if len(size) == rank:
1901 for index in range(rank):
1902 if size[index] <= 0:
1903 error_result = True
1904
1905 info_dict = {
1906 "error_name": error_name,
1907 "error_result": error_result,
1908 "error_reason": error_reason,
1909 "param_reqs": param_reqs,
1910 }
1911 return info_dict
1912
1913 @staticmethod
1914 def evStartSizeOutsideBounds(check=False, **kwargs):
1915 error_name = ErrorIf.StartSizeOutsideBounds
1916 param_reqs = {"rank": None, "dtype": None, "shape": None}
1917 error_result = False
1918 error_reason = "starting point plus size larger than input dimension"
1919
1920 if check:
1921 input_shape = kwargs["input_shape"]
1922 start = kwargs["start"]
1923 size = kwargs["size"]
1924 rank = len(input_shape)
1925 if len(start) == rank and len(size) == rank:
1926 for index in range(rank):
1927 if start[index] + size[index] > input_shape[index]:
1928 error_result = True
1929
1930 info_dict = {
1931 "error_name": error_name,
1932 "error_result": error_result,
1933 "error_reason": error_reason,
1934 "param_reqs": param_reqs,
1935 }
1936 return info_dict
1937
1938 @staticmethod
1939 def evSizeOutputShapeMismatch(check=False, **kwargs):
1940 error_name = ErrorIf.SizeOutputShapeMismatch
1941 param_reqs = {"rank": None, "dtype": None, "shape": None}
1942 error_result = False
1943 error_reason = "Size does not match output dimension"
1944
1945 if check:
1946 input_shape = kwargs["input_shape"]
1947 output_shape = kwargs["output_shape"]
1948 size = kwargs["size"]
Luke Huttona4e48ca2023-02-22 11:53:48 +00001949
1950 if len(input_shape) == len(output_shape):
1951 rank = len(input_shape)
1952 if len(size) == rank:
1953 for index in range(rank):
1954 if size[index] != output_shape[index]:
1955 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001956
1957 info_dict = {
1958 "error_name": error_name,
1959 "error_result": error_result,
1960 "error_reason": error_reason,
1961 "param_reqs": param_reqs,
1962 }
1963 return info_dict
1964
1965 @staticmethod
1966 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1967 error_name = ErrorIf.InputSizeStartLengthMismatch
1968 param_reqs = {"rank": None, "dtype": None, "shape": None}
1969 error_result = False
1970 error_reason = "rank of input not equal to length of start or size"
1971
1972 if check:
1973 input_shape = kwargs["input_shape"]
1974 start = kwargs["start"]
1975 size = kwargs["size"]
1976 rank = len(input_shape)
1977 if rank != len(start) or rank != len(size):
1978 error_result = True
1979
1980 info_dict = {
1981 "error_name": error_name,
1982 "error_result": error_result,
1983 "error_reason": error_reason,
1984 "param_reqs": param_reqs,
1985 }
1986 return info_dict
1987
1988 @staticmethod
1989 def evIndexOutsideBounds(check=False, **kwargs):
1990 error_name = ErrorIf.IndexOutsideBounds
1991 param_reqs = {"rank": None, "dtype": None, "shape": None}
1992 error_result = False
1993 error_reason = "Index outside of allowed bounds"
1994
1995 if check:
1996 input_shape = kwargs["input_shape"]
1997 perms = kwargs["perms"]
1998 rank = len(input_shape)
1999
2000 for index in perms:
2001 if index < 0 or index > rank:
2002 error_result = True
2003
2004 info_dict = {
2005 "error_name": error_name,
2006 "error_result": error_result,
2007 "error_reason": error_reason,
2008 "param_reqs": param_reqs,
2009 }
2010 return info_dict
2011
2012 @staticmethod
2013 def evIndexUsedTwice(check=False, **kwargs):
2014 error_name = ErrorIf.IndexUsedTwice
2015 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2016 error_result = False
2017 error_reason = "Index used multiple times"
2018
2019 if check:
2020 perms = kwargs["perms"]
2021
2022 unique_indices = []
2023 for index in perms:
2024 if index in unique_indices:
2025 error_result = True
2026 else:
2027 unique_indices.append(index)
2028
2029 info_dict = {
2030 "error_name": error_name,
2031 "error_result": error_result,
2032 "error_reason": error_reason,
2033 "param_reqs": param_reqs,
2034 }
2035 return info_dict
2036
2037 @staticmethod
2038 def evMaxSmallerMin(check=False, **kwargs):
2039 error_name = ErrorIf.MaxSmallerMin
2040 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2041 error_result = False
2042 error_reason = "Max value smaller than min value"
2043
2044 if check:
2045 max_val = kwargs["max_val"]
2046 min_val = kwargs["min_val"]
2047 if max_val < min_val:
2048 error_result = True
2049
2050 info_dict = {
2051 "error_name": error_name,
2052 "error_result": error_result,
2053 "error_reason": error_reason,
2054 "param_reqs": param_reqs,
2055 }
2056 return info_dict
2057
2058 @staticmethod
2059 def evConcatInputRankMismatch(check=False, **kwargs):
2060 error_name = ErrorIf.ConcatInputRankMismatch
2061 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2062 error_result = False
2063 error_reason = "Input ranks are not identical"
2064
2065 if check:
2066 inputs = kwargs["inputs"]
2067 input_shape = kwargs["input_shape"]
2068 for input in inputs:
2069 if len(input.shape) != len(input_shape):
2070 error_result = True
2071
2072 info_dict = {
2073 "error_name": error_name,
2074 "error_result": error_result,
2075 "error_reason": error_reason,
2076 "param_reqs": param_reqs,
2077 }
2078 return info_dict
2079
2080 @staticmethod
2081 def evConcatInputDimMismatch(check=False, **kwargs):
2082 error_name = ErrorIf.ConcatInputDimMismatch
2083 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2084 error_result = False
2085 error_reason = "Input dimensions differ on too many axes"
2086
2087 if check:
2088 inputs = kwargs["inputs"]
2089 input_shape = kwargs["input_shape"]
2090 axis = kwargs["axis"]
2091
2092 # Ensure rank is valid before checking dims.
2093 valid_rank = True
2094 for input in inputs:
2095 if len(input.shape) != len(input_shape):
2096 valid_rank = False
2097
2098 if valid_rank:
2099 for input in inputs:
2100 for i, dim in enumerate(input.shape):
2101 if dim != input_shape[i] and axis != i:
2102 error_result = True
2103
2104 info_dict = {
2105 "error_name": error_name,
2106 "error_result": error_result,
2107 "error_reason": error_reason,
2108 "param_reqs": param_reqs,
2109 }
2110 return info_dict
2111
2112 @staticmethod
2113 def evConcatShapeSumMismatch(check=False, **kwargs):
2114 error_name = ErrorIf.ConcatShapeSumMismatch
2115 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2116 error_result = False
2117 error_reason = "Sum of dimensions on axis not equal to output dimension"
2118
2119 if check:
2120 inputs = kwargs["inputs"]
2121 input_shape = kwargs["input_shape"]
2122 output_shape = kwargs["output_shape"]
2123 axis = kwargs["axis"]
2124
2125 # Ensure rank is valid before checking dims.
2126 valid_params = True
2127 for input in inputs:
2128 if len(input.shape) != len(input_shape):
2129 valid_params = False
2130 if axis < 0 or axis > len(input_shape):
2131 valid_params = False
2132
2133 if valid_params:
2134 axis_dim_sum = 0
2135 for input in inputs:
2136 axis_dim_sum += input.shape[axis]
2137
2138 if axis_dim_sum != output_shape[axis]:
2139 error_result = True
2140
2141 info_dict = {
2142 "error_name": error_name,
2143 "error_result": error_result,
2144 "error_reason": error_reason,
2145 "param_reqs": param_reqs,
2146 }
2147 return info_dict
2148
2149 @staticmethod
2150 def evInputListThenGraphMismatch(check=False, **kwargs):
2151 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2152 param_reqs = {"rank": None, "dtype": None, "shape": None}
2153 error_result = False
2154 error_reason = "Input list shape does not match then-graph shape"
2155
2156 if check:
2157 a = kwargs["a"]
2158 b = kwargs["b"]
2159 basicBlocks = kwargs["basicBlocks"]
2160 then_block = basicBlocks[1]
2161 then_inputs = then_block.inputs
2162 then_tens = then_block.tensors
2163 if (a.shape != then_tens[then_inputs[0]].shape) or (
2164 b.shape != then_tens[then_inputs[1]].shape
2165 ):
2166 error_result = True
2167
2168 info_dict = {
2169 "error_name": error_name,
2170 "error_result": error_result,
2171 "error_reason": error_reason,
2172 "param_reqs": param_reqs,
2173 }
2174 return info_dict
2175
2176 @staticmethod
2177 def evInputListElseGraphMismatch(check=False, **kwargs):
2178 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2179 param_reqs = {"rank": None, "dtype": None, "shape": None}
2180 error_result = False
2181 error_reason = "Input list shape does not match else-graph shape"
2182
2183 if check:
2184 a = kwargs["a"]
2185 b = kwargs["b"]
2186 basicBlocks = kwargs["basicBlocks"]
2187 else_block = basicBlocks[2]
2188 else_inputs = else_block.inputs
2189 else_tens = else_block.tensors
2190 if (a.shape != else_tens[else_inputs[0]].shape) or (
2191 b.shape != else_tens[else_inputs[1]].shape
2192 ):
2193 error_result = True
2194
2195 info_dict = {
2196 "error_name": error_name,
2197 "error_result": error_result,
2198 "error_reason": error_reason,
2199 "param_reqs": param_reqs,
2200 }
2201 return info_dict
2202
2203 @staticmethod
2204 def evOutputListThenGraphMismatch(check=False, **kwargs):
2205 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2206 param_reqs = {"rank": None, "dtype": None, "shape": None}
2207 error_result = False
2208 error_reason = "Output list shape does not match then-graph shape"
2209
2210 if check:
2211 basicBlocks = kwargs["basicBlocks"]
2212 cond_block = basicBlocks[0]
2213 cond_outputs = cond_block.outputs
2214 cond_tens = cond_block.tensors
2215 then_block = basicBlocks[1]
2216 then_outputs = then_block.outputs
2217 then_tens = then_block.tensors
2218 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2219 error_result = True
2220
2221 info_dict = {
2222 "error_name": error_name,
2223 "error_result": error_result,
2224 "error_reason": error_reason,
2225 "param_reqs": param_reqs,
2226 }
2227 return info_dict
2228
2229 @staticmethod
2230 def evOutputListElseGraphMismatch(check=False, **kwargs):
2231 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2232 param_reqs = {"rank": None, "dtype": None, "shape": None}
2233 error_result = False
2234 error_reason = "Output list shape does not match else-graph shape"
2235
2236 if check:
2237 basicBlocks = kwargs["basicBlocks"]
2238 cond_block = basicBlocks[0]
2239 cond_outputs = cond_block.outputs
2240 cond_tens = cond_block.tensors
2241 else_block = basicBlocks[2]
2242 else_outputs = else_block.outputs
2243 else_tens = else_block.tensors
2244 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2245 error_result = True
2246
2247 info_dict = {
2248 "error_name": error_name,
2249 "error_result": error_result,
2250 "error_reason": error_reason,
2251 "param_reqs": param_reqs,
2252 }
2253 return info_dict
2254
2255 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002256 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2257 error_name = ErrorIf.CondIfCondNotMatchingBool
2258 param_reqs = {"rank": None, "dtype": None, "shape": None}
2259 error_result = False
2260 error_reason = "Conditional tensor does not match bool type"
2261
2262 if check:
2263 cond = kwargs["cond"]
2264 if cond.dtype != DType.BOOL:
2265 error_result = True
2266
2267 info_dict = {
2268 "error_name": error_name,
2269 "error_result": error_result,
2270 "error_reason": error_reason,
2271 "param_reqs": param_reqs,
2272 }
2273 return info_dict
2274
2275 @staticmethod
2276 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2277 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2278 param_reqs = {"rank": None, "dtype": None, "shape": None}
2279 error_result = False
2280 error_reason = "Conditional tensor is not equal to a size of one"
2281
2282 if check:
2283 cond = kwargs["cond"]
2284 # Size of 1 is equivalent to rank 0
2285 if len(cond.shape) != 0:
2286 error_result = True
2287
2288 info_dict = {
2289 "error_name": error_name,
2290 "error_result": error_result,
2291 "error_reason": error_reason,
2292 "param_reqs": param_reqs,
2293 }
2294 return info_dict
2295
2296 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002297 def evInputListOutputListMismatch(check=False, **kwargs):
2298 error_name = ErrorIf.InputListOutputListMismatch
2299 param_reqs = {"rank": None, "dtype": None, "shape": None}
2300 error_result = False
2301 error_reason = "Input list does not match output list"
2302
2303 if check:
2304 basicBlocks = kwargs["basicBlocks"]
2305 while_block = basicBlocks[0]
2306 while_inputs = while_block.inputs
2307 while_outputs = while_block.outputs
2308 while_tens = while_block.tensors
2309 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2310 error_result = True
2311
2312 info_dict = {
2313 "error_name": error_name,
2314 "error_result": error_result,
2315 "error_reason": error_reason,
2316 "param_reqs": param_reqs,
2317 }
2318 return info_dict
2319
2320 @staticmethod
2321 def evInputListCondGraphMismatch(check=False, **kwargs):
2322 error_name = ErrorIf.InputListCondGraphMismatch
2323 param_reqs = {"rank": None, "dtype": None, "shape": None}
2324 error_result = False
2325 error_reason = "Input list does not match cond graph"
2326
2327 if check:
2328 basicBlocks = kwargs["basicBlocks"]
2329 while_block = basicBlocks[0]
2330 while_inputs = while_block.inputs
2331 while_tens = while_block.tensors
2332 cond_block = basicBlocks[1]
2333 cond_inputs = cond_block.inputs
2334 cond_tens = cond_block.tensors
2335 if (
2336 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2337 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2338 error_result = True
2339
2340 info_dict = {
2341 "error_name": error_name,
2342 "error_result": error_result,
2343 "error_reason": error_reason,
2344 "param_reqs": param_reqs,
2345 }
2346 return info_dict
2347
2348 @staticmethod
2349 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2350 error_name = ErrorIf.InputListBodyGraphInputMismatch
2351 param_reqs = {"rank": None, "dtype": None, "shape": None}
2352 error_result = False
2353 error_reason = "Input list does not match body graph input"
2354
2355 if check:
2356 basicBlocks = kwargs["basicBlocks"]
2357 while_block = basicBlocks[0]
2358 while_inputs = while_block.inputs
2359 while_tens = while_block.tensors
2360 body_block = basicBlocks[2]
2361 body_outputs = body_block.inputs
2362 body_tens = body_block.tensors
2363 if (
2364 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2365 ) or (
2366 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2367 ):
2368 error_result = True
2369
2370 info_dict = {
2371 "error_name": error_name,
2372 "error_result": error_result,
2373 "error_reason": error_reason,
2374 "param_reqs": param_reqs,
2375 }
2376 return info_dict
2377
2378 @staticmethod
2379 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2380 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2381 param_reqs = {"rank": None, "dtype": None, "shape": None}
2382 error_result = False
2383 error_reason = "Input list does not match body graph output"
2384
2385 if check:
2386 basicBlocks = kwargs["basicBlocks"]
2387 while_block = basicBlocks[0]
2388 while_inputs = while_block.inputs
2389 while_tens = while_block.tensors
2390 body_block = basicBlocks[2]
2391 body_outputs = body_block.outputs
2392 body_tens = body_block.tensors
2393 if (
2394 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2395 ) or (
2396 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2397 ):
2398 error_result = True
2399 info_dict = {
2400 "error_name": error_name,
2401 "error_result": error_result,
2402 "error_reason": error_reason,
2403 "param_reqs": param_reqs,
2404 }
2405 return info_dict
2406
2407 @staticmethod
2408 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2409 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2410 param_reqs = {"rank": None, "dtype": None, "shape": None}
2411 error_result = False
2412 error_reason = "Cond graph output is not a match list of booleans"
2413
2414 if check:
2415 basicBlocks = kwargs["basicBlocks"]
2416 cond_block = basicBlocks[1]
2417 cond_outputs = cond_block.outputs
2418 cond_tens = cond_block.tensors
2419 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2420 error_result = True
2421
2422 info_dict = {
2423 "error_name": error_name,
2424 "error_result": error_result,
2425 "error_reason": error_reason,
2426 "param_reqs": param_reqs,
2427 }
2428 return info_dict
2429
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002430 @staticmethod
2431 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2432 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2433 param_reqs = {"rank": None, "dtype": None, "shape": None}
2434 error_result = False
2435 error_reason = "Cond graph output is not a shape of size one"
2436
2437 if check:
2438 basicBlocks = kwargs["basicBlocks"]
2439 cond_block = basicBlocks[1]
2440 cond_outputs = cond_block.outputs
2441 cond_tens = cond_block.tensors
2442 # Size of 1 is equivalent to rank 0
2443 if len(cond_tens[cond_outputs[0]].shape) != 0:
2444 error_result = True
2445
2446 info_dict = {
2447 "error_name": error_name,
2448 "error_result": error_result,
2449 "error_reason": error_reason,
2450 "param_reqs": param_reqs,
2451 }
2452 return info_dict
2453
Luke Hutton261b7b62023-01-10 14:50:31 +00002454 @staticmethod
2455 def evKernelNotPowerOfTwo(check=False, **kwargs):
2456 error_name = ErrorIf.KernelNotPowerOfTwo
2457 param_reqs = {"rank": None, "dtype": None, "shape": None}
2458 error_result = False
2459 error_reason = "kernel height and/or width not a power of two"
2460
2461 def is_power_of_two(x):
2462 return math.log(x, 2).is_integer()
2463
2464 if check:
2465 shape = kwargs["input_shape"]
2466 if len(shape) == 3:
2467 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2468 error_result = not valid_kernel
2469
2470 info_dict = {
2471 "error_name": error_name,
2472 "error_result": error_result,
2473 "error_reason": error_reason,
2474 "param_reqs": param_reqs,
2475 }
2476 return info_dict
2477
Luke Hutton57287132023-02-06 14:54:18 +00002478 @staticmethod
2479 def evFFTInputShapeMismatch(check=False, **kwargs):
2480 error_name = ErrorIf.FFTInputShapeMismatch
2481 param_reqs = {"rank": None, "dtype": None, "shape": None}
2482 error_result = False
2483 error_reason = "Mismatch between real and imaginary input shapes"
2484
2485 if check:
2486 input1 = kwargs["input1"]
2487 input2 = kwargs["input2"]
2488
2489 if input1.shape != input2.shape:
2490 error_result = True
2491
2492 info_dict = {
2493 "error_name": error_name,
2494 "error_result": error_result,
2495 "error_reason": error_reason,
2496 "param_reqs": param_reqs,
2497 }
2498 return info_dict
2499
2500 @staticmethod
2501 def evFFTOutputShapeMismatch(check=False, **kwargs):
2502 error_name = ErrorIf.FFTOutputShapeMismatch
2503 param_reqs = {"rank": None, "dtype": None, "shape": None}
2504 error_result = False
2505 error_reason = (
2506 "Mismatch between provided and expected output kernel (H, W) shape"
2507 )
2508
2509 if check:
2510 op = kwargs["op"]
2511 input_shape = kwargs["input_shape"]
2512
2513 if len(input_shape) == 3:
2514 output_shapes = kwargs["output_shape"]
2515
2516 # Ignoring batch size (N) from input shape
2517 expected_shape = input_shape[1:]
2518 if op["op"] == Op.RFFT2D:
2519 expected_shape[1] = expected_shape[1] // 2 + 1
2520
2521 # Ignoring batch size (N) from output shapes
2522 output_shape_0 = output_shapes[0][1:]
2523 output_shape_1 = output_shapes[1][1:]
2524 # Ensure sure the kernel sizes (H, W) of both outputs match the expected
2525 if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
2526 error_result = True
2527
2528 info_dict = {
2529 "error_name": error_name,
2530 "error_result": error_result,
2531 "error_reason": error_reason,
2532 "param_reqs": param_reqs,
2533 }
2534 return info_dict
2535
Jerry Ge264f7fa2023-04-21 22:49:57 +00002536 @staticmethod
Jerry Ge135c9552023-05-23 20:59:32 +00002537 def calculateBroadcastShape(input_shape_a, input_shape_b):
2538 if input_shape_a is not None and input_shape_b is not None:
2539 calculated_shape = input_shape_a.copy()
2540 for idx in range(len(calculated_shape)):
2541 if calculated_shape[idx] == 1:
2542 calculated_shape[idx] = input_shape_b[idx]
2543 elif (
2544 input_shape_b[idx] != 1
2545 and input_shape_b[idx] != calculated_shape[idx]
2546 ):
2547 return None
2548 return calculated_shape
2549 else:
2550 return None
2551
2552 @staticmethod
2553 def evBroadcastShapesMismatch(check=False, **kwargs):
2554 error_name = ErrorIf.BroadcastShapesMismatch
2555 param_reqs = {"rank": None, "dtype": None, "shape": None}
2556 error_result = False
2557 error_reason = "Broadcast shape calculating failed"
2558
2559 if check:
2560 input_shape_a = kwargs["input1"].shape
2561 input_shape_b = kwargs["input2"].shape
2562 input_shape_c = (
2563 kwargs["input3"].shape if "input3" in kwargs else input_shape_b
2564 )
2565
2566 if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
2567 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
2568 input_shape_c,
2569 TosaErrorValidator.calculateBroadcastShape(
2570 input_shape_a, input_shape_b
2571 ),
2572 )
2573 error_result = calculated_shape is None
2574
2575 info_dict = {
2576 "error_name": error_name,
2577 "error_result": error_result,
2578 "error_reason": error_reason,
2579 "param_reqs": param_reqs,
2580 }
2581 return info_dict
2582
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002583
2584class TosaInvalidValidator:
2585 @staticmethod
2586 def ivWrongDataTypeOrModeResize(**kwargs):
2587 input_dtype = kwargs["input_dtype"]
2588 args = kwargs["args"]
2589 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002590 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002591
2592 if mode == ResizeMode.BILINEAR:
2593 # Invalid output data type / Invalid input datatype
2594 return (
2595 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002596 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002597 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002598 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002599 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002600 )
2601 elif mode == ResizeMode.NEAREST:
2602 # Invalid output data type / Invalid input datatype
2603 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002604 input_dtype
2605 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002606 )
2607 else:
2608 # Invalid resize mode
2609 return True
2610
2611 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002612 def ivHeightWidthInvalid(**kwargs):
2613 opName = kwargs["opName"]
2614
2615 inputShapes = kwargs["shapeList"]
2616 input_shape = inputShapes[0]
2617
2618 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002619
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002620 if isinstance(args, dict):
2621 args_dict = args
2622 else:
2623 # Create args_dict from list elements
2624 # TODO - Remove this once all NWHC operators agFunctions have been
2625 # converted to args_dict output
2626
2627 # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
2628 stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
2629 args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]}
2630 # Alias different info for each op
2631 args_dict["kernel"] = args[pad_idx + 1]
2632 args_dict["out_shape"] = args[pad_idx + 1]
2633 args_dict["dilation"] = args[pad_idx + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002634
2635 # Common info for all ops
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002636 strides = args_dict["stride"]
2637 padding = args_dict["pad"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002638
2639 if opName.endswith("pool2d"):
2640 # avg_pool2d, max_pool2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002641 kernel_shape = args_dict["kernel"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002642 h = (
2643 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2644 ) // strides[0]
2645 w = (
2646 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2647 ) // strides[1]
2648 # return True if any dimension is < 1
2649 return h < 1 or w < 1
2650
2651 if opName.startswith("transpose_conv2d"):
2652 # transpose_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002653 output_shape = args_dict["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002654 filter_shape = inputShapes[1]
2655 kernel_shape = filter_shape[1:-1]
2656
TatWai Chong24594f52022-06-08 00:48:04 -07002657 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002658 """Calculate the transpose_conv2d output size for a dimension."""
2659 return (in_size - 1) * stride + kernel_size + in_pad + out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002660
Jeremy Johnson0c716862023-04-13 17:18:19 +01002661 h = get_out_size(
2662 input_shape[1],
2663 strides[0],
2664 kernel_shape[0],
2665 padding[0],
2666 padding[1],
2667 )
2668 w = get_out_size(
2669 input_shape[2],
2670 strides[1],
2671 kernel_shape[1],
2672 padding[2],
2673 padding[3],
2674 )
2675 if output_shape[1] == h and output_shape[2] == w:
2676 return False
2677 # output shape does not match the expected shape
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002678 return True
2679
2680 if "conv2d" in opName or "conv3d" in opName:
2681 # conv2d, conv3d, depthwise_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002682 dilations = args_dict["dilation"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002683 filter_shape = inputShapes[1]
2684 kernel_shape = (
2685 filter_shape[0:2]
2686 if opName.startswith("depthwise_conv2d")
2687 else filter_shape[1:-1]
2688 )
2689
2690 for i in range(len(kernel_shape)):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002691 pad_offset = i * 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002692 dim = (
2693 input_shape[i + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002694 - 1
2695 + padding[pad_offset]
2696 + padding[pad_offset + 1]
2697 - (kernel_shape[i] - 1) * dilations[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002698 ) // strides[i] + 1
2699 # return True if any dimension is < 1
2700 if dim < 1:
2701 return True
2702 return False
2703
2704 assert False, f"Unrecognized Op: {opName}"
2705
2706 @staticmethod
2707 def ivNonPositiveOutputShape(**kwargs):
2708 args = kwargs["args"]
Jeremy Johnson95a67102024-01-10 14:16:39 +00002709 output_shape = args["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002710 if output_shape[1] <= 0 or output_shape[2] <= 0:
2711 # Negative output shape
2712 return True
2713 return False