blob: 3972edd95bd7336a2437a8ff0d9447c9162e7abb [file] [log] [blame]
Won Jeon74342e52024-01-09 00:34:40 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003import logging
Luke Hutton261b7b62023-01-10 14:50:31 +00004import math
5
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01007from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_utils import product
9from generator.tosa_utils import usableDTypes
10from generator.tosa_utils import valueToName
11from tosa.DType import DType
12from tosa.Op import Op
13from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000014
Jeremy Johnsonaf090182024-02-13 18:25:39 +000015logging.basicConfig()
16logger = logging.getLogger("tosa_verif_build_tests")
17
Matthew Haddone86fd342021-09-07 16:12:21 +010018
19class ErrorIf(object):
20 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010021 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
22 ScaleNLargerMax = "ScaleNLargerMax"
23 ScaleDLargerMax = "ScaleDLargerMax"
24 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010025 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010026 BorderSmallerMin = "BorderSmallerMin"
27 BorderLargerEqualMax = "BorderLargerEqualMax"
28 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
29 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010030 WrongInputType = "WrongInputType"
31 WrongOutputType = "WrongOutputType"
32 WrongInputList = "WrongInputList"
33 WrongOutputList = "WrongOutputList"
34 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010035 BatchMismatch = "BatchMismatch"
36 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010037 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000038 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010039 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010040 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010041 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010042 AxisSmallerZero = "AxisSmallerZero"
43 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010044 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
45 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010046 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010047 KernelSmallerOne = "KernelSmallerOne"
48 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000049 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010050 PadSmallerZero = "PadSmallerZero"
51 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010052 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010053 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010054 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
55 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
56 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010057 ScaleNotTrue = "ScaleNotTrue"
58 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010059 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
60 StartSmallerZero = "StartSmallerZero"
61 SizeSmallerEqualZero = "SizeSmallerEqualZero"
62 StartSizeOutsideBounds = "StartSizeOutsideBounds"
63 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
64 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
65 IndexOutsideBounds = "IndexOutsideBounds"
66 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010067 MaxSmallerMin = "MaxSmallerMin"
68 ConcatInputRankMismatch = "ConcatInputRankMismatch"
69 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010070 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010071 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
72 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
73 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
74 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
75 InputListOutputListMismatch = "InputListOutputListMismatch"
76 InputListCondGraphMismatch = "InputListCondGraphMismatch"
77 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
78 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
79 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010080 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
81 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000082 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
83 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
84 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000085 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Luke Hutton57287132023-02-06 14:54:18 +000086 FFTInputShapeMismatch = "FFTInputShapeMismatch"
87 FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
Jerry Ge264f7fa2023-04-21 22:49:57 +000088 ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
89 ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
Jerry Ge135c9552023-05-23 20:59:32 +000090 BroadcastShapesMismatch = "BroadcastShapesMismatch"
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +000091 WrongAccumulatorType = "WrongAccumulatorType"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092
93
94class TosaErrorIfArgGen:
95 @staticmethod
96 def eiResizeErrorIf(
97 testGen,
98 error_name,
99 mode,
100 dtype,
101 shapeList,
102 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100103 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100104 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100105 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100107 if error_name == ErrorIf.ScaleSmallerEqualZero:
108 index = testGen.randInt(low=0, high=4)
109 scale[index] = testGen.rng.choice([-2, -1, 0])
110 elif error_name == ErrorIf.ScaleNLargerMax:
111 index = testGen.rng.choice([0, 2])
112 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
113 elif error_name == ErrorIf.ScaleDLargerMax:
114 index = testGen.rng.choice([1, 3])
115 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100116
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100117 if error_name == ErrorIf.OffsetLargerEqualMax:
118 index = testGen.rng.choice([0, 1])
119 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
120 elif error_name == ErrorIf.OffsetSmallerMin:
121 index = testGen.rng.choice([0, 1])
122 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
123
124 if error_name == ErrorIf.BorderLargerEqualMax:
125 index = testGen.rng.choice([0, 1])
126 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
127 elif error_name == ErrorIf.BorderSmallerMin:
128 index = testGen.rng.choice([0, 1])
129 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100130
131 if error_name == ErrorIf.WrongOutputType:
132 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
133 incorrect_types = (
134 DType.INT4,
135 DType.INT16,
136 DType.INT32,
137 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100138 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100139 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100140 )
141 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
142 incorrect_types = (
143 DType.INT4,
144 DType.INT8,
145 DType.INT32,
146 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100147 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100148 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100149 )
150 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
151 incorrect_types = (
152 DType.INT4,
153 DType.INT8,
154 DType.INT16,
155 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100156 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100157 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100158 )
159 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
160 incorrect_types = (
161 DType.INT4,
162 DType.INT8,
163 DType.INT16,
164 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100165 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100166 DType.FP16,
167 )
168 elif dtype == DType.FP16:
169 incorrect_types = (
170 DType.INT4,
171 DType.INT8,
172 DType.INT16,
173 DType.INT32,
174 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100175 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100176 )
James Ward24dbc422022-10-19 12:20:31 +0100177 elif dtype == DType.BF16:
178 incorrect_types = (
179 DType.INT4,
180 DType.INT8,
181 DType.INT16,
182 DType.INT32,
183 DType.INT48,
184 DType.FP32,
185 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100186 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100187 incorrect_types = (
188 DType.INT4,
189 DType.INT8,
190 DType.INT16,
191 DType.INT32,
192 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100193 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100194 )
195 outputDType = testGen.rng.choice(a=incorrect_types)
196
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100197 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100198
199 @staticmethod
200 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
201 if (
202 error_name == ErrorIf.StrideSmallerOne
203 # padding must not exceed the kernel size
204 and pad[0] < kernel[0]
205 and pad[1] < kernel[0]
206 and pad[2] < kernel[1]
207 and pad[3] < kernel[1]
208 ):
209 wrongStride = (
210 testGen.rng.choice([0, -1, -2, -3]),
211 testGen.rng.choice([0, -1, -2, -3]),
212 )
213 return wrongStride, pad, kernel
214 elif error_name == ErrorIf.PadSmallerZero:
215 wrongPad = (
216 testGen.rng.choice([-1, -2, -3]),
217 testGen.rng.choice([-1, -2, -3]),
218 testGen.rng.choice([-1, -2, -3]),
219 testGen.rng.choice([-1, -2, -3]),
220 )
221 return stride, wrongPad, kernel
222 elif error_name == ErrorIf.KernelSmallerOne:
223 wrongKernel = (
224 testGen.rng.choice([0, -1, -2, -3]),
225 testGen.rng.choice([0, -1, -2, -3]),
226 )
227 return stride, pad, wrongKernel
228 elif error_name == ErrorIf.PadLargerEqualKernel:
229 wrongPad = (
230 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
231 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
232 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
233 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
234 )
235 return stride, wrongPad, kernel
236 else:
237 return None, None, None
238
239 @staticmethod
240 def eiRescaleWrongOutputType(input_dtype, output_dtype):
241 if input_dtype == DType.INT8:
242 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
243 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100244 elif input_dtype == DType.INT16:
245 if output_dtype not in [
246 DType.UINT8,
247 DType.INT8,
248 DType.UINT16,
249 DType.INT16,
250 DType.INT32,
251 ]:
252 return True
253 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100254 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
255 return True
256 elif input_dtype == DType.INT48:
257 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
258 return True
259 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100260 if output_dtype not in [DType.INT8, DType.INT16]:
261 return True
262 elif input_dtype == DType.UINT16:
263 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100264 return True
265 return False
266
267 @staticmethod
268 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
269 # Mess up input/output tensors for ERROR_IF checks
270 if error_name == "WrongInputList":
271 add_input = testGen.rng.choice([True, False])
272 if add_input:
273 input_list.append("eiDummyInput")
274 else:
275 input_list = input_list[:-1]
276 elif error_name == "WrongOutputList":
277 add_output = testGen.rng.choice([True, False])
278 if add_output:
279 output_list.append("eiDummyOutput")
280 else:
281 output_list = []
282 return input_list, output_list
283
284 @staticmethod
285 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
286 """Restrict the dimensions and overall size of a shape to
287 max_dim and max_items.
288 """
289 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
290 while product(new_shape) > max_items:
291 new_shape = [max(d - 1, 1) for d in new_shape]
292 return new_shape
293
294 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
295 if error_name == ErrorIf.StartSmallerZero:
296 newStart = []
297 for i in range(len(input_shape)):
298 newStart.append(testGen.rng.choice([-3, -2, -1]))
299 return newStart, size
300 elif error_name == ErrorIf.SizeSmallerEqualZero:
301 newSize = []
302 for i in range(len(input_shape)):
303 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
304 return start, newSize
305 elif error_name == ErrorIf.StartSizeOutsideBounds:
306 newStart, newSize = [], []
307 for i in range(len(input_shape)):
308 newStart.append(input_shape[i] - 1)
309 newSize.append(testGen.rng.choice([2, 3, 4]))
310 return newStart, newSize
311 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
312 remove = testGen.rng.choice([True, False])
TatWai Chongf15bad82024-01-31 21:33:27 -0800313
314 # Get an empty tensor when diminishing dimension on 1-d tensor.
315 if len(start) == 1 or len(size) == 1:
316 remove = False
317
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318 if remove:
319 newStart = start[1:]
320 newSize = size[1:]
321 else:
322 newStart = start
323 newStart.append(1)
324 newSize = size
325 newSize.append(1)
326 return newStart, newSize
327 else:
328 return start, size
329
330 @staticmethod
331 def eiCastErrorIf(testGen, input_dtype):
Won Jeon2c34b462024-02-06 18:37:00 +0000332 # if input_dtype in [DType.BOOL, DType.FP32]:
333 # outputDType = [DType.BOOL, DType.INT48, DType.FP32]
334 if input_dtype in [DType.BOOL]:
335 outputDType = [
336 DType.BOOL,
337 DType.INT48,
338 DType.FP32,
339 DType.FP16,
340 DType.BF16,
341 DType.FP8E4M3,
342 DType.FP8E5M2,
343 ]
344 elif input_dtype in [DType.FP32]:
James Ward736fd1a2023-01-23 17:13:37 +0000345 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
346 elif input_dtype in [DType.FP16, DType.BF16]:
347 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100348 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
349 outputDType = [DType.INT48]
Won Jeon2c34b462024-02-06 18:37:00 +0000350 elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]:
351 outputDType = [
352 DType.BOOL,
353 DType.INT8,
354 DType.INT16,
355 DType.INT32,
356 DType.INT48,
357 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100358 else:
James Ward736fd1a2023-01-23 17:13:37 +0000359 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100360 return outputDType
361
362
363class TosaErrorValidator:
364 @staticmethod
365 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
366 """Check ERROR_IF statements are caught and set the expected result.
367
368 Args:
369 serializer: the serializer to set the expected result in
370 validator_fcns: a sequence of validator functions to verify the result
371 error_name: the name of the ERROR_IF condition to check for
372 kwargs: keyword arguments for the validator functions
373 Returns:
374 True if the result matches the expected result; otherwise False
375 """
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000376 if validator_fcns is None:
377 # Nothing to do
378 return True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100379 overall_result = True
380 for val_fcn in validator_fcns:
381 val_result = val_fcn(True, **kwargs)
382 validator_name = val_result["error_name"]
383 error_result = val_result["error_result"]
384 error_reason = val_result["error_reason"]
385
386 # expect an error IFF the error_name and validator_name match
387 expected_result = error_result == (error_name == validator_name)
388 overall_result &= expected_result
389
390 if expected_result and error_result:
391 serializer.setExpectedReturnCode(2, True, desc=error_reason)
392 elif error_result: # and not expected_result
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000393 logger.error(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100394 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
395 f" Expected: {error_name}, Got: {validator_name}"
396 )
397 elif not expected_result: # and not error_result
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000398 logger.error(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100399 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
400 f" Expected: {error_name}"
401 )
402
403 if not expected_result:
404 for k, v in sorted(kwargs.items()):
405 if k != "op":
406 if k.endswith("dtype"):
407 v = valueToName(DType, v)
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000408 logger.error(f" {k} = {v}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100409
410 return overall_result
411
412 @staticmethod
413 def evWrongInputType(check=False, **kwargs):
414 error_result = False
415
416 # Find the unsupported input data types
417 op = kwargs["op"]
418 input_dtypes = op["types"]
419 allowed_input_dtypes = {
420 t[0] if isinstance(t, list) else t for t in input_dtypes
421 }
422 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
423
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100424 # Turn the wrong dtypes into required list of types
425 if op["op"] in [
426 Op.FULLY_CONNECTED,
427 Op.CONV2D,
428 Op.CONV3D,
429 Op.DEPTHWISE_CONV2D,
430 Op.TRANSPOSE_CONV2D,
431 ]:
432 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
433
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100434 if op["op"] == Op.CLAMP:
435 wrong_input_dtypes.remove(DType.INT48)
436
437 if check:
438 input_dtype = kwargs["input_dtype"]
439 if input_dtype not in allowed_input_dtypes:
440 error_result = True
441
442 info_dict = {
443 "error_name": ErrorIf.WrongInputType,
444 "error_result": error_result,
445 "error_reason": "Input data type not supported for this operator",
446 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
447 }
448 return info_dict
449
450 @staticmethod
451 def evWrongOutputType(check=False, **kwargs):
452 error_result = False
453
454 if check:
455 input_dtype = kwargs["input_dtype"]
456 output_dtype = kwargs["output_dtype"]
457 op = kwargs["op"]
458
459 if op["op"] == Op.RESIZE:
460 mode = kwargs["mode"]
461 if (
462 (
463 mode == ResizeMode.NEAREST
464 and input_dtype == DType.INT8
465 and output_dtype != DType.INT8
466 )
467 or (
468 mode == ResizeMode.NEAREST
469 and input_dtype == DType.INT16
470 and output_dtype != DType.INT16
471 )
472 or (
473 mode == ResizeMode.BILINEAR
474 and input_dtype == DType.INT8
475 and output_dtype != DType.INT32
476 )
477 or (
478 mode == ResizeMode.BILINEAR
479 and input_dtype == DType.INT16
480 and output_dtype != DType.INT48
481 )
James Ward8b390432022-08-12 20:48:56 +0100482 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100483 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100484 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100485 ):
486 error_result = True
487
488 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100489 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
490 input_dtype, output_dtype
491 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100492
493 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
494 if (
495 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
496 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100497 or (
498 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100499 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100500 )
James Ward24dbc422022-10-19 12:20:31 +0100501 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100502 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Won Jeon2c34b462024-02-06 18:37:00 +0000503 or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16)
504 or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100505 ):
506 error_result = True
507
508 elif op["op"] == Op.ARGMAX:
509 if (
James Ward24dbc422022-10-19 12:20:31 +0100510 input_dtype
Won Jeon2c34b462024-02-06 18:37:00 +0000511 in [
512 DType.INT8,
513 DType.INT16,
514 DType.FP16,
515 DType.BF16,
516 DType.FP32,
517 DType.FP8E4M3,
518 DType.FP8E5M2,
519 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100520 and output_dtype != DType.INT32
521 ):
522 error_result = True
523
524 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100525 if (
James Ward24dbc422022-10-19 12:20:31 +0100526 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100527 and output_dtype != DType.INT32
528 ):
529 error_result = True
530 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100531 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100532 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
533 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100534 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100535 error_result = True
536
537 elif op["op"] == Op.TABLE:
538 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
539 error_result = True
540 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
541 error_result = True
542
543 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
544 if output_dtype != DType.BOOL:
545 error_result = True
546
547 elif op["op"] == Op.CAST:
548 if (
549 (
550 input_dtype == DType.BOOL
551 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
552 )
553 or (
554 input_dtype == DType.INT8
555 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100556 not in [
557 DType.BOOL,
558 DType.INT16,
559 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100560 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100561 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100562 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100563 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100564 )
565 or (
566 input_dtype == DType.INT16
567 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100568 not in [
569 DType.BOOL,
570 DType.INT8,
571 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100572 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100573 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100574 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100575 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100576 )
577 or (
578 input_dtype == DType.INT32
579 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100580 not in [
581 DType.BOOL,
582 DType.INT8,
583 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100584 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100585 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100586 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100587 ]
588 )
589 or (
590 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000591 and output_dtype
Won Jeon2c34b462024-02-06 18:37:00 +0000592 not in [
593 DType.INT8,
594 DType.INT16,
595 DType.INT32,
596 DType.FP32,
597 DType.FP8E4M3,
598 DType.FP8E5M2,
599 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100600 )
601 or (
James Ward24dbc422022-10-19 12:20:31 +0100602 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000603 and output_dtype
Won Jeon2c34b462024-02-06 18:37:00 +0000604 not in [
605 DType.INT8,
606 DType.INT16,
607 DType.INT32,
608 DType.FP32,
609 DType.FP8E4M3,
610 DType.FP8E5M2,
611 ]
James Ward24dbc422022-10-19 12:20:31 +0100612 )
613 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100614 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000615 and output_dtype
616 not in [
617 DType.INT8,
618 DType.INT16,
619 DType.INT32,
620 DType.FP16,
621 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000622 DType.FP8E4M3,
623 DType.FP8E5M2,
624 ]
625 )
626 or (
627 input_dtype in [DType.FP8E4M3, DType.FP8E5M2]
628 and output_dtype
629 not in [
630 DType.FP16,
631 DType.BF16,
632 DType.FP32,
James Ward736fd1a2023-01-23 17:13:37 +0000633 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100634 )
635 ):
636 error_result = True
637
Luke Hutton57287132023-02-06 14:54:18 +0000638 elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000639 if not all([ty == input_dtype for ty in output_dtype]):
640 error_result = True
641
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100642 elif op["op"] in {
643 Op.CONV2D,
644 Op.CONV3D,
645 Op.DEPTHWISE_CONV2D,
646 Op.TRANSPOSE_CONV2D,
647 }:
648 if (
649 input_dtype == DType.INT8
650 and output_dtype != DType.INT32
651 or input_dtype == DType.INT16
652 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100653 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100654 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100655 or input_dtype == DType.BF16
656 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100657 or input_dtype == DType.FP32
658 and output_dtype != DType.FP32
Won Jeon2c34b462024-02-06 18:37:00 +0000659 or input_dtype == DType.FP8E4M3
660 and output_dtype != DType.FP16
661 or input_dtype == DType.FP8E5M2
662 and output_dtype != DType.FP16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100663 ):
664 error_result = True
665 # invalid input types are ignored, to avoid reporting multiple errors
666
Won Jeon74342e52024-01-09 00:34:40 +0000667 elif op["op"] in {Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE}:
668 if output_dtype != DType.SHAPE:
669 error_result = True
670
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100671 else:
672 if output_dtype != input_dtype:
673 error_result = True
674
675 info_dict = {
676 "error_name": ErrorIf.WrongOutputType,
677 "error_result": error_result,
678 "error_reason": (
679 "Output data type not supported for this configuration of operator"
680 ),
681 "param_reqs": {"rank": None, "dtype": None, "shape": None},
682 }
683 return info_dict
684
685 @staticmethod
686 def evWrongRank(check=False, **kwargs):
687 all_ranks = (1, 2, 3, 4, 5)
688
689 # Make a list of incorrect ranks
690 assert "op" in kwargs
691 op = kwargs["op"]
692 rmin, rmax = op["rank"]
693 rank_range = range(rmin, rmax + 1)
694 incorrect_ranks = list(set(all_ranks) - set(rank_range))
695 # Remove small incorrect ranks to avoid index errors
696 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
697 # Set minimum incorrect rank to 3 to avoid index error
698 if op["op"] in [Op.RESIZE]:
699 incorrect_ranks = [3, 5]
700 elif op["op"] in [Op.TRANSPOSE]:
701 incorrect_ranks = [7, 8]
702 elif op["op"] in [Op.CONV3D]:
703 incorrect_ranks = [6, 7]
704
705 error_name = ErrorIf.WrongRank
706 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
707 error_result = False
708 error_reason = "Rank not supported for this operator"
709
710 if check:
711 input_shape = kwargs["input_shape"]
712
713 if (
714 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
715 and len(input_shape) != 4
716 ):
717 error_result = True
718 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
719 error_result = True
720 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
721 error_result = True
722 else:
723 if len(input_shape) not in rank_range:
724 error_result = True
725
726 info_dict = {
727 "error_name": error_name,
728 "error_result": error_result,
729 "error_reason": error_reason,
730 "param_reqs": param_reqs,
731 }
732 return info_dict
733
734 @staticmethod
735 def evWrongInputList(check=False, **kwargs):
736 error_name = ErrorIf.WrongInputList
737 param_reqs = {"rank": None, "dtype": None, "shape": None}
738 error_result = False
739 error_reason = "Op input list does not match expected input"
740
741 if check:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100742 input_list = kwargs["input_list"]
743 num_operands = kwargs["num_operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100744 if len(input_list) != num_operands:
745 error_result = True
746
747 info_dict = {
748 "error_name": error_name,
749 "error_result": error_result,
750 "error_reason": error_reason,
751 "param_reqs": param_reqs,
752 }
753 return info_dict
754
755 @staticmethod
756 def evWrongOutputList(check=False, **kwargs):
757 error_name = ErrorIf.WrongOutputList
758 param_reqs = {"rank": None, "dtype": None, "shape": None}
759 error_result = False
760 error_reason = "Op output list does not match expected output"
761
762 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000763 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100764 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000765 expected_length = 1
Luke Hutton57287132023-02-06 14:54:18 +0000766 if op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000767 expected_length = 2
768
769 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100770 error_result = True
771
772 info_dict = {
773 "error_name": error_name,
774 "error_result": error_result,
775 "error_reason": error_reason,
776 "param_reqs": param_reqs,
777 }
778 return info_dict
779
780 @staticmethod
781 def evMaxDimExceeded(check=False, **kwargs):
782 error_name = ErrorIf.MaxDimExceeded
783 param_reqs = {
784 "rank": [4, 4],
785 "dtype": [DType.INT8],
786 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
787 }
788 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100789 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100790
791 if check:
792 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100793 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100794 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100795 (input_shape[1] >= MAX_RESIZE_DIMENSION)
796 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
797 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
798 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100799 ):
800 error_result = True
801
802 info_dict = {
803 "error_name": error_name,
804 "error_result": error_result,
805 "error_reason": error_reason,
806 "param_reqs": param_reqs,
807 }
808 return info_dict
809
810 @staticmethod
811 def evBatchMismatch(check=False, **kwargs):
812 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000813 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100814 error_result = False
815 error_reason = "Input batch size not equal to output batch size"
816
817 assert "op" in kwargs
818 op = kwargs["op"]
819 rmin, rmax = op["rank"]
820 rank_range = range(rmin, rmax + 1)
821
822 if check:
823 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100824
Luke Hutton261b7b62023-01-10 14:50:31 +0000825 for output in kwargs["result_tensors"]:
826 output_shape = (
827 output.shape
828 ) # Note batch is expected to be the first dim
829 if (len(input_shape) in rank_range) and (
830 input_shape[0] != output_shape[0]
831 ):
832 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100833
834 info_dict = {
835 "error_name": error_name,
836 "error_result": error_result,
837 "error_reason": error_reason,
838 "param_reqs": param_reqs,
839 }
840 return info_dict
841
842 @staticmethod
843 def evChannelMismatch(check=False, **kwargs):
844 error_name = ErrorIf.ChannelMismatch
845 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
846 error_result = False
847 error_reason = "Input channel size not equal to output channel size"
848
849 assert "op" in kwargs
850 op = kwargs["op"]
851 rmin, rmax = op["rank"]
852 rank_range = range(rmin, rmax + 1)
853
854 if check:
855 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000856 for output in kwargs["result_tensors"]:
857 output_shape = output.shape # Note this is just (N, OH, OW, C)
858 if (len(input_shape) in rank_range) and (
859 input_shape[3] != output_shape[3]
860 ):
861 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100862
863 info_dict = {
864 "error_name": error_name,
865 "error_result": error_result,
866 "error_reason": error_reason,
867 "param_reqs": param_reqs,
868 }
869 return info_dict
870
871 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100872 def evScaleSmallerEqualZero(check=False, **kwargs):
873 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100874 param_reqs = {"rank": None, "dtype": None, "shape": None}
875 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100876 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100877
878 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100879 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100880
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100881 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100882 error_result = True
883
884 info_dict = {
885 "error_name": error_name,
886 "error_result": error_result,
887 "error_reason": error_reason,
888 "param_reqs": param_reqs,
889 }
890 return info_dict
891
892 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100893 def evScaleNLargerMax(check=False, **kwargs):
894 error_name = ErrorIf.ScaleNLargerMax
895 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100896 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100897 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100898
899 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100900 scale = kwargs["scale"]
901
902 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
903 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904
905 info_dict = {
906 "error_name": error_name,
907 "error_result": error_result,
908 "error_reason": error_reason,
909 "param_reqs": param_reqs,
910 }
911 return info_dict
912
913 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100914 def evScaleDLargerMax(check=False, **kwargs):
915 error_name = ErrorIf.ScaleDLargerMax
916 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100917 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100918 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100919
920 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100921 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100922
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100923 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
924 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100925 ):
926 error_result = True
927
928 info_dict = {
929 "error_name": error_name,
930 "error_result": error_result,
931 "error_reason": error_reason,
932 "param_reqs": param_reqs,
933 }
934 return info_dict
935
936 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100937 def evOffsetSmallerMin(check=False, **kwargs):
938 error_name = ErrorIf.OffsetSmallerMin
939 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100940 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100941 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100942
943 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100944 scale = kwargs["scale"]
945 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100946
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100947 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100948 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100949 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100950 error_result = True
951
952 info_dict = {
953 "error_name": error_name,
954 "error_result": error_result,
955 "error_reason": error_reason,
956 "param_reqs": param_reqs,
957 }
958 return info_dict
959
960 @staticmethod
961 def evOffsetLargerEqualMax(check=False, **kwargs):
962 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100963 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100964 error_result = False
965 error_reason = "Offset value larger than or equal to maximum value"
966
967 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100968 scale = kwargs["scale"]
969 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100970
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100971 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
972 error_result = True
973 elif (
974 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
975 ):
976 error_result = True
977
978 info_dict = {
979 "error_name": error_name,
980 "error_result": error_result,
981 "error_reason": error_reason,
982 "param_reqs": param_reqs,
983 }
984 return info_dict
985
986 @staticmethod
987 def evBorderSmallerMin(check=False, **kwargs):
988 error_name = ErrorIf.BorderSmallerMin
989 param_reqs = {"rank": None, "dtype": None, "shape": None}
990 error_result = False
991 error_reason = "Border value smaller than minimum value"
992
993 if check:
994 scale = kwargs["scale"]
995 border = kwargs["border"]
996
997 if (
998 scale[0] > 0
999 and scale[0] <= (1 << 11)
1000 and (border[0] < (-16 * scale[0]))
1001 ):
1002 error_result = True
1003 elif (
1004 scale[2] > 0
1005 and scale[2] <= (1 << 11)
1006 and (border[1] < (-16 * scale[2]))
1007 ):
1008 error_result = True
1009
1010 info_dict = {
1011 "error_name": error_name,
1012 "error_result": error_result,
1013 "error_reason": error_reason,
1014 "param_reqs": param_reqs,
1015 }
1016 return info_dict
1017
1018 @staticmethod
1019 def evBorderLargerEqualMax(check=False, **kwargs):
1020 error_name = ErrorIf.BorderLargerEqualMax
1021 param_reqs = {"rank": None, "dtype": None, "shape": None}
1022 error_result = False
1023 error_reason = "Border value larger than or equal to maximum value"
1024
1025 if check:
1026 scale = kwargs["scale"]
1027 border = kwargs["border"]
1028
1029 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
1030 error_result = True
1031 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
1032 error_result = True
1033
1034 info_dict = {
1035 "error_name": error_name,
1036 "error_result": error_result,
1037 "error_reason": error_reason,
1038 "param_reqs": param_reqs,
1039 }
1040 return info_dict
1041
1042 @staticmethod
1043 def checkResizeParams(scale, offset, border):
1044 return (
1045 min(scale) > 0
1046 and max(scale[0], scale[2]) <= (1 << 11)
1047 and scale[1] < 16 * scale[0]
1048 and scale[3] < 16 * scale[2]
1049 and offset[0] >= -scale[0]
1050 and offset[1] >= -scale[2]
1051 and offset[0] < 16 * scale[0]
1052 and offset[1] < 16 * scale[2]
1053 and border[0] >= -16 * scale[0]
1054 and border[1] >= -16 * scale[2]
1055 and border[0] < scale[0]
1056 and border[1] < scale[2]
1057 )
1058
1059 @staticmethod
1060 def evResizeOutputShapeMismatch(check=False, **kwargs):
1061 error_name = ErrorIf.ResizeOutputShapeMismatch
1062 param_reqs = {"rank": None, "dtype": None, "shape": None}
1063 error_result = False
1064 error_reason = (
1065 "Mismatch between output shape provided and expected output shape"
1066 )
1067
1068 if check:
1069 input_shape = kwargs["input_shape"]
1070 output_shape = kwargs["output_shape"]
1071 scale = kwargs["scale"]
1072 offset = kwargs["offset"]
1073 border = kwargs["border"]
1074
1075 # Ensure parameters are valid
1076 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1077
1078 if (
1079 params_valid
1080 and max(output_shape) < MAX_RESIZE_DIMENSION
1081 and max(input_shape) < MAX_RESIZE_DIMENSION
1082 ):
1083 output_y = (
1084 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1085 ) // scale[1] + 1
1086 output_x = (
1087 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1088 ) // scale[3] + 1
1089
1090 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001091 error_result = True
1092
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001093 info_dict = {
1094 "error_name": error_name,
1095 "error_result": error_result,
1096 "error_reason": error_reason,
1097 "param_reqs": param_reqs,
1098 }
1099 return info_dict
1100
1101 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001102 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1103 error_name = ErrorIf.ResizeOutputShapeNonInteger
1104 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001105 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001106 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001107
1108 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001109 input_shape = kwargs["input_shape"]
1110 scale = kwargs["scale"]
1111 offset = kwargs["offset"]
1112 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001113
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001114 # Ensure parameters are valid
1115 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001116
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001117 if params_valid:
1118 remainder_y = (
1119 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1120 ) % scale[1]
1121 remainder_x = (
1122 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1123 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001124
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001125 if max(remainder_y, remainder_x) > 0:
1126 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001127
1128 info_dict = {
1129 "error_name": error_name,
1130 "error_result": error_result,
1131 "error_reason": error_reason,
1132 "param_reqs": param_reqs,
1133 }
1134 return info_dict
1135
1136 @staticmethod
1137 def evRankMismatch(check=False, **kwargs):
1138 error_name = ErrorIf.RankMismatch
1139 param_reqs = {"rank": None, "dtype": None, "shape": None}
1140 error_result = False
1141 error_reason = "Input Rank does not match output rank"
1142
1143 if check:
1144 input1_shape = kwargs["input1"].shape
Luke Huttona4e48ca2023-02-22 11:53:48 +00001145 input2_shape = (
1146 kwargs["input2"].shape if "input2" in kwargs else input1_shape
1147 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001148 # In case of SELECT op
1149 input3_shape = (
1150 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1151 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001152
1153 for output in kwargs["result_tensors"]:
1154 output_shape = output.shape
1155 if (
1156 (len(input1_shape) != len(output_shape))
1157 or (len(input2_shape) != len(output_shape))
1158 or (len(input3_shape) != len(output_shape))
1159 ):
1160 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161
1162 info_dict = {
1163 "error_name": error_name,
1164 "error_result": error_result,
1165 "error_reason": error_reason,
1166 "param_reqs": param_reqs,
1167 }
1168 return info_dict
1169
1170 @staticmethod
1171 def evDimensionMismatch(check=False, **kwargs):
1172 error_name = ErrorIf.DimensionMismatch
1173 param_reqs = {"rank": None, "dtype": None, "shape": None}
1174 error_result = False
1175 error_reason = "Input Dimensions do not match output"
1176
1177 if check:
1178 input1_shape = kwargs["input1"].shape
1179 input2_shape = kwargs["input2"].shape
1180 # In case of SELECT op
1181 input3_shape = (
1182 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1183 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001184
Won Jeon74342e52024-01-09 00:34:40 +00001185 op = kwargs["op"]
1186 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE):
1187 output_shape = kwargs["result_tensors"][0].shape
1188 if input1_shape != output_shape:
1189 error_result = True
1190
1191 elif len(input1_shape) == len(input2_shape) == len(input3_shape):
Jerry Ge135c9552023-05-23 20:59:32 +00001192 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
1193 input3_shape,
1194 TosaErrorValidator.calculateBroadcastShape(
1195 input1_shape, input2_shape
1196 ),
1197 )
1198 if calculated_shape is not None:
1199 # Valid inputs - check for output mismatch
1200 for output in kwargs["result_tensors"]:
1201 output_shape = output.shape
1202 if calculated_shape != output_shape:
1203 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001204
1205 info_dict = {
1206 "error_name": error_name,
1207 "error_result": error_result,
1208 "error_reason": error_reason,
1209 "param_reqs": param_reqs,
1210 }
1211 return info_dict
1212
1213 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001214 def _getZeroPoint(qinfo, index):
1215 """Return zero point value from quantization info.
1216
1217 Generally input_zp is index 0, output_zp is index 1
1218 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001219 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001220
1221 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001222 def evInputZeroPointNotZero(check=False, **kwargs):
1223 op = kwargs["op"]
1224 error_result = False
1225
1226 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001227 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001228
1229 # This does not apply to quantizable types
1230 inputDtypes = [
1231 dtype
1232 for dtype in op["types"]
1233 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1234 or (not isinstance(dtype, list) and dtype not in qTypes)
1235 ]
1236
1237 if check:
1238 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001239 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001240 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001241 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001242 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001243 (kwargs["input_dtype"], input_zero_point),
1244 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001245 ):
1246 if dtype not in qTypes and zp != 0:
1247 error_result = True
1248 break
1249 else:
1250 error_result = input_dtype not in qTypes and input_zero_point != 0
1251
1252 info_dict = {
1253 "error_name": ErrorIf.InputZeroPointNotZero,
1254 "error_result": error_result,
1255 "error_reason": "Input DType not INT8 and zero point not 0",
1256 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1257 }
1258 return info_dict
1259
1260 @staticmethod
1261 def evWeightZeroPointNotZero(check=False, **kwargs):
1262 op = kwargs["op"]
1263
1264 # exclude inputs with INT8 weights
1265 inputDtypes = [
1266 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1267 ]
1268
1269 error_name = ErrorIf.WeightZeroPointNotZero
1270 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1271 error_result = False
1272 error_reason = "Weight DType not INT8 and zero point not 0"
1273
1274 if check:
1275 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001276 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001277 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1278 error_result = True
1279
1280 info_dict = {
1281 "error_name": error_name,
1282 "error_result": error_result,
1283 "error_reason": error_reason,
1284 "param_reqs": param_reqs,
1285 }
1286 return info_dict
1287
1288 @staticmethod
1289 def evOutputZeroPointNotZero(check=False, **kwargs):
1290 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001291 inputDtypes = [
1292 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1293 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001294
1295 error_name = ErrorIf.OutputZeroPointNotZero
1296 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1297 error_result = False
1298 error_reason = "Output DType not INT8 and zero point not 0"
1299
1300 if check:
1301 input_dtype = kwargs["input_dtype"]
1302 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001303 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001304 if op["op"] == Op.AVG_POOL2D:
1305 if input_dtype != DType.INT8 and output_zero_point != 0:
1306 error_result = True
1307 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001308 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1309 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001310 ):
1311 error_result = True
1312
1313 info_dict = {
1314 "error_name": error_name,
1315 "error_result": error_result,
1316 "error_reason": error_reason,
1317 "param_reqs": param_reqs,
1318 }
1319 return info_dict
1320
1321 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001322 def evU16InputZeroPointNotValid(check=False, **kwargs):
1323 error_name = ErrorIf.U16InputZeroPointNotValid
1324 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1325 error_result = False
1326 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1327
1328 if check:
1329 input_dtype = kwargs["input_dtype"]
1330 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1331 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1332 0,
1333 32768,
1334 ]
1335
1336 info_dict = {
1337 "error_name": error_name,
1338 "error_result": error_result,
1339 "error_reason": error_reason,
1340 "param_reqs": param_reqs,
1341 }
1342 return info_dict
1343
1344 @staticmethod
1345 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1346 error_name = ErrorIf.U16OutputZeroPointNotValid
1347 param_reqs = {"rank": None, "dtype": None, "shape": None}
1348 error_result = False
1349 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1350
1351 if check:
1352 output_dtype = kwargs["output_dtype"]
1353 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1354
1355 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1356 0,
1357 32768,
1358 ]
1359
1360 info_dict = {
1361 "error_name": error_name,
1362 "error_result": error_result,
1363 "error_reason": error_reason,
1364 "param_reqs": param_reqs,
1365 }
1366 return info_dict
1367
1368 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001369 def evAxisSmallerZero(check=False, **kwargs):
1370 error_name = ErrorIf.AxisSmallerZero
1371 param_reqs = {"rank": None, "dtype": None, "shape": None}
1372 error_result = False
1373 error_reason = "Axis smaller than zero"
1374
1375 if check:
1376 axis = kwargs["axis"]
1377 if axis < 0:
1378 error_result = True
1379
1380 info_dict = {
1381 "error_name": error_name,
1382 "error_result": error_result,
1383 "error_reason": error_reason,
1384 "param_reqs": param_reqs,
1385 }
1386 return info_dict
1387
1388 @staticmethod
1389 def evAxisLargerRank(check=False, **kwargs):
1390 error_name = ErrorIf.AxisLargerRank
1391 param_reqs = {"rank": None, "dtype": None, "shape": None}
1392 error_result = False
1393 error_reason = "Axis larger than rank"
1394
1395 if check:
1396 axis = kwargs["axis"]
1397 shape = kwargs["input_shape"]
1398 if axis > len(shape):
1399 error_result = True
1400
1401 info_dict = {
1402 "error_name": error_name,
1403 "error_result": error_result,
1404 "error_reason": error_reason,
1405 "param_reqs": param_reqs,
1406 }
1407 return info_dict
1408
1409 @staticmethod
1410 def evShapeOfAxisNotOne(check=False, **kwargs):
1411 error_name = ErrorIf.ShapeOfAxisNotOne
1412 param_reqs = {"rank": None, "dtype": None, "shape": None}
1413 error_result = False
1414 error_reason = "shape[axis] is not equal to 1"
1415
1416 if check:
1417 axis = kwargs["axis"]
1418 shape = kwargs["output_shape"]
1419 if (0 <= axis < len(shape)) and shape[axis] != 1:
1420 error_result = True
1421
1422 info_dict = {
1423 "error_name": error_name,
1424 "error_result": error_result,
1425 "error_reason": error_reason,
1426 "param_reqs": param_reqs,
1427 }
1428 return info_dict
1429
1430 @staticmethod
1431 def evPadSmallerZero(check=False, **kwargs):
1432 error_name = ErrorIf.PadSmallerZero
1433 param_reqs = {"rank": None, "dtype": None, "shape": None}
1434 error_result = False
1435 error_reason = "At least one pad is smaller than zero"
1436
1437 if check:
1438 op = kwargs["op"]
1439 pad = kwargs["pad"]
1440 if op["op"] == Op.PAD:
1441 for padding in pad:
1442 if min(padding) < 0:
1443 error_result = True
1444 else:
1445 if min(pad) < 0:
1446 error_result = True
1447
1448 info_dict = {
1449 "error_name": error_name,
1450 "error_result": error_result,
1451 "error_reason": error_reason,
1452 "param_reqs": param_reqs,
1453 }
1454 return info_dict
1455
1456 @staticmethod
1457 def evPadLargerEqualKernel(check=False, **kwargs):
1458 error_name = ErrorIf.PadLargerEqualKernel
1459 param_reqs = {"rank": None, "dtype": None, "shape": None}
1460 error_result = False
1461 error_reason = "At least one pad is larger than kernel dimension"
1462
1463 if check:
1464 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001465 op = kwargs["op"]
1466 if op["op"] == Op.TRANSPOSE_CONV2D:
1467 # transpose_conv2d
1468 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001469 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001470 pad[0] <= -kernel[0]
1471 or pad[1] <= -kernel[0]
1472 or pad[2] <= -kernel[1]
1473 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001474 ):
1475 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001476 else:
1477 # pooling op
1478 kernel = kwargs["kernel"]
1479 if min(pad) > 0 and min(kernel) > 1:
1480 if (
1481 pad[0] >= kernel[0]
1482 or pad[1] >= kernel[0]
1483 or pad[2] >= kernel[1]
1484 or pad[3] >= kernel[1]
1485 ):
1486 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001487
1488 info_dict = {
1489 "error_name": error_name,
1490 "error_result": error_result,
1491 "error_reason": error_reason,
1492 "param_reqs": param_reqs,
1493 }
1494 return info_dict
1495
1496 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001497 def evPadOutputShapeMismatch(check=False, **kwargs):
1498 error_name = ErrorIf.PadOutputShapeMismatch
1499 param_reqs = {"rank": None, "dtype": None, "shape": None}
1500 error_result = False
1501 error_reason = "Pad output shape mismatch for requested padding"
1502
1503 if check:
1504 pad = kwargs["pad"]
1505 input_shape = kwargs["input_shape"]
1506 output_shape = kwargs["output_shape"]
1507 for dim, padding in enumerate(pad):
1508 expected_size = input_shape[dim] + padding[0] + padding[1]
1509 if expected_size != output_shape[dim]:
1510 error_result = True
1511
1512 info_dict = {
1513 "error_name": error_name,
1514 "error_result": error_result,
1515 "error_reason": error_reason,
1516 "param_reqs": param_reqs,
1517 }
1518 return info_dict
1519
1520 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001521 def checkPoolingParams(kernel, stride, pad):
1522 return (
1523 min(kernel) >= 1
1524 and min(stride) >= 1
1525 and min(pad) >= 0
1526 and not (
1527 pad[0] >= kernel[0]
1528 or pad[1] >= kernel[0]
1529 or pad[2] >= kernel[1]
1530 or pad[3] >= kernel[1]
1531 )
1532 )
1533
1534 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001535 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1536 error_name = ErrorIf.PoolingOutputShapeMismatch
1537 param_reqs = {"rank": None, "dtype": None, "shape": None}
1538 error_result = False
1539 error_reason = (
1540 "Mismatch between output shape provided and expected output shape"
1541 )
1542
1543 if check:
1544 pad = kwargs["pad"]
1545 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1546
1547 kernel = kwargs["kernel"]
1548 kernel_y, kernel_x = kernel[0], kernel[1]
1549
1550 input_shape = kwargs["input_shape"]
1551 IH, IW = input_shape[1], input_shape[2]
1552
1553 output_shape = kwargs["output_shape"]
1554 OH, OW = output_shape[1], output_shape[2]
1555
1556 stride = kwargs["stride"]
1557 stride_y, stride_x = stride[0], stride[1]
1558
1559 # calculate correct height, width dimensions
1560 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001561 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1562 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001563
1564 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001565 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001566
1567 if params_valid and (OH != y_correct or OW != x_correct):
1568 error_result = True
1569
1570 info_dict = {
1571 "error_name": error_name,
1572 "error_result": error_result,
1573 "error_reason": error_reason,
1574 "param_reqs": param_reqs,
1575 }
1576 return info_dict
1577
1578 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001579 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1580 error_name = ErrorIf.PoolingOutputShapeNonInteger
1581 param_reqs = {"rank": None, "dtype": None, "shape": None}
1582 error_result = False
1583 error_reason = "Parameters do not yield exact integer output dimensions"
1584
1585 if check:
1586 pad = kwargs["pad"]
1587 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1588
1589 kernel = kwargs["kernel"]
1590 kernel_y, kernel_x = kernel[0], kernel[1]
1591
1592 input_shape = kwargs["input_shape"]
1593 IH, IW = input_shape[1], input_shape[2]
1594
1595 stride = kwargs["stride"]
1596 stride_y, stride_x = stride[0], stride[1]
1597
1598 # calculate remainder of height, width dimensions
1599 if stride_x != 0 and stride_y != 0:
1600 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1601 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1602
1603 # ensure parameters are valid
1604 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1605 if params_valid and (y_remainder != 0 or x_remainder != 0):
1606 error_result = True
1607
1608 info_dict = {
1609 "error_name": error_name,
1610 "error_result": error_result,
1611 "error_reason": error_reason,
1612 "param_reqs": param_reqs,
1613 }
1614 return info_dict
1615
1616 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001617 def checkConvParams(op, weight_shape, stride, pad, dilation):
1618 if op == Op.TRANSPOSE_CONV2D:
1619 pad_ok = (
1620 pad[0] > -weight_shape[1]
1621 and pad[1] > -weight_shape[1]
1622 and pad[2] > -weight_shape[2]
1623 and pad[3] > -weight_shape[2]
1624 )
1625 else:
1626 pad_ok = min(pad) >= 0
1627
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001628 return (
1629 # Check kernel sizes
1630 min(weight_shape[1:-1]) >= 1
1631 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001632 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001633 and (dilation is None or min(dilation) >= 1)
1634 )
1635
1636 @staticmethod
1637 def evConvOutputShapeMismatch(check=False, **kwargs):
1638 error_name = ErrorIf.ConvOutputShapeMismatch
1639 param_reqs = {"rank": None, "dtype": None, "shape": None}
1640 error_result = False
1641 error_reason = (
1642 "Mismatch between output shape provided and expected output shape"
1643 )
1644
1645 if check:
1646 op = kwargs["op"]
1647 pad = kwargs["pad"]
1648 weight_shape = kwargs["weight_shape"]
1649 input_shape = kwargs["input_shape"]
1650 output_shape = kwargs["output_shape"]
1651 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1652 stride = kwargs["stride"]
1653
1654 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1655
1656 # calculate correct dimensions
1657 dims_correct = []
1658 if min(stride) > 0:
1659 for index in range(len(stride)):
1660 pad_offset = index * 2
1661 if op["op"] == Op.TRANSPOSE_CONV2D:
1662 dims_correct.append(
1663 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001664 + pad[pad_offset]
1665 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001666 + weight_shape[index + kernel_offset]
1667 )
1668 else:
1669 dims_correct.append(
1670 (
1671 input_shape[index + 1]
1672 - 1
1673 + pad[pad_offset]
1674 + pad[pad_offset + 1]
1675 - (weight_shape[index + kernel_offset] - 1)
1676 * dilation[index]
1677 )
1678 // stride[index]
1679 + 1
1680 )
1681
1682 # ensure parameters are valid
1683 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001684 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001685 )
1686
1687 if params_valid and output_shape[1:-1] != dims_correct:
1688 error_result = True
1689
1690 info_dict = {
1691 "error_name": error_name,
1692 "error_result": error_result,
1693 "error_reason": error_reason,
1694 "param_reqs": param_reqs,
1695 }
1696 return info_dict
1697
1698 @staticmethod
1699 def evConvOutputShapeNonInteger(check=False, **kwargs):
1700 error_name = ErrorIf.ConvOutputShapeNonInteger
1701 param_reqs = {"rank": None, "dtype": None, "shape": None}
1702 error_result = False
1703 error_reason = "Parameters do not yield exact integer output dimensions"
1704
1705 if check:
1706 op = kwargs["op"]
1707 pad = kwargs["pad"]
1708 weight_shape = kwargs["weight_shape"]
1709 input_shape = kwargs["input_shape"]
1710 dilation = kwargs["dilation"]
1711 stride = kwargs["stride"]
1712
1713 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1714
1715 # calculate correct height, width dimensions
1716 remainders = []
1717 if min(stride) > 0:
1718 for index in range(len(stride)):
1719 pad_offset = index * 2
1720 remainders.append(
1721 (
1722 input_shape[index + 1]
1723 - 1
1724 + pad[pad_offset]
1725 + pad[pad_offset + 1]
1726 - (weight_shape[index + kernel_offset] - 1)
1727 * dilation[index]
1728 )
1729 % stride[index]
1730 )
1731
1732 # ensure parameters are valid
1733 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001734 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001735 )
1736 if params_valid and max(remainders) > 0:
1737 error_result = True
1738
1739 info_dict = {
1740 "error_name": error_name,
1741 "error_result": error_result,
1742 "error_reason": error_reason,
1743 "param_reqs": param_reqs,
1744 }
1745 return info_dict
1746
1747 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001748 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1749 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1750 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1751 error_result = False
1752 error_reason = (
1753 "Mismatch between output shape provided and expected output shape"
1754 )
1755
1756 if check:
1757 output_shape = kwargs["output_shape"]
1758 input_shape = kwargs["input_shape"]
1759 axis = kwargs["axis"]
1760
1761 dimension_match = True
1762 axis_shift = 0
1763
1764 # Check that rank is correct before trying to check dimensions
1765 if (len(input_shape) - 1) == len(output_shape):
1766 for i in range(len(input_shape)):
1767 if i == axis:
1768 axis_shift = 1
1769 continue
1770 if input_shape[i] != output_shape[i - axis_shift]:
1771 dimension_match = False
1772
1773 if not dimension_match:
1774 error_result = True
1775
1776 info_dict = {
1777 "error_name": error_name,
1778 "error_result": error_result,
1779 "error_reason": error_reason,
1780 "param_reqs": param_reqs,
1781 }
1782 return info_dict
1783
1784 @staticmethod
1785 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1786 error_name = ErrorIf.ArgmaxOutputRankMismatch
1787 param_reqs = {"rank": None, "dtype": None, "shape": None}
1788 error_result = False
1789 error_reason = (
1790 "Mismatch between output shape provided and expected output shape"
1791 )
1792
1793 if check:
1794 output_shape = kwargs["output_shape"]
1795 input_shape = kwargs["input_shape"]
1796 axis = kwargs["axis"]
1797 valid_params = axis >= 0 and axis < len(input_shape)
1798
1799 if valid_params and (len(input_shape) - 1) != len(output_shape):
1800 error_result = True
1801
1802 info_dict = {
1803 "error_name": error_name,
1804 "error_result": error_result,
1805 "error_reason": error_reason,
1806 "param_reqs": param_reqs,
1807 }
1808 return info_dict
1809
1810 @staticmethod
1811 def evKernelSmallerOne(check=False, **kwargs):
1812 error_name = ErrorIf.KernelSmallerOne
1813 param_reqs = {"rank": None, "dtype": None, "shape": None}
1814 error_result = False
1815 error_reason = "At least one kernel dimension is smaller than zero"
1816
1817 if check:
1818 kernel = kwargs["kernel"]
1819 if min(kernel) < 1:
1820 error_result = True
1821
1822 info_dict = {
1823 "error_name": error_name,
1824 "error_result": error_result,
1825 "error_reason": error_reason,
1826 "param_reqs": param_reqs,
1827 }
1828 return info_dict
1829
1830 @staticmethod
1831 def evStrideSmallerOne(check=False, **kwargs):
1832 error_name = ErrorIf.StrideSmallerOne
1833 param_reqs = {"rank": None, "dtype": None, "shape": None}
1834 error_result = False
1835 error_reason = "At least one stride dimension is smaller than zero"
1836
1837 if check:
1838 stride = kwargs["stride"]
1839 if min(stride) < 1:
1840 error_result = True
1841
1842 info_dict = {
1843 "error_name": error_name,
1844 "error_result": error_result,
1845 "error_reason": error_reason,
1846 "param_reqs": param_reqs,
1847 }
1848 return info_dict
1849
1850 @staticmethod
1851 def evDilationSmallerOne(check=False, **kwargs):
1852 error_result = check and min(kwargs["dilation"]) < 1
1853 return {
1854 "error_name": ErrorIf.DilationSmallerOne,
1855 "error_reason": "At least one dilation is smaller than one",
1856 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1857 "error_result": error_result,
1858 }
1859
1860 @staticmethod
1861 def evScaleTrue(check=False, **kwargs):
1862 error_name = ErrorIf.ScaleTrue
1863 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1864 error_result = False
1865 error_reason = "Scale set to true but input type is INT48"
1866
1867 if check:
1868 input_dtype = kwargs["input_dtype"]
1869 scale32 = kwargs["scale32"]
1870 if scale32 and input_dtype == DType.INT48:
1871 error_result = True
1872
1873 info_dict = {
1874 "error_name": error_name,
1875 "error_result": error_result,
1876 "error_reason": error_reason,
1877 "param_reqs": param_reqs,
1878 }
1879 return info_dict
1880
1881 @staticmethod
1882 def evScaleNotTrue(check=False, **kwargs):
1883 error_name = ErrorIf.ScaleNotTrue
1884 param_reqs = {"rank": None, "dtype": None, "shape": None}
1885 error_result = False
1886 error_reason = "Scale set to false but double round set to true"
1887
1888 if check:
1889 scale32 = kwargs["scale32"]
1890 double_round = kwargs["double_round"]
1891 if not scale32 and double_round:
1892 error_result = True
1893
1894 info_dict = {
1895 "error_name": error_name,
1896 "error_result": error_result,
1897 "error_reason": error_reason,
1898 "param_reqs": param_reqs,
1899 }
1900 return info_dict
1901
1902 @staticmethod
1903 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1904 error_name = ErrorIf.TensorSizeInputOutputMismatch
1905 param_reqs = {"rank": None, "dtype": None, "shape": None}
1906 error_result = False
1907 error_reason = "Input tensor size does not match output tensor size"
Jerry Ge264f7fa2023-04-21 22:49:57 +00001908 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001909
1910 if check:
1911 input_shape = kwargs["input_shape"]
1912 output_shape = kwargs["output_shape"]
Jerry Ge264f7fa2023-04-21 22:49:57 +00001913 shape_inferencing = False
1914 if -1 in output_shape and op["op"] == Op.RESHAPE:
1915 shape_inferencing = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001916 input_size = np.prod(input_shape)
1917 output_size = np.prod(output_shape)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001918 if input_size != output_size and not shape_inferencing:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001919 error_result = True
1920
1921 info_dict = {
1922 "error_name": error_name,
1923 "error_result": error_result,
1924 "error_reason": error_reason,
1925 "param_reqs": param_reqs,
1926 }
1927 return info_dict
1928
1929 @staticmethod
1930 def evStartSmallerZero(check=False, **kwargs):
1931 error_name = ErrorIf.StartSmallerZero
1932 param_reqs = {"rank": None, "dtype": None, "shape": None}
1933 error_result = False
1934 error_reason = "Starting point smaller than zero"
1935
1936 if check:
1937 input_shape = kwargs["input_shape"]
1938 start = kwargs["start"]
1939 rank = len(input_shape)
1940 if len(start) == rank:
1941 for index in range(rank):
1942 if start[index] < 0:
1943 error_result = True
1944
1945 info_dict = {
1946 "error_name": error_name,
1947 "error_result": error_result,
1948 "error_reason": error_reason,
1949 "param_reqs": param_reqs,
1950 }
1951 return info_dict
1952
1953 @staticmethod
1954 def evSizeSmallerEqualZero(check=False, **kwargs):
1955 error_name = ErrorIf.SizeSmallerEqualZero
1956 param_reqs = {"rank": None, "dtype": None, "shape": None}
1957 error_result = False
1958 error_reason = "Size smaller than or equal to zero"
1959
1960 if check:
1961 input_shape = kwargs["input_shape"]
1962 size = kwargs["size"]
1963 rank = len(input_shape)
1964 if len(size) == rank:
1965 for index in range(rank):
1966 if size[index] <= 0:
1967 error_result = True
1968
1969 info_dict = {
1970 "error_name": error_name,
1971 "error_result": error_result,
1972 "error_reason": error_reason,
1973 "param_reqs": param_reqs,
1974 }
1975 return info_dict
1976
1977 @staticmethod
1978 def evStartSizeOutsideBounds(check=False, **kwargs):
1979 error_name = ErrorIf.StartSizeOutsideBounds
1980 param_reqs = {"rank": None, "dtype": None, "shape": None}
1981 error_result = False
1982 error_reason = "starting point plus size larger than input dimension"
1983
1984 if check:
1985 input_shape = kwargs["input_shape"]
1986 start = kwargs["start"]
1987 size = kwargs["size"]
1988 rank = len(input_shape)
1989 if len(start) == rank and len(size) == rank:
1990 for index in range(rank):
1991 if start[index] + size[index] > input_shape[index]:
1992 error_result = True
1993
1994 info_dict = {
1995 "error_name": error_name,
1996 "error_result": error_result,
1997 "error_reason": error_reason,
1998 "param_reqs": param_reqs,
1999 }
2000 return info_dict
2001
2002 @staticmethod
2003 def evSizeOutputShapeMismatch(check=False, **kwargs):
2004 error_name = ErrorIf.SizeOutputShapeMismatch
2005 param_reqs = {"rank": None, "dtype": None, "shape": None}
2006 error_result = False
2007 error_reason = "Size does not match output dimension"
2008
2009 if check:
2010 input_shape = kwargs["input_shape"]
2011 output_shape = kwargs["output_shape"]
2012 size = kwargs["size"]
Luke Huttona4e48ca2023-02-22 11:53:48 +00002013
2014 if len(input_shape) == len(output_shape):
2015 rank = len(input_shape)
2016 if len(size) == rank:
2017 for index in range(rank):
2018 if size[index] != output_shape[index]:
2019 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002020
2021 info_dict = {
2022 "error_name": error_name,
2023 "error_result": error_result,
2024 "error_reason": error_reason,
2025 "param_reqs": param_reqs,
2026 }
2027 return info_dict
2028
2029 @staticmethod
2030 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2031 error_name = ErrorIf.InputSizeStartLengthMismatch
2032 param_reqs = {"rank": None, "dtype": None, "shape": None}
2033 error_result = False
2034 error_reason = "rank of input not equal to length of start or size"
2035
2036 if check:
2037 input_shape = kwargs["input_shape"]
2038 start = kwargs["start"]
2039 size = kwargs["size"]
2040 rank = len(input_shape)
2041 if rank != len(start) or rank != len(size):
2042 error_result = True
2043
2044 info_dict = {
2045 "error_name": error_name,
2046 "error_result": error_result,
2047 "error_reason": error_reason,
2048 "param_reqs": param_reqs,
2049 }
2050 return info_dict
2051
2052 @staticmethod
2053 def evIndexOutsideBounds(check=False, **kwargs):
2054 error_name = ErrorIf.IndexOutsideBounds
2055 param_reqs = {"rank": None, "dtype": None, "shape": None}
2056 error_result = False
2057 error_reason = "Index outside of allowed bounds"
2058
2059 if check:
2060 input_shape = kwargs["input_shape"]
2061 perms = kwargs["perms"]
2062 rank = len(input_shape)
2063
2064 for index in perms:
2065 if index < 0 or index > rank:
2066 error_result = True
2067
2068 info_dict = {
2069 "error_name": error_name,
2070 "error_result": error_result,
2071 "error_reason": error_reason,
2072 "param_reqs": param_reqs,
2073 }
2074 return info_dict
2075
2076 @staticmethod
2077 def evIndexUsedTwice(check=False, **kwargs):
2078 error_name = ErrorIf.IndexUsedTwice
2079 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2080 error_result = False
2081 error_reason = "Index used multiple times"
2082
2083 if check:
2084 perms = kwargs["perms"]
2085
2086 unique_indices = []
2087 for index in perms:
2088 if index in unique_indices:
2089 error_result = True
2090 else:
2091 unique_indices.append(index)
2092
2093 info_dict = {
2094 "error_name": error_name,
2095 "error_result": error_result,
2096 "error_reason": error_reason,
2097 "param_reqs": param_reqs,
2098 }
2099 return info_dict
2100
2101 @staticmethod
2102 def evMaxSmallerMin(check=False, **kwargs):
2103 error_name = ErrorIf.MaxSmallerMin
2104 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2105 error_result = False
2106 error_reason = "Max value smaller than min value"
2107
2108 if check:
2109 max_val = kwargs["max_val"]
2110 min_val = kwargs["min_val"]
2111 if max_val < min_val:
2112 error_result = True
2113
2114 info_dict = {
2115 "error_name": error_name,
2116 "error_result": error_result,
2117 "error_reason": error_reason,
2118 "param_reqs": param_reqs,
2119 }
2120 return info_dict
2121
2122 @staticmethod
2123 def evConcatInputRankMismatch(check=False, **kwargs):
2124 error_name = ErrorIf.ConcatInputRankMismatch
2125 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2126 error_result = False
2127 error_reason = "Input ranks are not identical"
2128
2129 if check:
2130 inputs = kwargs["inputs"]
2131 input_shape = kwargs["input_shape"]
2132 for input in inputs:
2133 if len(input.shape) != len(input_shape):
2134 error_result = True
2135
2136 info_dict = {
2137 "error_name": error_name,
2138 "error_result": error_result,
2139 "error_reason": error_reason,
2140 "param_reqs": param_reqs,
2141 }
2142 return info_dict
2143
2144 @staticmethod
2145 def evConcatInputDimMismatch(check=False, **kwargs):
2146 error_name = ErrorIf.ConcatInputDimMismatch
2147 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2148 error_result = False
2149 error_reason = "Input dimensions differ on too many axes"
2150
2151 if check:
2152 inputs = kwargs["inputs"]
2153 input_shape = kwargs["input_shape"]
2154 axis = kwargs["axis"]
2155
2156 # Ensure rank is valid before checking dims.
2157 valid_rank = True
2158 for input in inputs:
2159 if len(input.shape) != len(input_shape):
2160 valid_rank = False
2161
2162 if valid_rank:
2163 for input in inputs:
2164 for i, dim in enumerate(input.shape):
2165 if dim != input_shape[i] and axis != i:
2166 error_result = True
2167
2168 info_dict = {
2169 "error_name": error_name,
2170 "error_result": error_result,
2171 "error_reason": error_reason,
2172 "param_reqs": param_reqs,
2173 }
2174 return info_dict
2175
2176 @staticmethod
2177 def evConcatShapeSumMismatch(check=False, **kwargs):
2178 error_name = ErrorIf.ConcatShapeSumMismatch
2179 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2180 error_result = False
2181 error_reason = "Sum of dimensions on axis not equal to output dimension"
2182
2183 if check:
2184 inputs = kwargs["inputs"]
2185 input_shape = kwargs["input_shape"]
2186 output_shape = kwargs["output_shape"]
2187 axis = kwargs["axis"]
2188
2189 # Ensure rank is valid before checking dims.
2190 valid_params = True
2191 for input in inputs:
2192 if len(input.shape) != len(input_shape):
2193 valid_params = False
2194 if axis < 0 or axis > len(input_shape):
2195 valid_params = False
2196
2197 if valid_params:
2198 axis_dim_sum = 0
2199 for input in inputs:
2200 axis_dim_sum += input.shape[axis]
2201
2202 if axis_dim_sum != output_shape[axis]:
2203 error_result = True
2204
2205 info_dict = {
2206 "error_name": error_name,
2207 "error_result": error_result,
2208 "error_reason": error_reason,
2209 "param_reqs": param_reqs,
2210 }
2211 return info_dict
2212
2213 @staticmethod
2214 def evInputListThenGraphMismatch(check=False, **kwargs):
2215 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2216 param_reqs = {"rank": None, "dtype": None, "shape": None}
2217 error_result = False
2218 error_reason = "Input list shape does not match then-graph shape"
2219
2220 if check:
2221 a = kwargs["a"]
2222 b = kwargs["b"]
2223 basicBlocks = kwargs["basicBlocks"]
2224 then_block = basicBlocks[1]
2225 then_inputs = then_block.inputs
2226 then_tens = then_block.tensors
2227 if (a.shape != then_tens[then_inputs[0]].shape) or (
2228 b.shape != then_tens[then_inputs[1]].shape
2229 ):
2230 error_result = True
2231
2232 info_dict = {
2233 "error_name": error_name,
2234 "error_result": error_result,
2235 "error_reason": error_reason,
2236 "param_reqs": param_reqs,
2237 }
2238 return info_dict
2239
2240 @staticmethod
2241 def evInputListElseGraphMismatch(check=False, **kwargs):
2242 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2243 param_reqs = {"rank": None, "dtype": None, "shape": None}
2244 error_result = False
2245 error_reason = "Input list shape does not match else-graph shape"
2246
2247 if check:
2248 a = kwargs["a"]
2249 b = kwargs["b"]
2250 basicBlocks = kwargs["basicBlocks"]
2251 else_block = basicBlocks[2]
2252 else_inputs = else_block.inputs
2253 else_tens = else_block.tensors
2254 if (a.shape != else_tens[else_inputs[0]].shape) or (
2255 b.shape != else_tens[else_inputs[1]].shape
2256 ):
2257 error_result = True
2258
2259 info_dict = {
2260 "error_name": error_name,
2261 "error_result": error_result,
2262 "error_reason": error_reason,
2263 "param_reqs": param_reqs,
2264 }
2265 return info_dict
2266
2267 @staticmethod
2268 def evOutputListThenGraphMismatch(check=False, **kwargs):
2269 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2270 param_reqs = {"rank": None, "dtype": None, "shape": None}
2271 error_result = False
2272 error_reason = "Output list shape does not match then-graph shape"
2273
2274 if check:
2275 basicBlocks = kwargs["basicBlocks"]
2276 cond_block = basicBlocks[0]
2277 cond_outputs = cond_block.outputs
2278 cond_tens = cond_block.tensors
2279 then_block = basicBlocks[1]
2280 then_outputs = then_block.outputs
2281 then_tens = then_block.tensors
2282 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2283 error_result = True
2284
2285 info_dict = {
2286 "error_name": error_name,
2287 "error_result": error_result,
2288 "error_reason": error_reason,
2289 "param_reqs": param_reqs,
2290 }
2291 return info_dict
2292
2293 @staticmethod
2294 def evOutputListElseGraphMismatch(check=False, **kwargs):
2295 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2296 param_reqs = {"rank": None, "dtype": None, "shape": None}
2297 error_result = False
2298 error_reason = "Output list shape does not match else-graph shape"
2299
2300 if check:
2301 basicBlocks = kwargs["basicBlocks"]
2302 cond_block = basicBlocks[0]
2303 cond_outputs = cond_block.outputs
2304 cond_tens = cond_block.tensors
2305 else_block = basicBlocks[2]
2306 else_outputs = else_block.outputs
2307 else_tens = else_block.tensors
2308 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2309 error_result = True
2310
2311 info_dict = {
2312 "error_name": error_name,
2313 "error_result": error_result,
2314 "error_reason": error_reason,
2315 "param_reqs": param_reqs,
2316 }
2317 return info_dict
2318
2319 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002320 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2321 error_name = ErrorIf.CondIfCondNotMatchingBool
2322 param_reqs = {"rank": None, "dtype": None, "shape": None}
2323 error_result = False
2324 error_reason = "Conditional tensor does not match bool type"
2325
2326 if check:
2327 cond = kwargs["cond"]
2328 if cond.dtype != DType.BOOL:
2329 error_result = True
2330
2331 info_dict = {
2332 "error_name": error_name,
2333 "error_result": error_result,
2334 "error_reason": error_reason,
2335 "param_reqs": param_reqs,
2336 }
2337 return info_dict
2338
2339 @staticmethod
2340 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2341 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2342 param_reqs = {"rank": None, "dtype": None, "shape": None}
2343 error_result = False
2344 error_reason = "Conditional tensor is not equal to a size of one"
2345
2346 if check:
2347 cond = kwargs["cond"]
2348 # Size of 1 is equivalent to rank 0
2349 if len(cond.shape) != 0:
2350 error_result = True
2351
2352 info_dict = {
2353 "error_name": error_name,
2354 "error_result": error_result,
2355 "error_reason": error_reason,
2356 "param_reqs": param_reqs,
2357 }
2358 return info_dict
2359
2360 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002361 def evInputListOutputListMismatch(check=False, **kwargs):
2362 error_name = ErrorIf.InputListOutputListMismatch
2363 param_reqs = {"rank": None, "dtype": None, "shape": None}
2364 error_result = False
2365 error_reason = "Input list does not match output list"
2366
2367 if check:
2368 basicBlocks = kwargs["basicBlocks"]
2369 while_block = basicBlocks[0]
2370 while_inputs = while_block.inputs
2371 while_outputs = while_block.outputs
2372 while_tens = while_block.tensors
2373 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2374 error_result = True
2375
2376 info_dict = {
2377 "error_name": error_name,
2378 "error_result": error_result,
2379 "error_reason": error_reason,
2380 "param_reqs": param_reqs,
2381 }
2382 return info_dict
2383
2384 @staticmethod
2385 def evInputListCondGraphMismatch(check=False, **kwargs):
2386 error_name = ErrorIf.InputListCondGraphMismatch
2387 param_reqs = {"rank": None, "dtype": None, "shape": None}
2388 error_result = False
2389 error_reason = "Input list does not match cond graph"
2390
2391 if check:
2392 basicBlocks = kwargs["basicBlocks"]
2393 while_block = basicBlocks[0]
2394 while_inputs = while_block.inputs
2395 while_tens = while_block.tensors
2396 cond_block = basicBlocks[1]
2397 cond_inputs = cond_block.inputs
2398 cond_tens = cond_block.tensors
2399 if (
2400 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2401 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2402 error_result = True
2403
2404 info_dict = {
2405 "error_name": error_name,
2406 "error_result": error_result,
2407 "error_reason": error_reason,
2408 "param_reqs": param_reqs,
2409 }
2410 return info_dict
2411
2412 @staticmethod
2413 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2414 error_name = ErrorIf.InputListBodyGraphInputMismatch
2415 param_reqs = {"rank": None, "dtype": None, "shape": None}
2416 error_result = False
2417 error_reason = "Input list does not match body graph input"
2418
2419 if check:
2420 basicBlocks = kwargs["basicBlocks"]
2421 while_block = basicBlocks[0]
2422 while_inputs = while_block.inputs
2423 while_tens = while_block.tensors
2424 body_block = basicBlocks[2]
2425 body_outputs = body_block.inputs
2426 body_tens = body_block.tensors
2427 if (
2428 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2429 ) or (
2430 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2431 ):
2432 error_result = True
2433
2434 info_dict = {
2435 "error_name": error_name,
2436 "error_result": error_result,
2437 "error_reason": error_reason,
2438 "param_reqs": param_reqs,
2439 }
2440 return info_dict
2441
2442 @staticmethod
2443 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2444 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2445 param_reqs = {"rank": None, "dtype": None, "shape": None}
2446 error_result = False
2447 error_reason = "Input list does not match body graph output"
2448
2449 if check:
2450 basicBlocks = kwargs["basicBlocks"]
2451 while_block = basicBlocks[0]
2452 while_inputs = while_block.inputs
2453 while_tens = while_block.tensors
2454 body_block = basicBlocks[2]
2455 body_outputs = body_block.outputs
2456 body_tens = body_block.tensors
2457 if (
2458 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2459 ) or (
2460 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2461 ):
2462 error_result = True
2463 info_dict = {
2464 "error_name": error_name,
2465 "error_result": error_result,
2466 "error_reason": error_reason,
2467 "param_reqs": param_reqs,
2468 }
2469 return info_dict
2470
2471 @staticmethod
2472 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2473 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2474 param_reqs = {"rank": None, "dtype": None, "shape": None}
2475 error_result = False
2476 error_reason = "Cond graph output is not a match list of booleans"
2477
2478 if check:
2479 basicBlocks = kwargs["basicBlocks"]
2480 cond_block = basicBlocks[1]
2481 cond_outputs = cond_block.outputs
2482 cond_tens = cond_block.tensors
2483 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2484 error_result = True
2485
2486 info_dict = {
2487 "error_name": error_name,
2488 "error_result": error_result,
2489 "error_reason": error_reason,
2490 "param_reqs": param_reqs,
2491 }
2492 return info_dict
2493
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002494 @staticmethod
2495 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2496 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2497 param_reqs = {"rank": None, "dtype": None, "shape": None}
2498 error_result = False
2499 error_reason = "Cond graph output is not a shape of size one"
2500
2501 if check:
2502 basicBlocks = kwargs["basicBlocks"]
2503 cond_block = basicBlocks[1]
2504 cond_outputs = cond_block.outputs
2505 cond_tens = cond_block.tensors
2506 # Size of 1 is equivalent to rank 0
2507 if len(cond_tens[cond_outputs[0]].shape) != 0:
2508 error_result = True
2509
2510 info_dict = {
2511 "error_name": error_name,
2512 "error_result": error_result,
2513 "error_reason": error_reason,
2514 "param_reqs": param_reqs,
2515 }
2516 return info_dict
2517
Luke Hutton261b7b62023-01-10 14:50:31 +00002518 @staticmethod
2519 def evKernelNotPowerOfTwo(check=False, **kwargs):
2520 error_name = ErrorIf.KernelNotPowerOfTwo
2521 param_reqs = {"rank": None, "dtype": None, "shape": None}
2522 error_result = False
2523 error_reason = "kernel height and/or width not a power of two"
2524
2525 def is_power_of_two(x):
2526 return math.log(x, 2).is_integer()
2527
2528 if check:
2529 shape = kwargs["input_shape"]
2530 if len(shape) == 3:
2531 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2532 error_result = not valid_kernel
2533
2534 info_dict = {
2535 "error_name": error_name,
2536 "error_result": error_result,
2537 "error_reason": error_reason,
2538 "param_reqs": param_reqs,
2539 }
2540 return info_dict
2541
Luke Hutton57287132023-02-06 14:54:18 +00002542 @staticmethod
2543 def evFFTInputShapeMismatch(check=False, **kwargs):
2544 error_name = ErrorIf.FFTInputShapeMismatch
2545 param_reqs = {"rank": None, "dtype": None, "shape": None}
2546 error_result = False
2547 error_reason = "Mismatch between real and imaginary input shapes"
2548
2549 if check:
2550 input1 = kwargs["input1"]
2551 input2 = kwargs["input2"]
2552
2553 if input1.shape != input2.shape:
2554 error_result = True
2555
2556 info_dict = {
2557 "error_name": error_name,
2558 "error_result": error_result,
2559 "error_reason": error_reason,
2560 "param_reqs": param_reqs,
2561 }
2562 return info_dict
2563
2564 @staticmethod
2565 def evFFTOutputShapeMismatch(check=False, **kwargs):
2566 error_name = ErrorIf.FFTOutputShapeMismatch
2567 param_reqs = {"rank": None, "dtype": None, "shape": None}
2568 error_result = False
2569 error_reason = (
2570 "Mismatch between provided and expected output kernel (H, W) shape"
2571 )
2572
2573 if check:
2574 op = kwargs["op"]
2575 input_shape = kwargs["input_shape"]
2576
2577 if len(input_shape) == 3:
2578 output_shapes = kwargs["output_shape"]
2579
2580 # Ignoring batch size (N) from input shape
2581 expected_shape = input_shape[1:]
2582 if op["op"] == Op.RFFT2D:
2583 expected_shape[1] = expected_shape[1] // 2 + 1
2584
2585 # Ignoring batch size (N) from output shapes
2586 output_shape_0 = output_shapes[0][1:]
2587 output_shape_1 = output_shapes[1][1:]
2588 # Ensure sure the kernel sizes (H, W) of both outputs match the expected
2589 if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
2590 error_result = True
2591
2592 info_dict = {
2593 "error_name": error_name,
2594 "error_result": error_result,
2595 "error_reason": error_reason,
2596 "param_reqs": param_reqs,
2597 }
2598 return info_dict
2599
Jerry Ge264f7fa2023-04-21 22:49:57 +00002600 @staticmethod
Jerry Ge135c9552023-05-23 20:59:32 +00002601 def calculateBroadcastShape(input_shape_a, input_shape_b):
2602 if input_shape_a is not None and input_shape_b is not None:
2603 calculated_shape = input_shape_a.copy()
2604 for idx in range(len(calculated_shape)):
2605 if calculated_shape[idx] == 1:
2606 calculated_shape[idx] = input_shape_b[idx]
2607 elif (
2608 input_shape_b[idx] != 1
2609 and input_shape_b[idx] != calculated_shape[idx]
2610 ):
2611 return None
2612 return calculated_shape
2613 else:
2614 return None
2615
2616 @staticmethod
2617 def evBroadcastShapesMismatch(check=False, **kwargs):
2618 error_name = ErrorIf.BroadcastShapesMismatch
2619 param_reqs = {"rank": None, "dtype": None, "shape": None}
2620 error_result = False
2621 error_reason = "Broadcast shape calculating failed"
2622
2623 if check:
2624 input_shape_a = kwargs["input1"].shape
2625 input_shape_b = kwargs["input2"].shape
2626 input_shape_c = (
2627 kwargs["input3"].shape if "input3" in kwargs else input_shape_b
2628 )
2629
2630 if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
2631 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
2632 input_shape_c,
2633 TosaErrorValidator.calculateBroadcastShape(
2634 input_shape_a, input_shape_b
2635 ),
2636 )
2637 error_result = calculated_shape is None
2638
2639 info_dict = {
2640 "error_name": error_name,
2641 "error_result": error_result,
2642 "error_reason": error_reason,
2643 "param_reqs": param_reqs,
2644 }
2645 return info_dict
2646
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002647 def evWrongAccumulatorType(check=False, **kwargs):
2648 error_name = ErrorIf.WrongAccumulatorType
2649 param_reqs = {"rank": None, "dtype": None, "shape": None}
2650 error_result = False
2651 error_reason = "An unsupported accumulator data type was requested"
2652
2653 if check:
2654 op = kwargs["op"]
2655 input_dtype = kwargs["input_dtype"]
2656 accum_dtype = kwargs["accum_dtype"]
2657 if op["op"] == Op.AVG_POOL2D:
2658 if (
2659 input_dtype
2660 in (
2661 DType.INT8,
2662 DType.INT16,
2663 )
2664 and accum_dtype != DType.INT32
2665 ):
2666 error_result = True
2667 elif (
2668 input_dtype
2669 in (
2670 DType.FP32,
2671 DType.BF16,
2672 )
2673 and accum_dtype != DType.FP32
2674 ):
2675 error_result = True
2676 elif input_dtype == DType.FP16 and accum_dtype not in (
2677 DType.FP16,
2678 DType.FP32,
2679 ):
2680 error_result = True
Won Jeon2c34b462024-02-06 18:37:00 +00002681 elif (
2682 input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
2683 and accum_dtype != DType.FP16
2684 ):
2685 error_result = True
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002686
2687 info_dict = {
2688 "error_name": error_name,
2689 "error_result": error_result,
2690 "error_reason": error_reason,
2691 "param_reqs": param_reqs,
2692 }
2693 return info_dict
2694
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002695
2696class TosaInvalidValidator:
2697 @staticmethod
2698 def ivWrongDataTypeOrModeResize(**kwargs):
2699 input_dtype = kwargs["input_dtype"]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002700 args_dict = kwargs["args"]
2701 mode = args_dict["mode"]
2702 output_dtype = args_dict["output_dtype"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002703
2704 if mode == ResizeMode.BILINEAR:
2705 # Invalid output data type / Invalid input datatype
2706 return (
2707 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002708 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002709 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002710 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002711 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002712 )
2713 elif mode == ResizeMode.NEAREST:
2714 # Invalid output data type / Invalid input datatype
2715 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002716 input_dtype
2717 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002718 )
2719 else:
2720 # Invalid resize mode
2721 return True
2722
2723 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002724 def ivHeightWidthInvalid(**kwargs):
2725 opName = kwargs["opName"]
2726
2727 inputShapes = kwargs["shapeList"]
2728 input_shape = inputShapes[0]
2729
2730 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002731
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002732 if isinstance(args, dict):
2733 args_dict = args
2734 else:
2735 # Create args_dict from list elements
2736 # TODO - Remove this once all NWHC operators agFunctions have been
2737 # converted to args_dict output
2738
2739 # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
2740 stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
2741 args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]}
2742 # Alias different info for each op
2743 args_dict["kernel"] = args[pad_idx + 1]
2744 args_dict["out_shape"] = args[pad_idx + 1]
2745 args_dict["dilation"] = args[pad_idx + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002746
2747 # Common info for all ops
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002748 strides = args_dict["stride"]
2749 padding = args_dict["pad"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002750
2751 if opName.endswith("pool2d"):
2752 # avg_pool2d, max_pool2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002753 kernel_shape = args_dict["kernel"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002754 h = (
2755 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2756 ) // strides[0]
2757 w = (
2758 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2759 ) // strides[1]
2760 # return True if any dimension is < 1
2761 return h < 1 or w < 1
2762
2763 if opName.startswith("transpose_conv2d"):
2764 # transpose_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002765 output_shape = args_dict["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002766 filter_shape = inputShapes[1]
2767 kernel_shape = filter_shape[1:-1]
2768
TatWai Chong24594f52022-06-08 00:48:04 -07002769 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002770 """Calculate the transpose_conv2d output size for a dimension."""
2771 return (in_size - 1) * stride + kernel_size + in_pad + out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002772
Jeremy Johnson0c716862023-04-13 17:18:19 +01002773 h = get_out_size(
2774 input_shape[1],
2775 strides[0],
2776 kernel_shape[0],
2777 padding[0],
2778 padding[1],
2779 )
2780 w = get_out_size(
2781 input_shape[2],
2782 strides[1],
2783 kernel_shape[1],
2784 padding[2],
2785 padding[3],
2786 )
2787 if output_shape[1] == h and output_shape[2] == w:
2788 return False
2789 # output shape does not match the expected shape
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002790 return True
2791
2792 if "conv2d" in opName or "conv3d" in opName:
2793 # conv2d, conv3d, depthwise_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002794 dilations = args_dict["dilation"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002795 filter_shape = inputShapes[1]
2796 kernel_shape = (
2797 filter_shape[0:2]
2798 if opName.startswith("depthwise_conv2d")
2799 else filter_shape[1:-1]
2800 )
2801
2802 for i in range(len(kernel_shape)):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002803 pad_offset = i * 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002804 dim = (
2805 input_shape[i + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002806 - 1
2807 + padding[pad_offset]
2808 + padding[pad_offset + 1]
2809 - (kernel_shape[i] - 1) * dilations[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002810 ) // strides[i] + 1
2811 # return True if any dimension is < 1
2812 if dim < 1:
2813 return True
2814 return False
2815
2816 assert False, f"Unrecognized Op: {opName}"
2817
2818 @staticmethod
2819 def ivNonPositiveOutputShape(**kwargs):
2820 args = kwargs["args"]
Jeremy Johnson95a67102024-01-10 14:16:39 +00002821 output_shape = args["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002822 if output_shape[1] <= 0 or output_shape[2] <= 0:
2823 # Negative output shape
2824 return True
2825 return False