blob: 9fd13d2d26ca95d776d90768ba6ef75e44a9289d [file] [log] [blame]
Won Jeon74342e52024-01-09 00:34:40 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003import logging
Luke Hutton261b7b62023-01-10 14:50:31 +00004import math
5
Jeremy Johnsondd975b82024-02-28 17:29:13 +00006import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9from tosa.Op import Op
10from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011
Jeremy Johnsonaf090182024-02-13 18:25:39 +000012logging.basicConfig()
13logger = logging.getLogger("tosa_verif_build_tests")
14
Matthew Haddone86fd342021-09-07 16:12:21 +010015
16class ErrorIf(object):
17 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010018 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
19 ScaleNLargerMax = "ScaleNLargerMax"
20 ScaleDLargerMax = "ScaleDLargerMax"
21 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010022 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010023 BorderSmallerMin = "BorderSmallerMin"
24 BorderLargerEqualMax = "BorderLargerEqualMax"
25 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
26 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010027 WrongInputType = "WrongInputType"
28 WrongOutputType = "WrongOutputType"
29 WrongInputList = "WrongInputList"
30 WrongOutputList = "WrongOutputList"
31 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010032 BatchMismatch = "BatchMismatch"
33 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010034 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000035 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010036 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010037 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010038 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010039 AxisSmallerZero = "AxisSmallerZero"
40 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010041 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
42 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010043 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010044 KernelSmallerOne = "KernelSmallerOne"
45 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000046 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010047 PadSmallerZero = "PadSmallerZero"
48 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010049 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010050 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010051 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
52 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
53 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010054 ScaleNotTrue = "ScaleNotTrue"
55 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010056 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
57 StartSmallerZero = "StartSmallerZero"
58 SizeSmallerEqualZero = "SizeSmallerEqualZero"
59 StartSizeOutsideBounds = "StartSizeOutsideBounds"
60 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
61 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
62 IndexOutsideBounds = "IndexOutsideBounds"
63 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010064 MaxSmallerMin = "MaxSmallerMin"
65 ConcatInputRankMismatch = "ConcatInputRankMismatch"
66 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010067 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010068 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
69 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
70 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
71 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
72 InputListOutputListMismatch = "InputListOutputListMismatch"
73 InputListCondGraphMismatch = "InputListCondGraphMismatch"
74 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
75 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
76 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010077 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
78 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000079 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
80 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
81 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000082 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Luke Hutton57287132023-02-06 14:54:18 +000083 FFTInputShapeMismatch = "FFTInputShapeMismatch"
84 FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
Jerry Ge264f7fa2023-04-21 22:49:57 +000085 ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
86 ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
Jerry Ge135c9552023-05-23 20:59:32 +000087 BroadcastShapesMismatch = "BroadcastShapesMismatch"
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +000088 WrongAccumulatorType = "WrongAccumulatorType"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089
90
91class TosaErrorIfArgGen:
92 @staticmethod
93 def eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010095 error_name,
96 mode,
97 dtype,
98 shapeList,
99 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100100 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100102 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100104 if error_name == ErrorIf.ScaleSmallerEqualZero:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100105 index = rng.randInt(low=0, high=4)
106 scale[index] = rng.choice([-2, -1, 0])
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100107 elif error_name == ErrorIf.ScaleNLargerMax:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 index = rng.choice([0, 2])
109 scale[index] = (1 << 11) + rng.choice([1, 2, 3])
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100110 elif error_name == ErrorIf.ScaleDLargerMax:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100111 index = rng.choice([1, 3])
112 scale[index] = 16 * scale[index - 1] + rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100113
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100114 if error_name == ErrorIf.OffsetLargerEqualMax:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100115 index = rng.choice([0, 1])
116 offset[index] = 16 * scale[index * 2] + rng.choice([0, 1, 2])
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100117 elif error_name == ErrorIf.OffsetSmallerMin:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100118 index = rng.choice([0, 1])
119 offset[index] = -scale[index * 2] - rng.choice([1, 2, 3])
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100120
121 if error_name == ErrorIf.BorderLargerEqualMax:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100122 index = rng.choice([0, 1])
123 border[index] = scale[index * 2] + rng.choice([0, 1, 2])
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100124 elif error_name == ErrorIf.BorderSmallerMin:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100125 index = rng.choice([0, 1])
126 border[index] = -16 * scale[index * 2] - rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100127
128 if error_name == ErrorIf.WrongOutputType:
129 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
130 incorrect_types = (
131 DType.INT4,
132 DType.INT16,
133 DType.INT32,
134 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100135 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100136 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100137 )
138 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
139 incorrect_types = (
140 DType.INT4,
141 DType.INT8,
142 DType.INT32,
143 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100144 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100145 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100146 )
147 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
148 incorrect_types = (
149 DType.INT4,
150 DType.INT8,
151 DType.INT16,
152 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100153 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100154 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100155 )
156 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
157 incorrect_types = (
158 DType.INT4,
159 DType.INT8,
160 DType.INT16,
161 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100162 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100163 DType.FP16,
164 )
165 elif dtype == DType.FP16:
166 incorrect_types = (
167 DType.INT4,
168 DType.INT8,
169 DType.INT16,
170 DType.INT32,
171 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100172 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100173 )
James Ward24dbc422022-10-19 12:20:31 +0100174 elif dtype == DType.BF16:
175 incorrect_types = (
176 DType.INT4,
177 DType.INT8,
178 DType.INT16,
179 DType.INT32,
180 DType.INT48,
181 DType.FP32,
182 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100183 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100184 incorrect_types = (
185 DType.INT4,
186 DType.INT8,
187 DType.INT16,
188 DType.INT32,
189 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100190 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100191 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100192 outputDType = rng.choice(a=incorrect_types)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100194 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 def eiPoolingErrorIf(rng, error_name, stride, pad, kernel):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100198 if (
199 error_name == ErrorIf.StrideSmallerOne
200 # padding must not exceed the kernel size
201 and pad[0] < kernel[0]
202 and pad[1] < kernel[0]
203 and pad[2] < kernel[1]
204 and pad[3] < kernel[1]
205 ):
206 wrongStride = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100207 rng.choice([0, -1, -2, -3]),
208 rng.choice([0, -1, -2, -3]),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100209 )
210 return wrongStride, pad, kernel
211 elif error_name == ErrorIf.PadSmallerZero:
212 wrongPad = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100213 rng.choice([-1, -2, -3]),
214 rng.choice([-1, -2, -3]),
215 rng.choice([-1, -2, -3]),
216 rng.choice([-1, -2, -3]),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217 )
218 return stride, wrongPad, kernel
219 elif error_name == ErrorIf.KernelSmallerOne:
220 wrongKernel = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100221 rng.choice([0, -1, -2, -3]),
222 rng.choice([0, -1, -2, -3]),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100223 )
224 return stride, pad, wrongKernel
225 elif error_name == ErrorIf.PadLargerEqualKernel:
226 wrongPad = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100227 rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
228 rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
229 rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
230 rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100231 )
232 return stride, wrongPad, kernel
233 else:
234 return None, None, None
235
236 @staticmethod
237 def eiRescaleWrongOutputType(input_dtype, output_dtype):
238 if input_dtype == DType.INT8:
239 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
240 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100241 elif input_dtype == DType.INT16:
242 if output_dtype not in [
243 DType.UINT8,
244 DType.INT8,
245 DType.UINT16,
246 DType.INT16,
247 DType.INT32,
248 ]:
249 return True
250 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100251 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
252 return True
253 elif input_dtype == DType.INT48:
254 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
255 return True
256 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100257 if output_dtype not in [DType.INT8, DType.INT16]:
258 return True
259 elif input_dtype == DType.UINT16:
260 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100261 return True
262 return False
263
264 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100265 def eiInvalidateInputOutputList(rng, error_name, input_list, output_list):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100266 # Mess up input/output tensors for ERROR_IF checks
267 if error_name == "WrongInputList":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100268 add_input = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100269 if add_input:
270 input_list.append("eiDummyInput")
271 else:
272 input_list = input_list[:-1]
273 elif error_name == "WrongOutputList":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100274 add_output = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100275 if add_output:
276 output_list.append("eiDummyOutput")
277 else:
278 output_list = []
279 return input_list, output_list
280
281 @staticmethod
282 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
283 """Restrict the dimensions and overall size of a shape to
284 max_dim and max_items.
285 """
286 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000287 while gtu.product(new_shape) > max_items:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 new_shape = [max(d - 1, 1) for d in new_shape]
289 return new_shape
290
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100291 def eiSliceErrorIf(rng, error_name, input_shape, start, size):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100292 if error_name == ErrorIf.StartSmallerZero:
293 newStart = []
294 for i in range(len(input_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100295 newStart.append(rng.choice([-3, -2, -1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100296 return newStart, size
297 elif error_name == ErrorIf.SizeSmallerEqualZero:
298 newSize = []
299 for i in range(len(input_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100300 newSize.append(rng.choice([-3, -2, -1, 0]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100301 return start, newSize
302 elif error_name == ErrorIf.StartSizeOutsideBounds:
303 newStart, newSize = [], []
304 for i in range(len(input_shape)):
305 newStart.append(input_shape[i] - 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100306 newSize.append(rng.choice([2, 3, 4]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100307 return newStart, newSize
308 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100309 remove = rng.choice([True, False])
TatWai Chongf15bad82024-01-31 21:33:27 -0800310
311 # Get an empty tensor when diminishing dimension on 1-d tensor.
312 if len(start) == 1 or len(size) == 1:
313 remove = False
314
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100315 if remove:
316 newStart = start[1:]
317 newSize = size[1:]
318 else:
319 newStart = start
320 newStart.append(1)
321 newSize = size
322 newSize.append(1)
323 return newStart, newSize
324 else:
325 return start, size
326
327 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100328 def eiCastErrorIf(input_dtype):
Won Jeon2c34b462024-02-06 18:37:00 +0000329 if input_dtype in [DType.BOOL]:
330 outputDType = [
331 DType.BOOL,
332 DType.INT48,
333 DType.FP32,
334 DType.FP16,
335 DType.BF16,
336 DType.FP8E4M3,
337 DType.FP8E5M2,
338 ]
339 elif input_dtype in [DType.FP32]:
James Ward736fd1a2023-01-23 17:13:37 +0000340 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
341 elif input_dtype in [DType.FP16, DType.BF16]:
342 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100343 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
344 outputDType = [DType.INT48]
Won Jeon2c34b462024-02-06 18:37:00 +0000345 elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]:
346 outputDType = [
347 DType.BOOL,
348 DType.INT8,
349 DType.INT16,
350 DType.INT32,
351 DType.INT48,
352 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100353 else:
James Ward736fd1a2023-01-23 17:13:37 +0000354 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100355 return outputDType
356
357
358class TosaErrorValidator:
359 @staticmethod
360 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
361 """Check ERROR_IF statements are caught and set the expected result.
362
363 Args:
364 serializer: the serializer to set the expected result in
365 validator_fcns: a sequence of validator functions to verify the result
366 error_name: the name of the ERROR_IF condition to check for
367 kwargs: keyword arguments for the validator functions
368 Returns:
369 True if the result matches the expected result; otherwise False
370 """
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000371 if validator_fcns is None:
372 # Nothing to do
373 return True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100374 overall_result = True
375 for val_fcn in validator_fcns:
376 val_result = val_fcn(True, **kwargs)
377 validator_name = val_result["error_name"]
378 error_result = val_result["error_result"]
379 error_reason = val_result["error_reason"]
380
381 # expect an error IFF the error_name and validator_name match
382 expected_result = error_result == (error_name == validator_name)
383 overall_result &= expected_result
384
385 if expected_result and error_result:
386 serializer.setExpectedReturnCode(2, True, desc=error_reason)
387 elif error_result: # and not expected_result
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000388 logger.error(
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000389 f"Unexpected ERROR_IF: Op: {gtu.valueToName(Op, kwargs['op']['op'])}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100390 f" Expected: {error_name}, Got: {validator_name}"
391 )
392 elif not expected_result: # and not error_result
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000393 logger.error(
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000394 f"Missed ERROR_IF: Op: {gtu.valueToName(Op, kwargs['op']['op'])}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100395 f" Expected: {error_name}"
396 )
397
398 if not expected_result:
399 for k, v in sorted(kwargs.items()):
400 if k != "op":
401 if k.endswith("dtype"):
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000402 v = gtu.valueToName(DType, v)
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000403 logger.error(f" {k} = {v}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100404
405 return overall_result
406
407 @staticmethod
408 def evWrongInputType(check=False, **kwargs):
409 error_result = False
410
411 # Find the unsupported input data types
412 op = kwargs["op"]
413 input_dtypes = op["types"]
414 allowed_input_dtypes = {
415 t[0] if isinstance(t, list) else t for t in input_dtypes
416 }
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000417 wrong_input_dtypes = list(gtu.usableDTypes(excludes=allowed_input_dtypes))
418 assert len(wrong_input_dtypes) > 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100419
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100420 # Turn the wrong dtypes into required list of types
421 if op["op"] in [
422 Op.FULLY_CONNECTED,
423 Op.CONV2D,
424 Op.CONV3D,
425 Op.DEPTHWISE_CONV2D,
426 Op.TRANSPOSE_CONV2D,
427 ]:
428 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
429
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100430 if op["op"] == Op.CLAMP:
431 wrong_input_dtypes.remove(DType.INT48)
432
433 if check:
434 input_dtype = kwargs["input_dtype"]
435 if input_dtype not in allowed_input_dtypes:
436 error_result = True
437
438 info_dict = {
439 "error_name": ErrorIf.WrongInputType,
440 "error_result": error_result,
441 "error_reason": "Input data type not supported for this operator",
442 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
443 }
444 return info_dict
445
446 @staticmethod
447 def evWrongOutputType(check=False, **kwargs):
448 error_result = False
449
450 if check:
451 input_dtype = kwargs["input_dtype"]
452 output_dtype = kwargs["output_dtype"]
453 op = kwargs["op"]
454
455 if op["op"] == Op.RESIZE:
456 mode = kwargs["mode"]
457 if (
458 (
459 mode == ResizeMode.NEAREST
460 and input_dtype == DType.INT8
461 and output_dtype != DType.INT8
462 )
463 or (
464 mode == ResizeMode.NEAREST
465 and input_dtype == DType.INT16
466 and output_dtype != DType.INT16
467 )
468 or (
469 mode == ResizeMode.BILINEAR
470 and input_dtype == DType.INT8
471 and output_dtype != DType.INT32
472 )
473 or (
474 mode == ResizeMode.BILINEAR
475 and input_dtype == DType.INT16
476 and output_dtype != DType.INT48
477 )
James Ward8b390432022-08-12 20:48:56 +0100478 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100479 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100480 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100481 ):
482 error_result = True
483
484 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100485 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
486 input_dtype, output_dtype
487 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100488
489 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
490 if (
491 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
492 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100493 or (
494 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100495 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100496 )
James Ward24dbc422022-10-19 12:20:31 +0100497 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100498 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Won Jeon2c34b462024-02-06 18:37:00 +0000499 or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16)
500 or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 ):
502 error_result = True
503
504 elif op["op"] == Op.ARGMAX:
505 if (
James Ward24dbc422022-10-19 12:20:31 +0100506 input_dtype
Won Jeon2c34b462024-02-06 18:37:00 +0000507 in [
508 DType.INT8,
509 DType.INT16,
510 DType.FP16,
511 DType.BF16,
512 DType.FP32,
513 DType.FP8E4M3,
514 DType.FP8E5M2,
515 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100516 and output_dtype != DType.INT32
517 ):
518 error_result = True
519
520 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100521 if (
James Ward24dbc422022-10-19 12:20:31 +0100522 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100523 and output_dtype != DType.INT32
524 ):
525 error_result = True
526 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100527 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100528 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
529 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100530 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100531 error_result = True
532
533 elif op["op"] == Op.TABLE:
534 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
535 error_result = True
536 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
537 error_result = True
538
539 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
540 if output_dtype != DType.BOOL:
541 error_result = True
542
543 elif op["op"] == Op.CAST:
544 if (
545 (
546 input_dtype == DType.BOOL
547 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
548 )
549 or (
550 input_dtype == DType.INT8
551 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100552 not in [
553 DType.BOOL,
554 DType.INT16,
555 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100556 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100557 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100558 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100559 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100560 )
561 or (
562 input_dtype == DType.INT16
563 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100564 not in [
565 DType.BOOL,
566 DType.INT8,
567 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100568 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100569 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100570 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100571 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100572 )
573 or (
574 input_dtype == DType.INT32
575 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100576 not in [
577 DType.BOOL,
578 DType.INT8,
579 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100580 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100581 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100582 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100583 ]
584 )
585 or (
586 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000587 and output_dtype
Won Jeon2c34b462024-02-06 18:37:00 +0000588 not in [
589 DType.INT8,
590 DType.INT16,
591 DType.INT32,
592 DType.FP32,
593 DType.FP8E4M3,
594 DType.FP8E5M2,
595 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100596 )
597 or (
James Ward24dbc422022-10-19 12:20:31 +0100598 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000599 and output_dtype
Won Jeon2c34b462024-02-06 18:37:00 +0000600 not in [
601 DType.INT8,
602 DType.INT16,
603 DType.INT32,
604 DType.FP32,
605 DType.FP8E4M3,
606 DType.FP8E5M2,
607 ]
James Ward24dbc422022-10-19 12:20:31 +0100608 )
609 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100610 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000611 and output_dtype
612 not in [
613 DType.INT8,
614 DType.INT16,
615 DType.INT32,
616 DType.FP16,
617 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000618 DType.FP8E4M3,
619 DType.FP8E5M2,
620 ]
621 )
622 or (
623 input_dtype in [DType.FP8E4M3, DType.FP8E5M2]
624 and output_dtype
625 not in [
626 DType.FP16,
627 DType.BF16,
628 DType.FP32,
James Ward736fd1a2023-01-23 17:13:37 +0000629 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100630 )
631 ):
632 error_result = True
633
Luke Hutton57287132023-02-06 14:54:18 +0000634 elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000635 if not all([ty == input_dtype for ty in output_dtype]):
636 error_result = True
637
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100638 elif op["op"] in {
639 Op.CONV2D,
640 Op.CONV3D,
641 Op.DEPTHWISE_CONV2D,
642 Op.TRANSPOSE_CONV2D,
643 }:
644 if (
645 input_dtype == DType.INT8
646 and output_dtype != DType.INT32
647 or input_dtype == DType.INT16
648 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100649 or input_dtype == DType.FP16
Tai Lyf36f2562024-03-14 16:21:29 +0000650 and output_dtype != DType.FP16
James Ward24dbc422022-10-19 12:20:31 +0100651 or input_dtype == DType.BF16
Tai Lyf36f2562024-03-14 16:21:29 +0000652 and output_dtype != DType.BF16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100653 or input_dtype == DType.FP32
654 and output_dtype != DType.FP32
Won Jeon2c34b462024-02-06 18:37:00 +0000655 or input_dtype == DType.FP8E4M3
656 and output_dtype != DType.FP16
657 or input_dtype == DType.FP8E5M2
658 and output_dtype != DType.FP16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100659 ):
660 error_result = True
661 # invalid input types are ignored, to avoid reporting multiple errors
662
Won Jeon74342e52024-01-09 00:34:40 +0000663 elif op["op"] in {Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE}:
664 if output_dtype != DType.SHAPE:
665 error_result = True
666
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100667 else:
668 if output_dtype != input_dtype:
669 error_result = True
670
671 info_dict = {
672 "error_name": ErrorIf.WrongOutputType,
673 "error_result": error_result,
674 "error_reason": (
675 "Output data type not supported for this configuration of operator"
676 ),
677 "param_reqs": {"rank": None, "dtype": None, "shape": None},
678 }
679 return info_dict
680
681 @staticmethod
682 def evWrongRank(check=False, **kwargs):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100683 # Make a list of incorrect ranks
684 assert "op" in kwargs
685 op = kwargs["op"]
686 rmin, rmax = op["rank"]
687 rank_range = range(rmin, rmax + 1)
Jeremy Johnsonac8c0c82024-03-21 10:32:26 +0000688 # From 1 to rmax+1 inclusively
689 all_ranks = tuple(range(1, rmax + 2))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100690 incorrect_ranks = list(set(all_ranks) - set(rank_range))
691 # Remove small incorrect ranks to avoid index errors
692 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
693 # Set minimum incorrect rank to 3 to avoid index error
694 if op["op"] in [Op.RESIZE]:
695 incorrect_ranks = [3, 5]
696 elif op["op"] in [Op.TRANSPOSE]:
697 incorrect_ranks = [7, 8]
698 elif op["op"] in [Op.CONV3D]:
699 incorrect_ranks = [6, 7]
700
701 error_name = ErrorIf.WrongRank
702 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
703 error_result = False
704 error_reason = "Rank not supported for this operator"
705
706 if check:
707 input_shape = kwargs["input_shape"]
708
709 if (
710 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
711 and len(input_shape) != 4
712 ):
713 error_result = True
714 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
715 error_result = True
716 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
717 error_result = True
718 else:
719 if len(input_shape) not in rank_range:
720 error_result = True
721
722 info_dict = {
723 "error_name": error_name,
724 "error_result": error_result,
725 "error_reason": error_reason,
726 "param_reqs": param_reqs,
727 }
728 return info_dict
729
730 @staticmethod
731 def evWrongInputList(check=False, **kwargs):
732 error_name = ErrorIf.WrongInputList
733 param_reqs = {"rank": None, "dtype": None, "shape": None}
734 error_result = False
735 error_reason = "Op input list does not match expected input"
736
737 if check:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100738 input_list = kwargs["input_list"]
739 num_operands = kwargs["num_operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100740 if len(input_list) != num_operands:
741 error_result = True
742
743 info_dict = {
744 "error_name": error_name,
745 "error_result": error_result,
746 "error_reason": error_reason,
747 "param_reqs": param_reqs,
748 }
749 return info_dict
750
751 @staticmethod
752 def evWrongOutputList(check=False, **kwargs):
753 error_name = ErrorIf.WrongOutputList
754 param_reqs = {"rank": None, "dtype": None, "shape": None}
755 error_result = False
756 error_reason = "Op output list does not match expected output"
757
758 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000759 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100760 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000761 expected_length = 1
Luke Hutton57287132023-02-06 14:54:18 +0000762 if op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000763 expected_length = 2
764
765 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100766 error_result = True
767
768 info_dict = {
769 "error_name": error_name,
770 "error_result": error_result,
771 "error_reason": error_reason,
772 "param_reqs": param_reqs,
773 }
774 return info_dict
775
776 @staticmethod
777 def evMaxDimExceeded(check=False, **kwargs):
778 error_name = ErrorIf.MaxDimExceeded
779 param_reqs = {
780 "rank": [4, 4],
781 "dtype": [DType.INT8],
782 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
783 }
784 error_result = False
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000785 error_reason = f"At least one maximum dimension is greater than or equal to {gtu.MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100786
787 if check:
788 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100789 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100790 if (
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000791 (input_shape[1] >= gtu.MAX_RESIZE_DIMENSION)
792 or (input_shape[2] >= gtu.MAX_RESIZE_DIMENSION)
793 or (output_shape[1] >= gtu.MAX_RESIZE_DIMENSION)
794 or (output_shape[2] >= gtu.MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100795 ):
796 error_result = True
797
798 info_dict = {
799 "error_name": error_name,
800 "error_result": error_result,
801 "error_reason": error_reason,
802 "param_reqs": param_reqs,
803 }
804 return info_dict
805
806 @staticmethod
807 def evBatchMismatch(check=False, **kwargs):
808 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000809 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100810 error_result = False
811 error_reason = "Input batch size not equal to output batch size"
812
813 assert "op" in kwargs
814 op = kwargs["op"]
815 rmin, rmax = op["rank"]
816 rank_range = range(rmin, rmax + 1)
817
818 if check:
819 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100820
Luke Hutton261b7b62023-01-10 14:50:31 +0000821 for output in kwargs["result_tensors"]:
822 output_shape = (
823 output.shape
824 ) # Note batch is expected to be the first dim
825 if (len(input_shape) in rank_range) and (
826 input_shape[0] != output_shape[0]
827 ):
828 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100829
830 info_dict = {
831 "error_name": error_name,
832 "error_result": error_result,
833 "error_reason": error_reason,
834 "param_reqs": param_reqs,
835 }
836 return info_dict
837
838 @staticmethod
839 def evChannelMismatch(check=False, **kwargs):
840 error_name = ErrorIf.ChannelMismatch
841 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
842 error_result = False
843 error_reason = "Input channel size not equal to output channel size"
844
845 assert "op" in kwargs
846 op = kwargs["op"]
847 rmin, rmax = op["rank"]
848 rank_range = range(rmin, rmax + 1)
849
850 if check:
851 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000852 for output in kwargs["result_tensors"]:
853 output_shape = output.shape # Note this is just (N, OH, OW, C)
854 if (len(input_shape) in rank_range) and (
855 input_shape[3] != output_shape[3]
856 ):
857 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100858
859 info_dict = {
860 "error_name": error_name,
861 "error_result": error_result,
862 "error_reason": error_reason,
863 "param_reqs": param_reqs,
864 }
865 return info_dict
866
867 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100868 def evScaleSmallerEqualZero(check=False, **kwargs):
869 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100870 param_reqs = {"rank": None, "dtype": None, "shape": None}
871 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100872 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100873
874 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100875 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100876
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100877 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100878 error_result = True
879
880 info_dict = {
881 "error_name": error_name,
882 "error_result": error_result,
883 "error_reason": error_reason,
884 "param_reqs": param_reqs,
885 }
886 return info_dict
887
888 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100889 def evScaleNLargerMax(check=False, **kwargs):
890 error_name = ErrorIf.ScaleNLargerMax
891 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100892 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100893 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100894
895 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100896 scale = kwargs["scale"]
897
898 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
899 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100900
901 info_dict = {
902 "error_name": error_name,
903 "error_result": error_result,
904 "error_reason": error_reason,
905 "param_reqs": param_reqs,
906 }
907 return info_dict
908
909 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100910 def evScaleDLargerMax(check=False, **kwargs):
911 error_name = ErrorIf.ScaleDLargerMax
912 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100913 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100914 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100915
916 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100917 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100918
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100919 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
920 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100921 ):
922 error_result = True
923
924 info_dict = {
925 "error_name": error_name,
926 "error_result": error_result,
927 "error_reason": error_reason,
928 "param_reqs": param_reqs,
929 }
930 return info_dict
931
932 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100933 def evOffsetSmallerMin(check=False, **kwargs):
934 error_name = ErrorIf.OffsetSmallerMin
935 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100936 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100937 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100938
939 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100940 scale = kwargs["scale"]
941 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100942
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100943 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100944 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100945 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100946 error_result = True
947
948 info_dict = {
949 "error_name": error_name,
950 "error_result": error_result,
951 "error_reason": error_reason,
952 "param_reqs": param_reqs,
953 }
954 return info_dict
955
956 @staticmethod
957 def evOffsetLargerEqualMax(check=False, **kwargs):
958 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100959 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100960 error_result = False
961 error_reason = "Offset value larger than or equal to maximum value"
962
963 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100964 scale = kwargs["scale"]
965 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100966
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100967 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
968 error_result = True
969 elif (
970 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
971 ):
972 error_result = True
973
974 info_dict = {
975 "error_name": error_name,
976 "error_result": error_result,
977 "error_reason": error_reason,
978 "param_reqs": param_reqs,
979 }
980 return info_dict
981
982 @staticmethod
983 def evBorderSmallerMin(check=False, **kwargs):
984 error_name = ErrorIf.BorderSmallerMin
985 param_reqs = {"rank": None, "dtype": None, "shape": None}
986 error_result = False
987 error_reason = "Border value smaller than minimum value"
988
989 if check:
990 scale = kwargs["scale"]
991 border = kwargs["border"]
992
993 if (
994 scale[0] > 0
995 and scale[0] <= (1 << 11)
996 and (border[0] < (-16 * scale[0]))
997 ):
998 error_result = True
999 elif (
1000 scale[2] > 0
1001 and scale[2] <= (1 << 11)
1002 and (border[1] < (-16 * scale[2]))
1003 ):
1004 error_result = True
1005
1006 info_dict = {
1007 "error_name": error_name,
1008 "error_result": error_result,
1009 "error_reason": error_reason,
1010 "param_reqs": param_reqs,
1011 }
1012 return info_dict
1013
1014 @staticmethod
1015 def evBorderLargerEqualMax(check=False, **kwargs):
1016 error_name = ErrorIf.BorderLargerEqualMax
1017 param_reqs = {"rank": None, "dtype": None, "shape": None}
1018 error_result = False
1019 error_reason = "Border value larger than or equal to maximum value"
1020
1021 if check:
1022 scale = kwargs["scale"]
1023 border = kwargs["border"]
1024
1025 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
1026 error_result = True
1027 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
1028 error_result = True
1029
1030 info_dict = {
1031 "error_name": error_name,
1032 "error_result": error_result,
1033 "error_reason": error_reason,
1034 "param_reqs": param_reqs,
1035 }
1036 return info_dict
1037
1038 @staticmethod
1039 def checkResizeParams(scale, offset, border):
1040 return (
1041 min(scale) > 0
1042 and max(scale[0], scale[2]) <= (1 << 11)
1043 and scale[1] < 16 * scale[0]
1044 and scale[3] < 16 * scale[2]
1045 and offset[0] >= -scale[0]
1046 and offset[1] >= -scale[2]
1047 and offset[0] < 16 * scale[0]
1048 and offset[1] < 16 * scale[2]
1049 and border[0] >= -16 * scale[0]
1050 and border[1] >= -16 * scale[2]
1051 and border[0] < scale[0]
1052 and border[1] < scale[2]
1053 )
1054
1055 @staticmethod
1056 def evResizeOutputShapeMismatch(check=False, **kwargs):
1057 error_name = ErrorIf.ResizeOutputShapeMismatch
1058 param_reqs = {"rank": None, "dtype": None, "shape": None}
1059 error_result = False
1060 error_reason = (
1061 "Mismatch between output shape provided and expected output shape"
1062 )
1063
1064 if check:
1065 input_shape = kwargs["input_shape"]
1066 output_shape = kwargs["output_shape"]
1067 scale = kwargs["scale"]
1068 offset = kwargs["offset"]
1069 border = kwargs["border"]
1070
1071 # Ensure parameters are valid
1072 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1073
1074 if (
1075 params_valid
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001076 and max(output_shape) < gtu.MAX_RESIZE_DIMENSION
1077 and max(input_shape) < gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001078 ):
1079 output_y = (
1080 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1081 ) // scale[1] + 1
1082 output_x = (
1083 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1084 ) // scale[3] + 1
1085
1086 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001087 error_result = True
1088
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001089 info_dict = {
1090 "error_name": error_name,
1091 "error_result": error_result,
1092 "error_reason": error_reason,
1093 "param_reqs": param_reqs,
1094 }
1095 return info_dict
1096
1097 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001098 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1099 error_name = ErrorIf.ResizeOutputShapeNonInteger
1100 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001101 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001102 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001103
1104 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001105 input_shape = kwargs["input_shape"]
1106 scale = kwargs["scale"]
1107 offset = kwargs["offset"]
1108 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001109
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001110 # Ensure parameters are valid
1111 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001112
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001113 if params_valid:
1114 remainder_y = (
1115 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1116 ) % scale[1]
1117 remainder_x = (
1118 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1119 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001120
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001121 if max(remainder_y, remainder_x) > 0:
1122 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001123
1124 info_dict = {
1125 "error_name": error_name,
1126 "error_result": error_result,
1127 "error_reason": error_reason,
1128 "param_reqs": param_reqs,
1129 }
1130 return info_dict
1131
1132 @staticmethod
1133 def evRankMismatch(check=False, **kwargs):
1134 error_name = ErrorIf.RankMismatch
1135 param_reqs = {"rank": None, "dtype": None, "shape": None}
1136 error_result = False
1137 error_reason = "Input Rank does not match output rank"
1138
1139 if check:
1140 input1_shape = kwargs["input1"].shape
Luke Huttona4e48ca2023-02-22 11:53:48 +00001141 input2_shape = (
1142 kwargs["input2"].shape if "input2" in kwargs else input1_shape
1143 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001144 # In case of SELECT op
1145 input3_shape = (
1146 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1147 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001148
1149 for output in kwargs["result_tensors"]:
1150 output_shape = output.shape
1151 if (
1152 (len(input1_shape) != len(output_shape))
1153 or (len(input2_shape) != len(output_shape))
1154 or (len(input3_shape) != len(output_shape))
1155 ):
1156 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001157
1158 info_dict = {
1159 "error_name": error_name,
1160 "error_result": error_result,
1161 "error_reason": error_reason,
1162 "param_reqs": param_reqs,
1163 }
1164 return info_dict
1165
1166 @staticmethod
1167 def evDimensionMismatch(check=False, **kwargs):
1168 error_name = ErrorIf.DimensionMismatch
1169 param_reqs = {"rank": None, "dtype": None, "shape": None}
1170 error_result = False
1171 error_reason = "Input Dimensions do not match output"
1172
1173 if check:
1174 input1_shape = kwargs["input1"].shape
1175 input2_shape = kwargs["input2"].shape
1176 # In case of SELECT op
1177 input3_shape = (
1178 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1179 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001180
Won Jeon74342e52024-01-09 00:34:40 +00001181 op = kwargs["op"]
1182 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE):
1183 output_shape = kwargs["result_tensors"][0].shape
1184 if input1_shape != output_shape:
1185 error_result = True
1186
1187 elif len(input1_shape) == len(input2_shape) == len(input3_shape):
Jerry Ge135c9552023-05-23 20:59:32 +00001188 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
1189 input3_shape,
1190 TosaErrorValidator.calculateBroadcastShape(
1191 input1_shape, input2_shape
1192 ),
1193 )
1194 if calculated_shape is not None:
1195 # Valid inputs - check for output mismatch
1196 for output in kwargs["result_tensors"]:
1197 output_shape = output.shape
1198 if calculated_shape != output_shape:
1199 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001200
1201 info_dict = {
1202 "error_name": error_name,
1203 "error_result": error_result,
1204 "error_reason": error_reason,
1205 "param_reqs": param_reqs,
1206 }
1207 return info_dict
1208
1209 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001210 def _getZeroPoint(qinfo, index):
1211 """Return zero point value from quantization info.
1212
1213 Generally input_zp is index 0, output_zp is index 1
1214 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001215 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001216
1217 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001218 def evInputZeroPointNotZero(check=False, **kwargs):
1219 op = kwargs["op"]
1220 error_result = False
1221
1222 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001223 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001224
1225 # This does not apply to quantizable types
1226 inputDtypes = [
1227 dtype
1228 for dtype in op["types"]
1229 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1230 or (not isinstance(dtype, list) and dtype not in qTypes)
1231 ]
1232
1233 if check:
1234 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001235 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001236 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001237 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001238 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001239 (kwargs["input_dtype"], input_zero_point),
1240 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001241 ):
1242 if dtype not in qTypes and zp != 0:
1243 error_result = True
1244 break
1245 else:
1246 error_result = input_dtype not in qTypes and input_zero_point != 0
1247
1248 info_dict = {
1249 "error_name": ErrorIf.InputZeroPointNotZero,
1250 "error_result": error_result,
1251 "error_reason": "Input DType not INT8 and zero point not 0",
1252 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1253 }
1254 return info_dict
1255
1256 @staticmethod
1257 def evWeightZeroPointNotZero(check=False, **kwargs):
1258 op = kwargs["op"]
1259
1260 # exclude inputs with INT8 weights
1261 inputDtypes = [
1262 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1263 ]
1264
1265 error_name = ErrorIf.WeightZeroPointNotZero
1266 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1267 error_result = False
1268 error_reason = "Weight DType not INT8 and zero point not 0"
1269
1270 if check:
1271 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001272 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001273 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1274 error_result = True
1275
1276 info_dict = {
1277 "error_name": error_name,
1278 "error_result": error_result,
1279 "error_reason": error_reason,
1280 "param_reqs": param_reqs,
1281 }
1282 return info_dict
1283
1284 @staticmethod
1285 def evOutputZeroPointNotZero(check=False, **kwargs):
1286 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001287 inputDtypes = [
1288 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1289 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001290
1291 error_name = ErrorIf.OutputZeroPointNotZero
1292 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1293 error_result = False
1294 error_reason = "Output DType not INT8 and zero point not 0"
1295
1296 if check:
1297 input_dtype = kwargs["input_dtype"]
1298 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001299 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001300 if op["op"] == Op.AVG_POOL2D:
1301 if input_dtype != DType.INT8 and output_zero_point != 0:
1302 error_result = True
1303 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001304 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1305 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001306 ):
1307 error_result = True
1308
1309 info_dict = {
1310 "error_name": error_name,
1311 "error_result": error_result,
1312 "error_reason": error_reason,
1313 "param_reqs": param_reqs,
1314 }
1315 return info_dict
1316
1317 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001318 def evU16InputZeroPointNotValid(check=False, **kwargs):
1319 error_name = ErrorIf.U16InputZeroPointNotValid
1320 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1321 error_result = False
1322 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1323
1324 if check:
1325 input_dtype = kwargs["input_dtype"]
1326 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1327 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1328 0,
1329 32768,
1330 ]
1331
1332 info_dict = {
1333 "error_name": error_name,
1334 "error_result": error_result,
1335 "error_reason": error_reason,
1336 "param_reqs": param_reqs,
1337 }
1338 return info_dict
1339
1340 @staticmethod
1341 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1342 error_name = ErrorIf.U16OutputZeroPointNotValid
1343 param_reqs = {"rank": None, "dtype": None, "shape": None}
1344 error_result = False
1345 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1346
1347 if check:
1348 output_dtype = kwargs["output_dtype"]
1349 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1350
1351 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1352 0,
1353 32768,
1354 ]
1355
1356 info_dict = {
1357 "error_name": error_name,
1358 "error_result": error_result,
1359 "error_reason": error_reason,
1360 "param_reqs": param_reqs,
1361 }
1362 return info_dict
1363
1364 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001365 def evAxisSmallerZero(check=False, **kwargs):
1366 error_name = ErrorIf.AxisSmallerZero
1367 param_reqs = {"rank": None, "dtype": None, "shape": None}
1368 error_result = False
1369 error_reason = "Axis smaller than zero"
1370
1371 if check:
1372 axis = kwargs["axis"]
1373 if axis < 0:
1374 error_result = True
1375
1376 info_dict = {
1377 "error_name": error_name,
1378 "error_result": error_result,
1379 "error_reason": error_reason,
1380 "param_reqs": param_reqs,
1381 }
1382 return info_dict
1383
1384 @staticmethod
1385 def evAxisLargerRank(check=False, **kwargs):
1386 error_name = ErrorIf.AxisLargerRank
1387 param_reqs = {"rank": None, "dtype": None, "shape": None}
1388 error_result = False
1389 error_reason = "Axis larger than rank"
1390
1391 if check:
1392 axis = kwargs["axis"]
1393 shape = kwargs["input_shape"]
1394 if axis > len(shape):
1395 error_result = True
1396
1397 info_dict = {
1398 "error_name": error_name,
1399 "error_result": error_result,
1400 "error_reason": error_reason,
1401 "param_reqs": param_reqs,
1402 }
1403 return info_dict
1404
1405 @staticmethod
1406 def evShapeOfAxisNotOne(check=False, **kwargs):
1407 error_name = ErrorIf.ShapeOfAxisNotOne
1408 param_reqs = {"rank": None, "dtype": None, "shape": None}
1409 error_result = False
1410 error_reason = "shape[axis] is not equal to 1"
1411
1412 if check:
1413 axis = kwargs["axis"]
1414 shape = kwargs["output_shape"]
1415 if (0 <= axis < len(shape)) and shape[axis] != 1:
1416 error_result = True
1417
1418 info_dict = {
1419 "error_name": error_name,
1420 "error_result": error_result,
1421 "error_reason": error_reason,
1422 "param_reqs": param_reqs,
1423 }
1424 return info_dict
1425
1426 @staticmethod
1427 def evPadSmallerZero(check=False, **kwargs):
1428 error_name = ErrorIf.PadSmallerZero
1429 param_reqs = {"rank": None, "dtype": None, "shape": None}
1430 error_result = False
1431 error_reason = "At least one pad is smaller than zero"
1432
1433 if check:
1434 op = kwargs["op"]
1435 pad = kwargs["pad"]
1436 if op["op"] == Op.PAD:
1437 for padding in pad:
1438 if min(padding) < 0:
1439 error_result = True
1440 else:
1441 if min(pad) < 0:
1442 error_result = True
1443
1444 info_dict = {
1445 "error_name": error_name,
1446 "error_result": error_result,
1447 "error_reason": error_reason,
1448 "param_reqs": param_reqs,
1449 }
1450 return info_dict
1451
1452 @staticmethod
1453 def evPadLargerEqualKernel(check=False, **kwargs):
1454 error_name = ErrorIf.PadLargerEqualKernel
1455 param_reqs = {"rank": None, "dtype": None, "shape": None}
1456 error_result = False
1457 error_reason = "At least one pad is larger than kernel dimension"
1458
1459 if check:
1460 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001461 op = kwargs["op"]
1462 if op["op"] == Op.TRANSPOSE_CONV2D:
1463 # transpose_conv2d
1464 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001465 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001466 pad[0] <= -kernel[0]
1467 or pad[1] <= -kernel[0]
1468 or pad[2] <= -kernel[1]
1469 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001470 ):
1471 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001472 else:
1473 # pooling op
1474 kernel = kwargs["kernel"]
1475 if min(pad) > 0 and min(kernel) > 1:
1476 if (
1477 pad[0] >= kernel[0]
1478 or pad[1] >= kernel[0]
1479 or pad[2] >= kernel[1]
1480 or pad[3] >= kernel[1]
1481 ):
1482 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001483
1484 info_dict = {
1485 "error_name": error_name,
1486 "error_result": error_result,
1487 "error_reason": error_reason,
1488 "param_reqs": param_reqs,
1489 }
1490 return info_dict
1491
1492 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001493 def evPadOutputShapeMismatch(check=False, **kwargs):
1494 error_name = ErrorIf.PadOutputShapeMismatch
1495 param_reqs = {"rank": None, "dtype": None, "shape": None}
1496 error_result = False
1497 error_reason = "Pad output shape mismatch for requested padding"
1498
1499 if check:
1500 pad = kwargs["pad"]
1501 input_shape = kwargs["input_shape"]
1502 output_shape = kwargs["output_shape"]
1503 for dim, padding in enumerate(pad):
1504 expected_size = input_shape[dim] + padding[0] + padding[1]
1505 if expected_size != output_shape[dim]:
1506 error_result = True
1507
1508 info_dict = {
1509 "error_name": error_name,
1510 "error_result": error_result,
1511 "error_reason": error_reason,
1512 "param_reqs": param_reqs,
1513 }
1514 return info_dict
1515
1516 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001517 def checkPoolingParams(kernel, stride, pad):
1518 return (
1519 min(kernel) >= 1
1520 and min(stride) >= 1
1521 and min(pad) >= 0
1522 and not (
1523 pad[0] >= kernel[0]
1524 or pad[1] >= kernel[0]
1525 or pad[2] >= kernel[1]
1526 or pad[3] >= kernel[1]
1527 )
1528 )
1529
1530 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001531 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1532 error_name = ErrorIf.PoolingOutputShapeMismatch
1533 param_reqs = {"rank": None, "dtype": None, "shape": None}
1534 error_result = False
1535 error_reason = (
1536 "Mismatch between output shape provided and expected output shape"
1537 )
1538
1539 if check:
1540 pad = kwargs["pad"]
1541 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1542
1543 kernel = kwargs["kernel"]
1544 kernel_y, kernel_x = kernel[0], kernel[1]
1545
1546 input_shape = kwargs["input_shape"]
1547 IH, IW = input_shape[1], input_shape[2]
1548
1549 output_shape = kwargs["output_shape"]
1550 OH, OW = output_shape[1], output_shape[2]
1551
1552 stride = kwargs["stride"]
1553 stride_y, stride_x = stride[0], stride[1]
1554
1555 # calculate correct height, width dimensions
1556 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001557 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1558 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001559
1560 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001561 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001562
1563 if params_valid and (OH != y_correct or OW != x_correct):
1564 error_result = True
1565
1566 info_dict = {
1567 "error_name": error_name,
1568 "error_result": error_result,
1569 "error_reason": error_reason,
1570 "param_reqs": param_reqs,
1571 }
1572 return info_dict
1573
1574 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001575 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1576 error_name = ErrorIf.PoolingOutputShapeNonInteger
1577 param_reqs = {"rank": None, "dtype": None, "shape": None}
1578 error_result = False
1579 error_reason = "Parameters do not yield exact integer output dimensions"
1580
1581 if check:
1582 pad = kwargs["pad"]
1583 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1584
1585 kernel = kwargs["kernel"]
1586 kernel_y, kernel_x = kernel[0], kernel[1]
1587
1588 input_shape = kwargs["input_shape"]
1589 IH, IW = input_shape[1], input_shape[2]
1590
1591 stride = kwargs["stride"]
1592 stride_y, stride_x = stride[0], stride[1]
1593
1594 # calculate remainder of height, width dimensions
1595 if stride_x != 0 and stride_y != 0:
1596 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1597 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1598
1599 # ensure parameters are valid
1600 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1601 if params_valid and (y_remainder != 0 or x_remainder != 0):
1602 error_result = True
1603
1604 info_dict = {
1605 "error_name": error_name,
1606 "error_result": error_result,
1607 "error_reason": error_reason,
1608 "param_reqs": param_reqs,
1609 }
1610 return info_dict
1611
1612 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001613 def checkConvParams(op, weight_shape, stride, pad, dilation):
1614 if op == Op.TRANSPOSE_CONV2D:
1615 pad_ok = (
1616 pad[0] > -weight_shape[1]
1617 and pad[1] > -weight_shape[1]
1618 and pad[2] > -weight_shape[2]
1619 and pad[3] > -weight_shape[2]
1620 )
1621 else:
1622 pad_ok = min(pad) >= 0
1623
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001624 return (
1625 # Check kernel sizes
1626 min(weight_shape[1:-1]) >= 1
1627 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001628 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001629 and (dilation is None or min(dilation) >= 1)
1630 )
1631
1632 @staticmethod
1633 def evConvOutputShapeMismatch(check=False, **kwargs):
1634 error_name = ErrorIf.ConvOutputShapeMismatch
1635 param_reqs = {"rank": None, "dtype": None, "shape": None}
1636 error_result = False
1637 error_reason = (
1638 "Mismatch between output shape provided and expected output shape"
1639 )
1640
1641 if check:
1642 op = kwargs["op"]
1643 pad = kwargs["pad"]
1644 weight_shape = kwargs["weight_shape"]
1645 input_shape = kwargs["input_shape"]
1646 output_shape = kwargs["output_shape"]
1647 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1648 stride = kwargs["stride"]
1649
1650 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1651
1652 # calculate correct dimensions
1653 dims_correct = []
1654 if min(stride) > 0:
1655 for index in range(len(stride)):
1656 pad_offset = index * 2
1657 if op["op"] == Op.TRANSPOSE_CONV2D:
1658 dims_correct.append(
1659 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001660 + pad[pad_offset]
1661 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001662 + weight_shape[index + kernel_offset]
1663 )
1664 else:
1665 dims_correct.append(
1666 (
1667 input_shape[index + 1]
1668 - 1
1669 + pad[pad_offset]
1670 + pad[pad_offset + 1]
1671 - (weight_shape[index + kernel_offset] - 1)
1672 * dilation[index]
1673 )
1674 // stride[index]
1675 + 1
1676 )
1677
1678 # ensure parameters are valid
1679 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001680 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001681 )
1682
1683 if params_valid and output_shape[1:-1] != dims_correct:
1684 error_result = True
1685
1686 info_dict = {
1687 "error_name": error_name,
1688 "error_result": error_result,
1689 "error_reason": error_reason,
1690 "param_reqs": param_reqs,
1691 }
1692 return info_dict
1693
1694 @staticmethod
1695 def evConvOutputShapeNonInteger(check=False, **kwargs):
1696 error_name = ErrorIf.ConvOutputShapeNonInteger
1697 param_reqs = {"rank": None, "dtype": None, "shape": None}
1698 error_result = False
1699 error_reason = "Parameters do not yield exact integer output dimensions"
1700
1701 if check:
1702 op = kwargs["op"]
1703 pad = kwargs["pad"]
1704 weight_shape = kwargs["weight_shape"]
1705 input_shape = kwargs["input_shape"]
1706 dilation = kwargs["dilation"]
1707 stride = kwargs["stride"]
1708
1709 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1710
1711 # calculate correct height, width dimensions
1712 remainders = []
1713 if min(stride) > 0:
1714 for index in range(len(stride)):
1715 pad_offset = index * 2
1716 remainders.append(
1717 (
1718 input_shape[index + 1]
1719 - 1
1720 + pad[pad_offset]
1721 + pad[pad_offset + 1]
1722 - (weight_shape[index + kernel_offset] - 1)
1723 * dilation[index]
1724 )
1725 % stride[index]
1726 )
1727
1728 # ensure parameters are valid
1729 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001730 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001731 )
1732 if params_valid and max(remainders) > 0:
1733 error_result = True
1734
1735 info_dict = {
1736 "error_name": error_name,
1737 "error_result": error_result,
1738 "error_reason": error_reason,
1739 "param_reqs": param_reqs,
1740 }
1741 return info_dict
1742
1743 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001744 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1745 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1746 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1747 error_result = False
1748 error_reason = (
1749 "Mismatch between output shape provided and expected output shape"
1750 )
1751
1752 if check:
1753 output_shape = kwargs["output_shape"]
1754 input_shape = kwargs["input_shape"]
1755 axis = kwargs["axis"]
1756
1757 dimension_match = True
1758 axis_shift = 0
1759
1760 # Check that rank is correct before trying to check dimensions
1761 if (len(input_shape) - 1) == len(output_shape):
1762 for i in range(len(input_shape)):
1763 if i == axis:
1764 axis_shift = 1
1765 continue
1766 if input_shape[i] != output_shape[i - axis_shift]:
1767 dimension_match = False
1768
1769 if not dimension_match:
1770 error_result = True
1771
1772 info_dict = {
1773 "error_name": error_name,
1774 "error_result": error_result,
1775 "error_reason": error_reason,
1776 "param_reqs": param_reqs,
1777 }
1778 return info_dict
1779
1780 @staticmethod
1781 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1782 error_name = ErrorIf.ArgmaxOutputRankMismatch
1783 param_reqs = {"rank": None, "dtype": None, "shape": None}
1784 error_result = False
1785 error_reason = (
1786 "Mismatch between output shape provided and expected output shape"
1787 )
1788
1789 if check:
1790 output_shape = kwargs["output_shape"]
1791 input_shape = kwargs["input_shape"]
1792 axis = kwargs["axis"]
1793 valid_params = axis >= 0 and axis < len(input_shape)
1794
1795 if valid_params and (len(input_shape) - 1) != len(output_shape):
1796 error_result = True
1797
1798 info_dict = {
1799 "error_name": error_name,
1800 "error_result": error_result,
1801 "error_reason": error_reason,
1802 "param_reqs": param_reqs,
1803 }
1804 return info_dict
1805
1806 @staticmethod
1807 def evKernelSmallerOne(check=False, **kwargs):
1808 error_name = ErrorIf.KernelSmallerOne
1809 param_reqs = {"rank": None, "dtype": None, "shape": None}
1810 error_result = False
1811 error_reason = "At least one kernel dimension is smaller than zero"
1812
1813 if check:
1814 kernel = kwargs["kernel"]
1815 if min(kernel) < 1:
1816 error_result = True
1817
1818 info_dict = {
1819 "error_name": error_name,
1820 "error_result": error_result,
1821 "error_reason": error_reason,
1822 "param_reqs": param_reqs,
1823 }
1824 return info_dict
1825
1826 @staticmethod
1827 def evStrideSmallerOne(check=False, **kwargs):
1828 error_name = ErrorIf.StrideSmallerOne
1829 param_reqs = {"rank": None, "dtype": None, "shape": None}
1830 error_result = False
1831 error_reason = "At least one stride dimension is smaller than zero"
1832
1833 if check:
1834 stride = kwargs["stride"]
1835 if min(stride) < 1:
1836 error_result = True
1837
1838 info_dict = {
1839 "error_name": error_name,
1840 "error_result": error_result,
1841 "error_reason": error_reason,
1842 "param_reqs": param_reqs,
1843 }
1844 return info_dict
1845
1846 @staticmethod
1847 def evDilationSmallerOne(check=False, **kwargs):
1848 error_result = check and min(kwargs["dilation"]) < 1
1849 return {
1850 "error_name": ErrorIf.DilationSmallerOne,
1851 "error_reason": "At least one dilation is smaller than one",
1852 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1853 "error_result": error_result,
1854 }
1855
1856 @staticmethod
1857 def evScaleTrue(check=False, **kwargs):
1858 error_name = ErrorIf.ScaleTrue
1859 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1860 error_result = False
1861 error_reason = "Scale set to true but input type is INT48"
1862
1863 if check:
1864 input_dtype = kwargs["input_dtype"]
1865 scale32 = kwargs["scale32"]
1866 if scale32 and input_dtype == DType.INT48:
1867 error_result = True
1868
1869 info_dict = {
1870 "error_name": error_name,
1871 "error_result": error_result,
1872 "error_reason": error_reason,
1873 "param_reqs": param_reqs,
1874 }
1875 return info_dict
1876
1877 @staticmethod
1878 def evScaleNotTrue(check=False, **kwargs):
1879 error_name = ErrorIf.ScaleNotTrue
1880 param_reqs = {"rank": None, "dtype": None, "shape": None}
1881 error_result = False
1882 error_reason = "Scale set to false but double round set to true"
1883
1884 if check:
1885 scale32 = kwargs["scale32"]
1886 double_round = kwargs["double_round"]
1887 if not scale32 and double_round:
1888 error_result = True
1889
1890 info_dict = {
1891 "error_name": error_name,
1892 "error_result": error_result,
1893 "error_reason": error_reason,
1894 "param_reqs": param_reqs,
1895 }
1896 return info_dict
1897
1898 @staticmethod
1899 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1900 error_name = ErrorIf.TensorSizeInputOutputMismatch
1901 param_reqs = {"rank": None, "dtype": None, "shape": None}
1902 error_result = False
1903 error_reason = "Input tensor size does not match output tensor size"
Jerry Ge264f7fa2023-04-21 22:49:57 +00001904 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001905
1906 if check:
1907 input_shape = kwargs["input_shape"]
1908 output_shape = kwargs["output_shape"]
Jerry Ge264f7fa2023-04-21 22:49:57 +00001909 shape_inferencing = False
1910 if -1 in output_shape and op["op"] == Op.RESHAPE:
1911 shape_inferencing = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001912 input_size = np.prod(input_shape)
1913 output_size = np.prod(output_shape)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001914 if input_size != output_size and not shape_inferencing:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001915 error_result = True
1916
1917 info_dict = {
1918 "error_name": error_name,
1919 "error_result": error_result,
1920 "error_reason": error_reason,
1921 "param_reqs": param_reqs,
1922 }
1923 return info_dict
1924
1925 @staticmethod
1926 def evStartSmallerZero(check=False, **kwargs):
1927 error_name = ErrorIf.StartSmallerZero
1928 param_reqs = {"rank": None, "dtype": None, "shape": None}
1929 error_result = False
1930 error_reason = "Starting point smaller than zero"
1931
1932 if check:
1933 input_shape = kwargs["input_shape"]
1934 start = kwargs["start"]
1935 rank = len(input_shape)
1936 if len(start) == rank:
1937 for index in range(rank):
1938 if start[index] < 0:
1939 error_result = True
1940
1941 info_dict = {
1942 "error_name": error_name,
1943 "error_result": error_result,
1944 "error_reason": error_reason,
1945 "param_reqs": param_reqs,
1946 }
1947 return info_dict
1948
1949 @staticmethod
1950 def evSizeSmallerEqualZero(check=False, **kwargs):
1951 error_name = ErrorIf.SizeSmallerEqualZero
1952 param_reqs = {"rank": None, "dtype": None, "shape": None}
1953 error_result = False
1954 error_reason = "Size smaller than or equal to zero"
1955
1956 if check:
1957 input_shape = kwargs["input_shape"]
1958 size = kwargs["size"]
1959 rank = len(input_shape)
1960 if len(size) == rank:
1961 for index in range(rank):
1962 if size[index] <= 0:
1963 error_result = True
1964
1965 info_dict = {
1966 "error_name": error_name,
1967 "error_result": error_result,
1968 "error_reason": error_reason,
1969 "param_reqs": param_reqs,
1970 }
1971 return info_dict
1972
1973 @staticmethod
1974 def evStartSizeOutsideBounds(check=False, **kwargs):
1975 error_name = ErrorIf.StartSizeOutsideBounds
1976 param_reqs = {"rank": None, "dtype": None, "shape": None}
1977 error_result = False
1978 error_reason = "starting point plus size larger than input dimension"
1979
1980 if check:
1981 input_shape = kwargs["input_shape"]
1982 start = kwargs["start"]
1983 size = kwargs["size"]
1984 rank = len(input_shape)
1985 if len(start) == rank and len(size) == rank:
1986 for index in range(rank):
1987 if start[index] + size[index] > input_shape[index]:
1988 error_result = True
1989
1990 info_dict = {
1991 "error_name": error_name,
1992 "error_result": error_result,
1993 "error_reason": error_reason,
1994 "param_reqs": param_reqs,
1995 }
1996 return info_dict
1997
1998 @staticmethod
1999 def evSizeOutputShapeMismatch(check=False, **kwargs):
2000 error_name = ErrorIf.SizeOutputShapeMismatch
2001 param_reqs = {"rank": None, "dtype": None, "shape": None}
2002 error_result = False
2003 error_reason = "Size does not match output dimension"
2004
2005 if check:
2006 input_shape = kwargs["input_shape"]
2007 output_shape = kwargs["output_shape"]
2008 size = kwargs["size"]
Luke Huttona4e48ca2023-02-22 11:53:48 +00002009
2010 if len(input_shape) == len(output_shape):
2011 rank = len(input_shape)
2012 if len(size) == rank:
2013 for index in range(rank):
2014 if size[index] != output_shape[index]:
2015 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002016
2017 info_dict = {
2018 "error_name": error_name,
2019 "error_result": error_result,
2020 "error_reason": error_reason,
2021 "param_reqs": param_reqs,
2022 }
2023 return info_dict
2024
2025 @staticmethod
2026 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2027 error_name = ErrorIf.InputSizeStartLengthMismatch
2028 param_reqs = {"rank": None, "dtype": None, "shape": None}
2029 error_result = False
2030 error_reason = "rank of input not equal to length of start or size"
2031
2032 if check:
2033 input_shape = kwargs["input_shape"]
2034 start = kwargs["start"]
2035 size = kwargs["size"]
2036 rank = len(input_shape)
2037 if rank != len(start) or rank != len(size):
2038 error_result = True
2039
2040 info_dict = {
2041 "error_name": error_name,
2042 "error_result": error_result,
2043 "error_reason": error_reason,
2044 "param_reqs": param_reqs,
2045 }
2046 return info_dict
2047
2048 @staticmethod
2049 def evIndexOutsideBounds(check=False, **kwargs):
2050 error_name = ErrorIf.IndexOutsideBounds
2051 param_reqs = {"rank": None, "dtype": None, "shape": None}
2052 error_result = False
2053 error_reason = "Index outside of allowed bounds"
2054
2055 if check:
2056 input_shape = kwargs["input_shape"]
2057 perms = kwargs["perms"]
2058 rank = len(input_shape)
2059
2060 for index in perms:
2061 if index < 0 or index > rank:
2062 error_result = True
2063
2064 info_dict = {
2065 "error_name": error_name,
2066 "error_result": error_result,
2067 "error_reason": error_reason,
2068 "param_reqs": param_reqs,
2069 }
2070 return info_dict
2071
2072 @staticmethod
2073 def evIndexUsedTwice(check=False, **kwargs):
2074 error_name = ErrorIf.IndexUsedTwice
2075 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2076 error_result = False
2077 error_reason = "Index used multiple times"
2078
2079 if check:
2080 perms = kwargs["perms"]
2081
2082 unique_indices = []
2083 for index in perms:
2084 if index in unique_indices:
2085 error_result = True
2086 else:
2087 unique_indices.append(index)
2088
2089 info_dict = {
2090 "error_name": error_name,
2091 "error_result": error_result,
2092 "error_reason": error_reason,
2093 "param_reqs": param_reqs,
2094 }
2095 return info_dict
2096
2097 @staticmethod
2098 def evMaxSmallerMin(check=False, **kwargs):
2099 error_name = ErrorIf.MaxSmallerMin
2100 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2101 error_result = False
2102 error_reason = "Max value smaller than min value"
2103
2104 if check:
2105 max_val = kwargs["max_val"]
2106 min_val = kwargs["min_val"]
2107 if max_val < min_val:
2108 error_result = True
2109
2110 info_dict = {
2111 "error_name": error_name,
2112 "error_result": error_result,
2113 "error_reason": error_reason,
2114 "param_reqs": param_reqs,
2115 }
2116 return info_dict
2117
2118 @staticmethod
2119 def evConcatInputRankMismatch(check=False, **kwargs):
2120 error_name = ErrorIf.ConcatInputRankMismatch
2121 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2122 error_result = False
2123 error_reason = "Input ranks are not identical"
2124
2125 if check:
2126 inputs = kwargs["inputs"]
2127 input_shape = kwargs["input_shape"]
2128 for input in inputs:
2129 if len(input.shape) != len(input_shape):
2130 error_result = True
2131
2132 info_dict = {
2133 "error_name": error_name,
2134 "error_result": error_result,
2135 "error_reason": error_reason,
2136 "param_reqs": param_reqs,
2137 }
2138 return info_dict
2139
2140 @staticmethod
2141 def evConcatInputDimMismatch(check=False, **kwargs):
2142 error_name = ErrorIf.ConcatInputDimMismatch
2143 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2144 error_result = False
2145 error_reason = "Input dimensions differ on too many axes"
2146
2147 if check:
2148 inputs = kwargs["inputs"]
2149 input_shape = kwargs["input_shape"]
2150 axis = kwargs["axis"]
2151
2152 # Ensure rank is valid before checking dims.
2153 valid_rank = True
2154 for input in inputs:
2155 if len(input.shape) != len(input_shape):
2156 valid_rank = False
2157
2158 if valid_rank:
2159 for input in inputs:
2160 for i, dim in enumerate(input.shape):
2161 if dim != input_shape[i] and axis != i:
2162 error_result = True
2163
2164 info_dict = {
2165 "error_name": error_name,
2166 "error_result": error_result,
2167 "error_reason": error_reason,
2168 "param_reqs": param_reqs,
2169 }
2170 return info_dict
2171
2172 @staticmethod
2173 def evConcatShapeSumMismatch(check=False, **kwargs):
2174 error_name = ErrorIf.ConcatShapeSumMismatch
2175 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2176 error_result = False
2177 error_reason = "Sum of dimensions on axis not equal to output dimension"
2178
2179 if check:
2180 inputs = kwargs["inputs"]
2181 input_shape = kwargs["input_shape"]
2182 output_shape = kwargs["output_shape"]
2183 axis = kwargs["axis"]
2184
2185 # Ensure rank is valid before checking dims.
2186 valid_params = True
2187 for input in inputs:
2188 if len(input.shape) != len(input_shape):
2189 valid_params = False
2190 if axis < 0 or axis > len(input_shape):
2191 valid_params = False
2192
2193 if valid_params:
2194 axis_dim_sum = 0
2195 for input in inputs:
2196 axis_dim_sum += input.shape[axis]
2197
2198 if axis_dim_sum != output_shape[axis]:
2199 error_result = True
2200
2201 info_dict = {
2202 "error_name": error_name,
2203 "error_result": error_result,
2204 "error_reason": error_reason,
2205 "param_reqs": param_reqs,
2206 }
2207 return info_dict
2208
2209 @staticmethod
2210 def evInputListThenGraphMismatch(check=False, **kwargs):
2211 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2212 param_reqs = {"rank": None, "dtype": None, "shape": None}
2213 error_result = False
2214 error_reason = "Input list shape does not match then-graph shape"
2215
2216 if check:
2217 a = kwargs["a"]
2218 b = kwargs["b"]
2219 basicBlocks = kwargs["basicBlocks"]
2220 then_block = basicBlocks[1]
2221 then_inputs = then_block.inputs
2222 then_tens = then_block.tensors
2223 if (a.shape != then_tens[then_inputs[0]].shape) or (
2224 b.shape != then_tens[then_inputs[1]].shape
2225 ):
2226 error_result = True
2227
2228 info_dict = {
2229 "error_name": error_name,
2230 "error_result": error_result,
2231 "error_reason": error_reason,
2232 "param_reqs": param_reqs,
2233 }
2234 return info_dict
2235
2236 @staticmethod
2237 def evInputListElseGraphMismatch(check=False, **kwargs):
2238 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2239 param_reqs = {"rank": None, "dtype": None, "shape": None}
2240 error_result = False
2241 error_reason = "Input list shape does not match else-graph shape"
2242
2243 if check:
2244 a = kwargs["a"]
2245 b = kwargs["b"]
2246 basicBlocks = kwargs["basicBlocks"]
2247 else_block = basicBlocks[2]
2248 else_inputs = else_block.inputs
2249 else_tens = else_block.tensors
2250 if (a.shape != else_tens[else_inputs[0]].shape) or (
2251 b.shape != else_tens[else_inputs[1]].shape
2252 ):
2253 error_result = True
2254
2255 info_dict = {
2256 "error_name": error_name,
2257 "error_result": error_result,
2258 "error_reason": error_reason,
2259 "param_reqs": param_reqs,
2260 }
2261 return info_dict
2262
2263 @staticmethod
2264 def evOutputListThenGraphMismatch(check=False, **kwargs):
2265 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2266 param_reqs = {"rank": None, "dtype": None, "shape": None}
2267 error_result = False
2268 error_reason = "Output list shape does not match then-graph shape"
2269
2270 if check:
2271 basicBlocks = kwargs["basicBlocks"]
2272 cond_block = basicBlocks[0]
2273 cond_outputs = cond_block.outputs
2274 cond_tens = cond_block.tensors
2275 then_block = basicBlocks[1]
2276 then_outputs = then_block.outputs
2277 then_tens = then_block.tensors
2278 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2279 error_result = True
2280
2281 info_dict = {
2282 "error_name": error_name,
2283 "error_result": error_result,
2284 "error_reason": error_reason,
2285 "param_reqs": param_reqs,
2286 }
2287 return info_dict
2288
2289 @staticmethod
2290 def evOutputListElseGraphMismatch(check=False, **kwargs):
2291 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2292 param_reqs = {"rank": None, "dtype": None, "shape": None}
2293 error_result = False
2294 error_reason = "Output list shape does not match else-graph shape"
2295
2296 if check:
2297 basicBlocks = kwargs["basicBlocks"]
2298 cond_block = basicBlocks[0]
2299 cond_outputs = cond_block.outputs
2300 cond_tens = cond_block.tensors
2301 else_block = basicBlocks[2]
2302 else_outputs = else_block.outputs
2303 else_tens = else_block.tensors
2304 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2305 error_result = True
2306
2307 info_dict = {
2308 "error_name": error_name,
2309 "error_result": error_result,
2310 "error_reason": error_reason,
2311 "param_reqs": param_reqs,
2312 }
2313 return info_dict
2314
2315 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002316 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2317 error_name = ErrorIf.CondIfCondNotMatchingBool
2318 param_reqs = {"rank": None, "dtype": None, "shape": None}
2319 error_result = False
2320 error_reason = "Conditional tensor does not match bool type"
2321
2322 if check:
2323 cond = kwargs["cond"]
2324 if cond.dtype != DType.BOOL:
2325 error_result = True
2326
2327 info_dict = {
2328 "error_name": error_name,
2329 "error_result": error_result,
2330 "error_reason": error_reason,
2331 "param_reqs": param_reqs,
2332 }
2333 return info_dict
2334
2335 @staticmethod
2336 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2337 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2338 param_reqs = {"rank": None, "dtype": None, "shape": None}
2339 error_result = False
2340 error_reason = "Conditional tensor is not equal to a size of one"
2341
2342 if check:
2343 cond = kwargs["cond"]
2344 # Size of 1 is equivalent to rank 0
2345 if len(cond.shape) != 0:
2346 error_result = True
2347
2348 info_dict = {
2349 "error_name": error_name,
2350 "error_result": error_result,
2351 "error_reason": error_reason,
2352 "param_reqs": param_reqs,
2353 }
2354 return info_dict
2355
2356 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002357 def evInputListOutputListMismatch(check=False, **kwargs):
2358 error_name = ErrorIf.InputListOutputListMismatch
2359 param_reqs = {"rank": None, "dtype": None, "shape": None}
2360 error_result = False
2361 error_reason = "Input list does not match output list"
2362
2363 if check:
2364 basicBlocks = kwargs["basicBlocks"]
2365 while_block = basicBlocks[0]
2366 while_inputs = while_block.inputs
2367 while_outputs = while_block.outputs
2368 while_tens = while_block.tensors
2369 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2370 error_result = True
2371
2372 info_dict = {
2373 "error_name": error_name,
2374 "error_result": error_result,
2375 "error_reason": error_reason,
2376 "param_reqs": param_reqs,
2377 }
2378 return info_dict
2379
2380 @staticmethod
2381 def evInputListCondGraphMismatch(check=False, **kwargs):
2382 error_name = ErrorIf.InputListCondGraphMismatch
2383 param_reqs = {"rank": None, "dtype": None, "shape": None}
2384 error_result = False
2385 error_reason = "Input list does not match cond graph"
2386
2387 if check:
2388 basicBlocks = kwargs["basicBlocks"]
2389 while_block = basicBlocks[0]
2390 while_inputs = while_block.inputs
2391 while_tens = while_block.tensors
2392 cond_block = basicBlocks[1]
2393 cond_inputs = cond_block.inputs
2394 cond_tens = cond_block.tensors
2395 if (
2396 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2397 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2398 error_result = True
2399
2400 info_dict = {
2401 "error_name": error_name,
2402 "error_result": error_result,
2403 "error_reason": error_reason,
2404 "param_reqs": param_reqs,
2405 }
2406 return info_dict
2407
2408 @staticmethod
2409 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2410 error_name = ErrorIf.InputListBodyGraphInputMismatch
2411 param_reqs = {"rank": None, "dtype": None, "shape": None}
2412 error_result = False
2413 error_reason = "Input list does not match body graph input"
2414
2415 if check:
2416 basicBlocks = kwargs["basicBlocks"]
2417 while_block = basicBlocks[0]
2418 while_inputs = while_block.inputs
2419 while_tens = while_block.tensors
2420 body_block = basicBlocks[2]
2421 body_outputs = body_block.inputs
2422 body_tens = body_block.tensors
2423 if (
2424 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2425 ) or (
2426 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2427 ):
2428 error_result = True
2429
2430 info_dict = {
2431 "error_name": error_name,
2432 "error_result": error_result,
2433 "error_reason": error_reason,
2434 "param_reqs": param_reqs,
2435 }
2436 return info_dict
2437
2438 @staticmethod
2439 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2440 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2441 param_reqs = {"rank": None, "dtype": None, "shape": None}
2442 error_result = False
2443 error_reason = "Input list does not match body graph output"
2444
2445 if check:
2446 basicBlocks = kwargs["basicBlocks"]
2447 while_block = basicBlocks[0]
2448 while_inputs = while_block.inputs
2449 while_tens = while_block.tensors
2450 body_block = basicBlocks[2]
2451 body_outputs = body_block.outputs
2452 body_tens = body_block.tensors
2453 if (
2454 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2455 ) or (
2456 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2457 ):
2458 error_result = True
2459 info_dict = {
2460 "error_name": error_name,
2461 "error_result": error_result,
2462 "error_reason": error_reason,
2463 "param_reqs": param_reqs,
2464 }
2465 return info_dict
2466
2467 @staticmethod
2468 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2469 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2470 param_reqs = {"rank": None, "dtype": None, "shape": None}
2471 error_result = False
2472 error_reason = "Cond graph output is not a match list of booleans"
2473
2474 if check:
2475 basicBlocks = kwargs["basicBlocks"]
2476 cond_block = basicBlocks[1]
2477 cond_outputs = cond_block.outputs
2478 cond_tens = cond_block.tensors
2479 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2480 error_result = True
2481
2482 info_dict = {
2483 "error_name": error_name,
2484 "error_result": error_result,
2485 "error_reason": error_reason,
2486 "param_reqs": param_reqs,
2487 }
2488 return info_dict
2489
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002490 @staticmethod
2491 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2492 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2493 param_reqs = {"rank": None, "dtype": None, "shape": None}
2494 error_result = False
2495 error_reason = "Cond graph output is not a shape of size one"
2496
2497 if check:
2498 basicBlocks = kwargs["basicBlocks"]
2499 cond_block = basicBlocks[1]
2500 cond_outputs = cond_block.outputs
2501 cond_tens = cond_block.tensors
2502 # Size of 1 is equivalent to rank 0
2503 if len(cond_tens[cond_outputs[0]].shape) != 0:
2504 error_result = True
2505
2506 info_dict = {
2507 "error_name": error_name,
2508 "error_result": error_result,
2509 "error_reason": error_reason,
2510 "param_reqs": param_reqs,
2511 }
2512 return info_dict
2513
Luke Hutton261b7b62023-01-10 14:50:31 +00002514 @staticmethod
2515 def evKernelNotPowerOfTwo(check=False, **kwargs):
2516 error_name = ErrorIf.KernelNotPowerOfTwo
2517 param_reqs = {"rank": None, "dtype": None, "shape": None}
2518 error_result = False
2519 error_reason = "kernel height and/or width not a power of two"
2520
2521 def is_power_of_two(x):
2522 return math.log(x, 2).is_integer()
2523
2524 if check:
2525 shape = kwargs["input_shape"]
2526 if len(shape) == 3:
2527 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2528 error_result = not valid_kernel
2529
2530 info_dict = {
2531 "error_name": error_name,
2532 "error_result": error_result,
2533 "error_reason": error_reason,
2534 "param_reqs": param_reqs,
2535 }
2536 return info_dict
2537
Luke Hutton57287132023-02-06 14:54:18 +00002538 @staticmethod
2539 def evFFTInputShapeMismatch(check=False, **kwargs):
2540 error_name = ErrorIf.FFTInputShapeMismatch
2541 param_reqs = {"rank": None, "dtype": None, "shape": None}
2542 error_result = False
2543 error_reason = "Mismatch between real and imaginary input shapes"
2544
2545 if check:
2546 input1 = kwargs["input1"]
2547 input2 = kwargs["input2"]
2548
2549 if input1.shape != input2.shape:
2550 error_result = True
2551
2552 info_dict = {
2553 "error_name": error_name,
2554 "error_result": error_result,
2555 "error_reason": error_reason,
2556 "param_reqs": param_reqs,
2557 }
2558 return info_dict
2559
2560 @staticmethod
2561 def evFFTOutputShapeMismatch(check=False, **kwargs):
2562 error_name = ErrorIf.FFTOutputShapeMismatch
2563 param_reqs = {"rank": None, "dtype": None, "shape": None}
2564 error_result = False
2565 error_reason = (
2566 "Mismatch between provided and expected output kernel (H, W) shape"
2567 )
2568
2569 if check:
2570 op = kwargs["op"]
2571 input_shape = kwargs["input_shape"]
2572
2573 if len(input_shape) == 3:
2574 output_shapes = kwargs["output_shape"]
2575
2576 # Ignoring batch size (N) from input shape
2577 expected_shape = input_shape[1:]
2578 if op["op"] == Op.RFFT2D:
2579 expected_shape[1] = expected_shape[1] // 2 + 1
2580
2581 # Ignoring batch size (N) from output shapes
2582 output_shape_0 = output_shapes[0][1:]
2583 output_shape_1 = output_shapes[1][1:]
2584 # Ensure sure the kernel sizes (H, W) of both outputs match the expected
2585 if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
2586 error_result = True
2587
2588 info_dict = {
2589 "error_name": error_name,
2590 "error_result": error_result,
2591 "error_reason": error_reason,
2592 "param_reqs": param_reqs,
2593 }
2594 return info_dict
2595
Jerry Ge264f7fa2023-04-21 22:49:57 +00002596 @staticmethod
Jerry Ge135c9552023-05-23 20:59:32 +00002597 def calculateBroadcastShape(input_shape_a, input_shape_b):
2598 if input_shape_a is not None and input_shape_b is not None:
2599 calculated_shape = input_shape_a.copy()
2600 for idx in range(len(calculated_shape)):
2601 if calculated_shape[idx] == 1:
2602 calculated_shape[idx] = input_shape_b[idx]
2603 elif (
2604 input_shape_b[idx] != 1
2605 and input_shape_b[idx] != calculated_shape[idx]
2606 ):
2607 return None
2608 return calculated_shape
2609 else:
2610 return None
2611
2612 @staticmethod
2613 def evBroadcastShapesMismatch(check=False, **kwargs):
2614 error_name = ErrorIf.BroadcastShapesMismatch
2615 param_reqs = {"rank": None, "dtype": None, "shape": None}
2616 error_result = False
2617 error_reason = "Broadcast shape calculating failed"
2618
2619 if check:
2620 input_shape_a = kwargs["input1"].shape
2621 input_shape_b = kwargs["input2"].shape
2622 input_shape_c = (
2623 kwargs["input3"].shape if "input3" in kwargs else input_shape_b
2624 )
2625
2626 if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
2627 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
2628 input_shape_c,
2629 TosaErrorValidator.calculateBroadcastShape(
2630 input_shape_a, input_shape_b
2631 ),
2632 )
2633 error_result = calculated_shape is None
2634
2635 info_dict = {
2636 "error_name": error_name,
2637 "error_result": error_result,
2638 "error_reason": error_reason,
2639 "param_reqs": param_reqs,
2640 }
2641 return info_dict
2642
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002643 def evWrongAccumulatorType(check=False, **kwargs):
2644 error_name = ErrorIf.WrongAccumulatorType
2645 param_reqs = {"rank": None, "dtype": None, "shape": None}
2646 error_result = False
2647 error_reason = "An unsupported accumulator data type was requested"
2648
2649 if check:
2650 op = kwargs["op"]
2651 input_dtype = kwargs["input_dtype"]
2652 accum_dtype = kwargs["accum_dtype"]
2653 if op["op"] == Op.AVG_POOL2D:
2654 if (
2655 input_dtype
2656 in (
2657 DType.INT8,
2658 DType.INT16,
2659 )
2660 and accum_dtype != DType.INT32
2661 ):
2662 error_result = True
2663 elif (
2664 input_dtype
2665 in (
2666 DType.FP32,
2667 DType.BF16,
2668 )
2669 and accum_dtype != DType.FP32
2670 ):
2671 error_result = True
2672 elif input_dtype == DType.FP16 and accum_dtype not in (
2673 DType.FP16,
2674 DType.FP32,
2675 ):
2676 error_result = True
Won Jeon2c34b462024-02-06 18:37:00 +00002677 elif (
2678 input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
2679 and accum_dtype != DType.FP16
2680 ):
2681 error_result = True
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002682
Tai Lyf36f2562024-03-14 16:21:29 +00002683 elif op["op"] in {
2684 Op.CONV2D,
2685 Op.CONV3D,
2686 Op.DEPTHWISE_CONV2D,
2687 Op.TRANSPOSE_CONV2D,
2688 }:
2689 if input_dtype == DType.INT8 and accum_dtype != DType.INT32:
2690 error_result = True
2691 elif input_dtype == DType.INT16 and accum_dtype != DType.INT48:
2692 error_result = True
2693 elif (
2694 input_dtype
2695 in (
2696 DType.FP32,
2697 DType.BF16,
2698 )
2699 and accum_dtype != DType.FP32
2700 ):
2701 error_result = True
2702 elif input_dtype == DType.FP16 and accum_dtype not in (
2703 DType.FP16,
2704 DType.FP32,
2705 ):
2706 error_result = True
2707 elif (
2708 input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
2709 and accum_dtype != DType.FP16
2710 ):
2711 error_result = True
2712
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002713 info_dict = {
2714 "error_name": error_name,
2715 "error_result": error_result,
2716 "error_reason": error_reason,
2717 "param_reqs": param_reqs,
2718 }
2719 return info_dict
2720
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002721
2722class TosaInvalidValidator:
2723 @staticmethod
2724 def ivWrongDataTypeOrModeResize(**kwargs):
2725 input_dtype = kwargs["input_dtype"]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002726 args_dict = kwargs["args"]
2727 mode = args_dict["mode"]
2728 output_dtype = args_dict["output_dtype"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002729
2730 if mode == ResizeMode.BILINEAR:
2731 # Invalid output data type / Invalid input datatype
2732 return (
2733 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002734 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002735 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002736 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002737 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002738 )
2739 elif mode == ResizeMode.NEAREST:
2740 # Invalid output data type / Invalid input datatype
2741 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002742 input_dtype
2743 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002744 )
2745 else:
2746 # Invalid resize mode
2747 return True
2748
2749 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002750 def ivHeightWidthInvalid(**kwargs):
2751 opName = kwargs["opName"]
2752
2753 inputShapes = kwargs["shapeList"]
2754 input_shape = inputShapes[0]
2755
2756 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002757
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002758 if isinstance(args, dict):
2759 args_dict = args
2760 else:
2761 # Create args_dict from list elements
2762 # TODO - Remove this once all NWHC operators agFunctions have been
2763 # converted to args_dict output
2764
2765 # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
2766 stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
2767 args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]}
2768 # Alias different info for each op
2769 args_dict["kernel"] = args[pad_idx + 1]
2770 args_dict["out_shape"] = args[pad_idx + 1]
2771 args_dict["dilation"] = args[pad_idx + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002772
2773 # Common info for all ops
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002774 strides = args_dict["stride"]
2775 padding = args_dict["pad"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002776
2777 if opName.endswith("pool2d"):
2778 # avg_pool2d, max_pool2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002779 kernel_shape = args_dict["kernel"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002780 h = (
2781 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2782 ) // strides[0]
2783 w = (
2784 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2785 ) // strides[1]
2786 # return True if any dimension is < 1
2787 return h < 1 or w < 1
2788
2789 if opName.startswith("transpose_conv2d"):
2790 # transpose_conv2d
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002791 filter_shape = inputShapes[1]
2792 kernel_shape = filter_shape[1:-1]
2793
TatWai Chong24594f52022-06-08 00:48:04 -07002794 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002795 """Calculate the transpose_conv2d output size for a dimension."""
2796 return (in_size - 1) * stride + kernel_size + in_pad + out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002797
Jeremy Johnson0c716862023-04-13 17:18:19 +01002798 h = get_out_size(
2799 input_shape[1],
2800 strides[0],
2801 kernel_shape[0],
2802 padding[0],
2803 padding[1],
2804 )
2805 w = get_out_size(
2806 input_shape[2],
2807 strides[1],
2808 kernel_shape[1],
2809 padding[2],
2810 padding[3],
2811 )
Suraj Sudhirb5fcfc02024-04-16 16:14:36 -07002812 return h < 1 or w < 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002813
2814 if "conv2d" in opName or "conv3d" in opName:
2815 # conv2d, conv3d, depthwise_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002816 dilations = args_dict["dilation"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002817 filter_shape = inputShapes[1]
2818 kernel_shape = (
2819 filter_shape[0:2]
2820 if opName.startswith("depthwise_conv2d")
2821 else filter_shape[1:-1]
2822 )
2823
2824 for i in range(len(kernel_shape)):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002825 pad_offset = i * 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002826 dim = (
2827 input_shape[i + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002828 - 1
2829 + padding[pad_offset]
2830 + padding[pad_offset + 1]
2831 - (kernel_shape[i] - 1) * dilations[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002832 ) // strides[i] + 1
2833 # return True if any dimension is < 1
2834 if dim < 1:
2835 return True
2836 return False
2837
2838 assert False, f"Unrecognized Op: {opName}"
2839
2840 @staticmethod
2841 def ivNonPositiveOutputShape(**kwargs):
2842 args = kwargs["args"]
Jeremy Johnson95a67102024-01-10 14:16:39 +00002843 output_shape = args["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002844 if output_shape[1] <= 0 or output_shape[2] <= 0:
2845 # Negative output shape
2846 return True
2847 return False