blob: 1b6b044011b8b8ec265bef365912ac41493cd958 [file] [log] [blame]
Won Jeon74342e52024-01-09 00:34:40 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003import logging
Luke Hutton261b7b62023-01-10 14:50:31 +00004import math
5
Jeremy Johnsondd975b82024-02-28 17:29:13 +00006import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from tosa.DType import DType
9from tosa.Op import Op
10from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011
Jeremy Johnsonaf090182024-02-13 18:25:39 +000012logging.basicConfig()
13logger = logging.getLogger("tosa_verif_build_tests")
14
Matthew Haddone86fd342021-09-07 16:12:21 +010015
16class ErrorIf(object):
17 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010018 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
19 ScaleNLargerMax = "ScaleNLargerMax"
20 ScaleDLargerMax = "ScaleDLargerMax"
21 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010022 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010023 BorderSmallerMin = "BorderSmallerMin"
24 BorderLargerEqualMax = "BorderLargerEqualMax"
25 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
26 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010027 WrongInputType = "WrongInputType"
28 WrongOutputType = "WrongOutputType"
29 WrongInputList = "WrongInputList"
30 WrongOutputList = "WrongOutputList"
31 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010032 BatchMismatch = "BatchMismatch"
33 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010034 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000035 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010036 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010037 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010038 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010039 AxisSmallerZero = "AxisSmallerZero"
40 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010041 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
42 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010043 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010044 KernelSmallerOne = "KernelSmallerOne"
45 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000046 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010047 PadSmallerZero = "PadSmallerZero"
48 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010049 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010050 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010051 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
52 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
53 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010054 ScaleNotTrue = "ScaleNotTrue"
55 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010056 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
57 StartSmallerZero = "StartSmallerZero"
58 SizeSmallerEqualZero = "SizeSmallerEqualZero"
59 StartSizeOutsideBounds = "StartSizeOutsideBounds"
60 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
61 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
62 IndexOutsideBounds = "IndexOutsideBounds"
63 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010064 MaxSmallerMin = "MaxSmallerMin"
65 ConcatInputRankMismatch = "ConcatInputRankMismatch"
66 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010067 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010068 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
69 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
70 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
71 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
72 InputListOutputListMismatch = "InputListOutputListMismatch"
73 InputListCondGraphMismatch = "InputListCondGraphMismatch"
74 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
75 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
76 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010077 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
78 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000079 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
80 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
81 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000082 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Luke Hutton57287132023-02-06 14:54:18 +000083 FFTInputShapeMismatch = "FFTInputShapeMismatch"
84 FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
Jerry Ge264f7fa2023-04-21 22:49:57 +000085 ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
86 ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
Jerry Ge135c9552023-05-23 20:59:32 +000087 BroadcastShapesMismatch = "BroadcastShapesMismatch"
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +000088 WrongAccumulatorType = "WrongAccumulatorType"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089
90
91class TosaErrorIfArgGen:
92 @staticmethod
93 def eiResizeErrorIf(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010094 rng,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010095 error_name,
96 mode,
97 dtype,
98 shapeList,
99 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100100 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100102 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100104 if error_name == ErrorIf.ScaleSmallerEqualZero:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100105 index = rng.randInt(low=0, high=4)
106 scale[index] = rng.choice([-2, -1, 0])
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100107 elif error_name == ErrorIf.ScaleNLargerMax:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100108 index = rng.choice([0, 2])
109 scale[index] = (1 << 11) + rng.choice([1, 2, 3])
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100110 elif error_name == ErrorIf.ScaleDLargerMax:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100111 index = rng.choice([1, 3])
112 scale[index] = 16 * scale[index - 1] + rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100113
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100114 if error_name == ErrorIf.OffsetLargerEqualMax:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100115 index = rng.choice([0, 1])
116 offset[index] = 16 * scale[index * 2] + rng.choice([0, 1, 2])
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100117 elif error_name == ErrorIf.OffsetSmallerMin:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100118 index = rng.choice([0, 1])
119 offset[index] = -scale[index * 2] - rng.choice([1, 2, 3])
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100120
121 if error_name == ErrorIf.BorderLargerEqualMax:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100122 index = rng.choice([0, 1])
123 border[index] = scale[index * 2] + rng.choice([0, 1, 2])
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100124 elif error_name == ErrorIf.BorderSmallerMin:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100125 index = rng.choice([0, 1])
126 border[index] = -16 * scale[index * 2] - rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100127
128 if error_name == ErrorIf.WrongOutputType:
129 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
130 incorrect_types = (
131 DType.INT4,
132 DType.INT16,
133 DType.INT32,
134 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100135 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100136 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100137 )
138 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
139 incorrect_types = (
140 DType.INT4,
141 DType.INT8,
142 DType.INT32,
143 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100144 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100145 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100146 )
147 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
148 incorrect_types = (
149 DType.INT4,
150 DType.INT8,
151 DType.INT16,
152 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100153 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100154 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100155 )
156 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
157 incorrect_types = (
158 DType.INT4,
159 DType.INT8,
160 DType.INT16,
161 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100162 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100163 DType.FP16,
164 )
165 elif dtype == DType.FP16:
166 incorrect_types = (
167 DType.INT4,
168 DType.INT8,
169 DType.INT16,
170 DType.INT32,
171 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100172 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100173 )
James Ward24dbc422022-10-19 12:20:31 +0100174 elif dtype == DType.BF16:
175 incorrect_types = (
176 DType.INT4,
177 DType.INT8,
178 DType.INT16,
179 DType.INT32,
180 DType.INT48,
181 DType.FP32,
182 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100183 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100184 incorrect_types = (
185 DType.INT4,
186 DType.INT8,
187 DType.INT16,
188 DType.INT32,
189 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100190 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100191 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100192 outputDType = rng.choice(a=incorrect_types)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100194 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100197 def eiPoolingErrorIf(rng, error_name, stride, pad, kernel):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100198 if (
199 error_name == ErrorIf.StrideSmallerOne
200 # padding must not exceed the kernel size
201 and pad[0] < kernel[0]
202 and pad[1] < kernel[0]
203 and pad[2] < kernel[1]
204 and pad[3] < kernel[1]
205 ):
206 wrongStride = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100207 rng.choice([0, -1, -2, -3]),
208 rng.choice([0, -1, -2, -3]),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100209 )
210 return wrongStride, pad, kernel
211 elif error_name == ErrorIf.PadSmallerZero:
212 wrongPad = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100213 rng.choice([-1, -2, -3]),
214 rng.choice([-1, -2, -3]),
215 rng.choice([-1, -2, -3]),
216 rng.choice([-1, -2, -3]),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217 )
218 return stride, wrongPad, kernel
219 elif error_name == ErrorIf.KernelSmallerOne:
220 wrongKernel = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100221 rng.choice([0, -1, -2, -3]),
222 rng.choice([0, -1, -2, -3]),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100223 )
224 return stride, pad, wrongKernel
225 elif error_name == ErrorIf.PadLargerEqualKernel:
226 wrongPad = (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100227 rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
228 rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
229 rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
230 rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100231 )
232 return stride, wrongPad, kernel
233 else:
234 return None, None, None
235
236 @staticmethod
237 def eiRescaleWrongOutputType(input_dtype, output_dtype):
238 if input_dtype == DType.INT8:
239 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
240 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100241 elif input_dtype == DType.INT16:
242 if output_dtype not in [
243 DType.UINT8,
244 DType.INT8,
245 DType.UINT16,
246 DType.INT16,
247 DType.INT32,
248 ]:
249 return True
250 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100251 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
252 return True
253 elif input_dtype == DType.INT48:
254 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
255 return True
256 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100257 if output_dtype not in [DType.INT8, DType.INT16]:
258 return True
259 elif input_dtype == DType.UINT16:
260 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100261 return True
262 return False
263
264 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100265 def eiInvalidateInputOutputList(rng, error_name, input_list, output_list):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100266 # Mess up input/output tensors for ERROR_IF checks
267 if error_name == "WrongInputList":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100268 add_input = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100269 if add_input:
270 input_list.append("eiDummyInput")
271 else:
272 input_list = input_list[:-1]
273 elif error_name == "WrongOutputList":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100274 add_output = rng.choice([True, False])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100275 if add_output:
276 output_list.append("eiDummyOutput")
277 else:
278 output_list = []
279 return input_list, output_list
280
281 @staticmethod
282 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
283 """Restrict the dimensions and overall size of a shape to
284 max_dim and max_items.
285 """
286 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000287 while gtu.product(new_shape) > max_items:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100288 new_shape = [max(d - 1, 1) for d in new_shape]
289 return new_shape
290
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100291 def eiSliceErrorIf(rng, error_name, input_shape, start, size):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100292 if error_name == ErrorIf.StartSmallerZero:
293 newStart = []
294 for i in range(len(input_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100295 newStart.append(rng.choice([-3, -2, -1]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100296 return newStart, size
297 elif error_name == ErrorIf.SizeSmallerEqualZero:
298 newSize = []
299 for i in range(len(input_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100300 newSize.append(rng.choice([-3, -2, -1, 0]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100301 return start, newSize
302 elif error_name == ErrorIf.StartSizeOutsideBounds:
303 newStart, newSize = [], []
304 for i in range(len(input_shape)):
305 newStart.append(input_shape[i] - 1)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100306 newSize.append(rng.choice([2, 3, 4]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100307 return newStart, newSize
308 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100309 remove = rng.choice([True, False])
TatWai Chongf15bad82024-01-31 21:33:27 -0800310
311 # Get an empty tensor when diminishing dimension on 1-d tensor.
312 if len(start) == 1 or len(size) == 1:
313 remove = False
314
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100315 if remove:
316 newStart = start[1:]
317 newSize = size[1:]
318 else:
319 newStart = start
320 newStart.append(1)
321 newSize = size
322 newSize.append(1)
323 return newStart, newSize
324 else:
325 return start, size
326
327 @staticmethod
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100328 def eiCastErrorIf(input_dtype):
Won Jeon2c34b462024-02-06 18:37:00 +0000329 if input_dtype in [DType.BOOL]:
330 outputDType = [
331 DType.BOOL,
332 DType.INT48,
333 DType.FP32,
334 DType.FP16,
335 DType.BF16,
336 DType.FP8E4M3,
337 DType.FP8E5M2,
338 ]
339 elif input_dtype in [DType.FP32]:
James Ward736fd1a2023-01-23 17:13:37 +0000340 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
341 elif input_dtype in [DType.FP16, DType.BF16]:
342 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100343 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
344 outputDType = [DType.INT48]
Won Jeon2c34b462024-02-06 18:37:00 +0000345 elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]:
346 outputDType = [
347 DType.BOOL,
348 DType.INT8,
349 DType.INT16,
350 DType.INT32,
351 DType.INT48,
352 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100353 else:
James Ward736fd1a2023-01-23 17:13:37 +0000354 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100355 return outputDType
356
357
358class TosaErrorValidator:
359 @staticmethod
360 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
361 """Check ERROR_IF statements are caught and set the expected result.
362
363 Args:
364 serializer: the serializer to set the expected result in
365 validator_fcns: a sequence of validator functions to verify the result
366 error_name: the name of the ERROR_IF condition to check for
367 kwargs: keyword arguments for the validator functions
368 Returns:
369 True if the result matches the expected result; otherwise False
370 """
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000371 if validator_fcns is None:
372 # Nothing to do
373 return True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100374 overall_result = True
375 for val_fcn in validator_fcns:
376 val_result = val_fcn(True, **kwargs)
377 validator_name = val_result["error_name"]
378 error_result = val_result["error_result"]
379 error_reason = val_result["error_reason"]
380
381 # expect an error IFF the error_name and validator_name match
382 expected_result = error_result == (error_name == validator_name)
383 overall_result &= expected_result
384
385 if expected_result and error_result:
386 serializer.setExpectedReturnCode(2, True, desc=error_reason)
387 elif error_result: # and not expected_result
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000388 logger.error(
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000389 f"Unexpected ERROR_IF: Op: {gtu.valueToName(Op, kwargs['op']['op'])}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100390 f" Expected: {error_name}, Got: {validator_name}"
391 )
392 elif not expected_result: # and not error_result
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000393 logger.error(
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000394 f"Missed ERROR_IF: Op: {gtu.valueToName(Op, kwargs['op']['op'])}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100395 f" Expected: {error_name}"
396 )
397
398 if not expected_result:
399 for k, v in sorted(kwargs.items()):
400 if k != "op":
401 if k.endswith("dtype"):
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000402 v = gtu.valueToName(DType, v)
Jeremy Johnsonaf090182024-02-13 18:25:39 +0000403 logger.error(f" {k} = {v}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100404
405 return overall_result
406
407 @staticmethod
408 def evWrongInputType(check=False, **kwargs):
409 error_result = False
410
411 # Find the unsupported input data types
412 op = kwargs["op"]
413 input_dtypes = op["types"]
414 allowed_input_dtypes = {
415 t[0] if isinstance(t, list) else t for t in input_dtypes
416 }
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000417 wrong_input_dtypes = list(gtu.usableDTypes(excludes=allowed_input_dtypes))
418 assert len(wrong_input_dtypes) > 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100419
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100420 # Turn the wrong dtypes into required list of types
421 if op["op"] in [
422 Op.FULLY_CONNECTED,
423 Op.CONV2D,
424 Op.CONV3D,
425 Op.DEPTHWISE_CONV2D,
426 Op.TRANSPOSE_CONV2D,
427 ]:
428 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
429
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100430 if op["op"] == Op.CLAMP:
431 wrong_input_dtypes.remove(DType.INT48)
432
433 if check:
434 input_dtype = kwargs["input_dtype"]
435 if input_dtype not in allowed_input_dtypes:
436 error_result = True
437
438 info_dict = {
439 "error_name": ErrorIf.WrongInputType,
440 "error_result": error_result,
441 "error_reason": "Input data type not supported for this operator",
442 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
443 }
444 return info_dict
445
446 @staticmethod
447 def evWrongOutputType(check=False, **kwargs):
448 error_result = False
449
450 if check:
451 input_dtype = kwargs["input_dtype"]
452 output_dtype = kwargs["output_dtype"]
453 op = kwargs["op"]
454
455 if op["op"] == Op.RESIZE:
456 mode = kwargs["mode"]
457 if (
458 (
459 mode == ResizeMode.NEAREST
460 and input_dtype == DType.INT8
461 and output_dtype != DType.INT8
462 )
463 or (
464 mode == ResizeMode.NEAREST
465 and input_dtype == DType.INT16
466 and output_dtype != DType.INT16
467 )
468 or (
469 mode == ResizeMode.BILINEAR
470 and input_dtype == DType.INT8
471 and output_dtype != DType.INT32
472 )
473 or (
474 mode == ResizeMode.BILINEAR
475 and input_dtype == DType.INT16
476 and output_dtype != DType.INT48
477 )
James Ward8b390432022-08-12 20:48:56 +0100478 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100479 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100480 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100481 ):
482 error_result = True
483
484 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100485 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
486 input_dtype, output_dtype
487 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100488
489 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
490 if (
491 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
492 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100493 or (
494 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100495 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100496 )
James Ward24dbc422022-10-19 12:20:31 +0100497 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100498 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Won Jeon2c34b462024-02-06 18:37:00 +0000499 or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16)
500 or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 ):
502 error_result = True
503
504 elif op["op"] == Op.ARGMAX:
505 if (
James Ward24dbc422022-10-19 12:20:31 +0100506 input_dtype
Won Jeon2c34b462024-02-06 18:37:00 +0000507 in [
508 DType.INT8,
509 DType.INT16,
510 DType.FP16,
511 DType.BF16,
512 DType.FP32,
513 DType.FP8E4M3,
514 DType.FP8E5M2,
515 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100516 and output_dtype != DType.INT32
517 ):
518 error_result = True
519
520 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100521 if (
James Ward24dbc422022-10-19 12:20:31 +0100522 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100523 and output_dtype != DType.INT32
524 ):
525 error_result = True
526 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100527 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100528 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
529 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100530 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100531 error_result = True
532
533 elif op["op"] == Op.TABLE:
534 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
535 error_result = True
536 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
537 error_result = True
538
539 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
540 if output_dtype != DType.BOOL:
541 error_result = True
542
543 elif op["op"] == Op.CAST:
544 if (
545 (
546 input_dtype == DType.BOOL
547 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
548 )
549 or (
550 input_dtype == DType.INT8
551 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100552 not in [
553 DType.BOOL,
554 DType.INT16,
555 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100556 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100557 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100558 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100559 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100560 )
561 or (
562 input_dtype == DType.INT16
563 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100564 not in [
565 DType.BOOL,
566 DType.INT8,
567 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100568 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100569 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100570 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100571 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100572 )
573 or (
574 input_dtype == DType.INT32
575 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100576 not in [
577 DType.BOOL,
578 DType.INT8,
579 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100580 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100581 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100582 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100583 ]
584 )
585 or (
586 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000587 and output_dtype
Won Jeon2c34b462024-02-06 18:37:00 +0000588 not in [
589 DType.INT8,
590 DType.INT16,
591 DType.INT32,
592 DType.FP32,
593 DType.FP8E4M3,
594 DType.FP8E5M2,
595 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100596 )
597 or (
James Ward24dbc422022-10-19 12:20:31 +0100598 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000599 and output_dtype
Won Jeon2c34b462024-02-06 18:37:00 +0000600 not in [
601 DType.INT8,
602 DType.INT16,
603 DType.INT32,
604 DType.FP32,
605 DType.FP8E4M3,
606 DType.FP8E5M2,
607 ]
James Ward24dbc422022-10-19 12:20:31 +0100608 )
609 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100610 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000611 and output_dtype
612 not in [
613 DType.INT8,
614 DType.INT16,
615 DType.INT32,
616 DType.FP16,
617 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000618 DType.FP8E4M3,
619 DType.FP8E5M2,
620 ]
621 )
622 or (
623 input_dtype in [DType.FP8E4M3, DType.FP8E5M2]
624 and output_dtype
625 not in [
626 DType.FP16,
627 DType.BF16,
628 DType.FP32,
James Ward736fd1a2023-01-23 17:13:37 +0000629 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100630 )
631 ):
632 error_result = True
633
Luke Hutton57287132023-02-06 14:54:18 +0000634 elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000635 if not all([ty == input_dtype for ty in output_dtype]):
636 error_result = True
637
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100638 elif op["op"] in {
639 Op.CONV2D,
640 Op.CONV3D,
641 Op.DEPTHWISE_CONV2D,
642 Op.TRANSPOSE_CONV2D,
643 }:
644 if (
645 input_dtype == DType.INT8
646 and output_dtype != DType.INT32
647 or input_dtype == DType.INT16
648 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100649 or input_dtype == DType.FP16
Tai Lyf36f2562024-03-14 16:21:29 +0000650 and output_dtype != DType.FP16
James Ward24dbc422022-10-19 12:20:31 +0100651 or input_dtype == DType.BF16
Tai Lyf36f2562024-03-14 16:21:29 +0000652 and output_dtype != DType.BF16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100653 or input_dtype == DType.FP32
654 and output_dtype != DType.FP32
Won Jeon2c34b462024-02-06 18:37:00 +0000655 or input_dtype == DType.FP8E4M3
656 and output_dtype != DType.FP16
657 or input_dtype == DType.FP8E5M2
658 and output_dtype != DType.FP16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100659 ):
660 error_result = True
661 # invalid input types are ignored, to avoid reporting multiple errors
662
Won Jeon74342e52024-01-09 00:34:40 +0000663 elif op["op"] in {Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE}:
664 if output_dtype != DType.SHAPE:
665 error_result = True
666
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100667 else:
668 if output_dtype != input_dtype:
669 error_result = True
670
671 info_dict = {
672 "error_name": ErrorIf.WrongOutputType,
673 "error_result": error_result,
674 "error_reason": (
675 "Output data type not supported for this configuration of operator"
676 ),
677 "param_reqs": {"rank": None, "dtype": None, "shape": None},
678 }
679 return info_dict
680
681 @staticmethod
682 def evWrongRank(check=False, **kwargs):
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000683 # From 1 to MAX_TENSOR_RANK+1 inclusively
684 all_ranks = tuple(range(1, gtu.MAX_TENSOR_RANK + 2))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100685
686 # Make a list of incorrect ranks
687 assert "op" in kwargs
688 op = kwargs["op"]
689 rmin, rmax = op["rank"]
690 rank_range = range(rmin, rmax + 1)
691 incorrect_ranks = list(set(all_ranks) - set(rank_range))
692 # Remove small incorrect ranks to avoid index errors
693 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
694 # Set minimum incorrect rank to 3 to avoid index error
695 if op["op"] in [Op.RESIZE]:
696 incorrect_ranks = [3, 5]
697 elif op["op"] in [Op.TRANSPOSE]:
698 incorrect_ranks = [7, 8]
699 elif op["op"] in [Op.CONV3D]:
700 incorrect_ranks = [6, 7]
701
702 error_name = ErrorIf.WrongRank
703 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
704 error_result = False
705 error_reason = "Rank not supported for this operator"
706
707 if check:
708 input_shape = kwargs["input_shape"]
709
710 if (
711 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
712 and len(input_shape) != 4
713 ):
714 error_result = True
715 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
716 error_result = True
717 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
718 error_result = True
719 else:
720 if len(input_shape) not in rank_range:
721 error_result = True
722
723 info_dict = {
724 "error_name": error_name,
725 "error_result": error_result,
726 "error_reason": error_reason,
727 "param_reqs": param_reqs,
728 }
729 return info_dict
730
731 @staticmethod
732 def evWrongInputList(check=False, **kwargs):
733 error_name = ErrorIf.WrongInputList
734 param_reqs = {"rank": None, "dtype": None, "shape": None}
735 error_result = False
736 error_reason = "Op input list does not match expected input"
737
738 if check:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100739 input_list = kwargs["input_list"]
740 num_operands = kwargs["num_operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100741 if len(input_list) != num_operands:
742 error_result = True
743
744 info_dict = {
745 "error_name": error_name,
746 "error_result": error_result,
747 "error_reason": error_reason,
748 "param_reqs": param_reqs,
749 }
750 return info_dict
751
752 @staticmethod
753 def evWrongOutputList(check=False, **kwargs):
754 error_name = ErrorIf.WrongOutputList
755 param_reqs = {"rank": None, "dtype": None, "shape": None}
756 error_result = False
757 error_reason = "Op output list does not match expected output"
758
759 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000760 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100761 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000762 expected_length = 1
Luke Hutton57287132023-02-06 14:54:18 +0000763 if op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000764 expected_length = 2
765
766 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100767 error_result = True
768
769 info_dict = {
770 "error_name": error_name,
771 "error_result": error_result,
772 "error_reason": error_reason,
773 "param_reqs": param_reqs,
774 }
775 return info_dict
776
777 @staticmethod
778 def evMaxDimExceeded(check=False, **kwargs):
779 error_name = ErrorIf.MaxDimExceeded
780 param_reqs = {
781 "rank": [4, 4],
782 "dtype": [DType.INT8],
783 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
784 }
785 error_result = False
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000786 error_reason = f"At least one maximum dimension is greater than or equal to {gtu.MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100787
788 if check:
789 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100790 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100791 if (
Jeremy Johnsondd975b82024-02-28 17:29:13 +0000792 (input_shape[1] >= gtu.MAX_RESIZE_DIMENSION)
793 or (input_shape[2] >= gtu.MAX_RESIZE_DIMENSION)
794 or (output_shape[1] >= gtu.MAX_RESIZE_DIMENSION)
795 or (output_shape[2] >= gtu.MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100796 ):
797 error_result = True
798
799 info_dict = {
800 "error_name": error_name,
801 "error_result": error_result,
802 "error_reason": error_reason,
803 "param_reqs": param_reqs,
804 }
805 return info_dict
806
807 @staticmethod
808 def evBatchMismatch(check=False, **kwargs):
809 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000810 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100811 error_result = False
812 error_reason = "Input batch size not equal to output batch size"
813
814 assert "op" in kwargs
815 op = kwargs["op"]
816 rmin, rmax = op["rank"]
817 rank_range = range(rmin, rmax + 1)
818
819 if check:
820 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100821
Luke Hutton261b7b62023-01-10 14:50:31 +0000822 for output in kwargs["result_tensors"]:
823 output_shape = (
824 output.shape
825 ) # Note batch is expected to be the first dim
826 if (len(input_shape) in rank_range) and (
827 input_shape[0] != output_shape[0]
828 ):
829 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100830
831 info_dict = {
832 "error_name": error_name,
833 "error_result": error_result,
834 "error_reason": error_reason,
835 "param_reqs": param_reqs,
836 }
837 return info_dict
838
839 @staticmethod
840 def evChannelMismatch(check=False, **kwargs):
841 error_name = ErrorIf.ChannelMismatch
842 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
843 error_result = False
844 error_reason = "Input channel size not equal to output channel size"
845
846 assert "op" in kwargs
847 op = kwargs["op"]
848 rmin, rmax = op["rank"]
849 rank_range = range(rmin, rmax + 1)
850
851 if check:
852 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000853 for output in kwargs["result_tensors"]:
854 output_shape = output.shape # Note this is just (N, OH, OW, C)
855 if (len(input_shape) in rank_range) and (
856 input_shape[3] != output_shape[3]
857 ):
858 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100859
860 info_dict = {
861 "error_name": error_name,
862 "error_result": error_result,
863 "error_reason": error_reason,
864 "param_reqs": param_reqs,
865 }
866 return info_dict
867
868 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100869 def evScaleSmallerEqualZero(check=False, **kwargs):
870 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100871 param_reqs = {"rank": None, "dtype": None, "shape": None}
872 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100873 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100874
875 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100876 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100877
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100878 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100879 error_result = True
880
881 info_dict = {
882 "error_name": error_name,
883 "error_result": error_result,
884 "error_reason": error_reason,
885 "param_reqs": param_reqs,
886 }
887 return info_dict
888
889 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100890 def evScaleNLargerMax(check=False, **kwargs):
891 error_name = ErrorIf.ScaleNLargerMax
892 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100893 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100894 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100895
896 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100897 scale = kwargs["scale"]
898
899 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
900 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901
902 info_dict = {
903 "error_name": error_name,
904 "error_result": error_result,
905 "error_reason": error_reason,
906 "param_reqs": param_reqs,
907 }
908 return info_dict
909
910 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100911 def evScaleDLargerMax(check=False, **kwargs):
912 error_name = ErrorIf.ScaleDLargerMax
913 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100914 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100915 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100916
917 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100918 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100919
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100920 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
921 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100922 ):
923 error_result = True
924
925 info_dict = {
926 "error_name": error_name,
927 "error_result": error_result,
928 "error_reason": error_reason,
929 "param_reqs": param_reqs,
930 }
931 return info_dict
932
933 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100934 def evOffsetSmallerMin(check=False, **kwargs):
935 error_name = ErrorIf.OffsetSmallerMin
936 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100937 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100938 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100939
940 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100941 scale = kwargs["scale"]
942 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100943
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100944 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100945 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100946 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100947 error_result = True
948
949 info_dict = {
950 "error_name": error_name,
951 "error_result": error_result,
952 "error_reason": error_reason,
953 "param_reqs": param_reqs,
954 }
955 return info_dict
956
957 @staticmethod
958 def evOffsetLargerEqualMax(check=False, **kwargs):
959 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100960 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100961 error_result = False
962 error_reason = "Offset value larger than or equal to maximum value"
963
964 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100965 scale = kwargs["scale"]
966 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100967
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100968 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
969 error_result = True
970 elif (
971 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
972 ):
973 error_result = True
974
975 info_dict = {
976 "error_name": error_name,
977 "error_result": error_result,
978 "error_reason": error_reason,
979 "param_reqs": param_reqs,
980 }
981 return info_dict
982
983 @staticmethod
984 def evBorderSmallerMin(check=False, **kwargs):
985 error_name = ErrorIf.BorderSmallerMin
986 param_reqs = {"rank": None, "dtype": None, "shape": None}
987 error_result = False
988 error_reason = "Border value smaller than minimum value"
989
990 if check:
991 scale = kwargs["scale"]
992 border = kwargs["border"]
993
994 if (
995 scale[0] > 0
996 and scale[0] <= (1 << 11)
997 and (border[0] < (-16 * scale[0]))
998 ):
999 error_result = True
1000 elif (
1001 scale[2] > 0
1002 and scale[2] <= (1 << 11)
1003 and (border[1] < (-16 * scale[2]))
1004 ):
1005 error_result = True
1006
1007 info_dict = {
1008 "error_name": error_name,
1009 "error_result": error_result,
1010 "error_reason": error_reason,
1011 "param_reqs": param_reqs,
1012 }
1013 return info_dict
1014
1015 @staticmethod
1016 def evBorderLargerEqualMax(check=False, **kwargs):
1017 error_name = ErrorIf.BorderLargerEqualMax
1018 param_reqs = {"rank": None, "dtype": None, "shape": None}
1019 error_result = False
1020 error_reason = "Border value larger than or equal to maximum value"
1021
1022 if check:
1023 scale = kwargs["scale"]
1024 border = kwargs["border"]
1025
1026 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
1027 error_result = True
1028 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
1029 error_result = True
1030
1031 info_dict = {
1032 "error_name": error_name,
1033 "error_result": error_result,
1034 "error_reason": error_reason,
1035 "param_reqs": param_reqs,
1036 }
1037 return info_dict
1038
1039 @staticmethod
1040 def checkResizeParams(scale, offset, border):
1041 return (
1042 min(scale) > 0
1043 and max(scale[0], scale[2]) <= (1 << 11)
1044 and scale[1] < 16 * scale[0]
1045 and scale[3] < 16 * scale[2]
1046 and offset[0] >= -scale[0]
1047 and offset[1] >= -scale[2]
1048 and offset[0] < 16 * scale[0]
1049 and offset[1] < 16 * scale[2]
1050 and border[0] >= -16 * scale[0]
1051 and border[1] >= -16 * scale[2]
1052 and border[0] < scale[0]
1053 and border[1] < scale[2]
1054 )
1055
1056 @staticmethod
1057 def evResizeOutputShapeMismatch(check=False, **kwargs):
1058 error_name = ErrorIf.ResizeOutputShapeMismatch
1059 param_reqs = {"rank": None, "dtype": None, "shape": None}
1060 error_result = False
1061 error_reason = (
1062 "Mismatch between output shape provided and expected output shape"
1063 )
1064
1065 if check:
1066 input_shape = kwargs["input_shape"]
1067 output_shape = kwargs["output_shape"]
1068 scale = kwargs["scale"]
1069 offset = kwargs["offset"]
1070 border = kwargs["border"]
1071
1072 # Ensure parameters are valid
1073 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1074
1075 if (
1076 params_valid
Jeremy Johnsondd975b82024-02-28 17:29:13 +00001077 and max(output_shape) < gtu.MAX_RESIZE_DIMENSION
1078 and max(input_shape) < gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001079 ):
1080 output_y = (
1081 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1082 ) // scale[1] + 1
1083 output_x = (
1084 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1085 ) // scale[3] + 1
1086
1087 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001088 error_result = True
1089
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001090 info_dict = {
1091 "error_name": error_name,
1092 "error_result": error_result,
1093 "error_reason": error_reason,
1094 "param_reqs": param_reqs,
1095 }
1096 return info_dict
1097
1098 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001099 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1100 error_name = ErrorIf.ResizeOutputShapeNonInteger
1101 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001102 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001103 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001104
1105 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001106 input_shape = kwargs["input_shape"]
1107 scale = kwargs["scale"]
1108 offset = kwargs["offset"]
1109 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001110
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001111 # Ensure parameters are valid
1112 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001113
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001114 if params_valid:
1115 remainder_y = (
1116 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1117 ) % scale[1]
1118 remainder_x = (
1119 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1120 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001121
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001122 if max(remainder_y, remainder_x) > 0:
1123 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001124
1125 info_dict = {
1126 "error_name": error_name,
1127 "error_result": error_result,
1128 "error_reason": error_reason,
1129 "param_reqs": param_reqs,
1130 }
1131 return info_dict
1132
1133 @staticmethod
1134 def evRankMismatch(check=False, **kwargs):
1135 error_name = ErrorIf.RankMismatch
1136 param_reqs = {"rank": None, "dtype": None, "shape": None}
1137 error_result = False
1138 error_reason = "Input Rank does not match output rank"
1139
1140 if check:
1141 input1_shape = kwargs["input1"].shape
Luke Huttona4e48ca2023-02-22 11:53:48 +00001142 input2_shape = (
1143 kwargs["input2"].shape if "input2" in kwargs else input1_shape
1144 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001145 # In case of SELECT op
1146 input3_shape = (
1147 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1148 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001149
1150 for output in kwargs["result_tensors"]:
1151 output_shape = output.shape
1152 if (
1153 (len(input1_shape) != len(output_shape))
1154 or (len(input2_shape) != len(output_shape))
1155 or (len(input3_shape) != len(output_shape))
1156 ):
1157 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001158
1159 info_dict = {
1160 "error_name": error_name,
1161 "error_result": error_result,
1162 "error_reason": error_reason,
1163 "param_reqs": param_reqs,
1164 }
1165 return info_dict
1166
1167 @staticmethod
1168 def evDimensionMismatch(check=False, **kwargs):
1169 error_name = ErrorIf.DimensionMismatch
1170 param_reqs = {"rank": None, "dtype": None, "shape": None}
1171 error_result = False
1172 error_reason = "Input Dimensions do not match output"
1173
1174 if check:
1175 input1_shape = kwargs["input1"].shape
1176 input2_shape = kwargs["input2"].shape
1177 # In case of SELECT op
1178 input3_shape = (
1179 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1180 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001181
Won Jeon74342e52024-01-09 00:34:40 +00001182 op = kwargs["op"]
1183 if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE):
1184 output_shape = kwargs["result_tensors"][0].shape
1185 if input1_shape != output_shape:
1186 error_result = True
1187
1188 elif len(input1_shape) == len(input2_shape) == len(input3_shape):
Jerry Ge135c9552023-05-23 20:59:32 +00001189 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
1190 input3_shape,
1191 TosaErrorValidator.calculateBroadcastShape(
1192 input1_shape, input2_shape
1193 ),
1194 )
1195 if calculated_shape is not None:
1196 # Valid inputs - check for output mismatch
1197 for output in kwargs["result_tensors"]:
1198 output_shape = output.shape
1199 if calculated_shape != output_shape:
1200 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001201
1202 info_dict = {
1203 "error_name": error_name,
1204 "error_result": error_result,
1205 "error_reason": error_reason,
1206 "param_reqs": param_reqs,
1207 }
1208 return info_dict
1209
1210 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001211 def _getZeroPoint(qinfo, index):
1212 """Return zero point value from quantization info.
1213
1214 Generally input_zp is index 0, output_zp is index 1
1215 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001216 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001217
1218 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001219 def evInputZeroPointNotZero(check=False, **kwargs):
1220 op = kwargs["op"]
1221 error_result = False
1222
1223 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001224 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001225
1226 # This does not apply to quantizable types
1227 inputDtypes = [
1228 dtype
1229 for dtype in op["types"]
1230 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1231 or (not isinstance(dtype, list) and dtype not in qTypes)
1232 ]
1233
1234 if check:
1235 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001236 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001237 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001238 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001239 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001240 (kwargs["input_dtype"], input_zero_point),
1241 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001242 ):
1243 if dtype not in qTypes and zp != 0:
1244 error_result = True
1245 break
1246 else:
1247 error_result = input_dtype not in qTypes and input_zero_point != 0
1248
1249 info_dict = {
1250 "error_name": ErrorIf.InputZeroPointNotZero,
1251 "error_result": error_result,
1252 "error_reason": "Input DType not INT8 and zero point not 0",
1253 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1254 }
1255 return info_dict
1256
1257 @staticmethod
1258 def evWeightZeroPointNotZero(check=False, **kwargs):
1259 op = kwargs["op"]
1260
1261 # exclude inputs with INT8 weights
1262 inputDtypes = [
1263 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1264 ]
1265
1266 error_name = ErrorIf.WeightZeroPointNotZero
1267 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1268 error_result = False
1269 error_reason = "Weight DType not INT8 and zero point not 0"
1270
1271 if check:
1272 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001273 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001274 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1275 error_result = True
1276
1277 info_dict = {
1278 "error_name": error_name,
1279 "error_result": error_result,
1280 "error_reason": error_reason,
1281 "param_reqs": param_reqs,
1282 }
1283 return info_dict
1284
1285 @staticmethod
1286 def evOutputZeroPointNotZero(check=False, **kwargs):
1287 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001288 inputDtypes = [
1289 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1290 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001291
1292 error_name = ErrorIf.OutputZeroPointNotZero
1293 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1294 error_result = False
1295 error_reason = "Output DType not INT8 and zero point not 0"
1296
1297 if check:
1298 input_dtype = kwargs["input_dtype"]
1299 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001300 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001301 if op["op"] == Op.AVG_POOL2D:
1302 if input_dtype != DType.INT8 and output_zero_point != 0:
1303 error_result = True
1304 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001305 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1306 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001307 ):
1308 error_result = True
1309
1310 info_dict = {
1311 "error_name": error_name,
1312 "error_result": error_result,
1313 "error_reason": error_reason,
1314 "param_reqs": param_reqs,
1315 }
1316 return info_dict
1317
1318 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001319 def evU16InputZeroPointNotValid(check=False, **kwargs):
1320 error_name = ErrorIf.U16InputZeroPointNotValid
1321 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1322 error_result = False
1323 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1324
1325 if check:
1326 input_dtype = kwargs["input_dtype"]
1327 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1328 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1329 0,
1330 32768,
1331 ]
1332
1333 info_dict = {
1334 "error_name": error_name,
1335 "error_result": error_result,
1336 "error_reason": error_reason,
1337 "param_reqs": param_reqs,
1338 }
1339 return info_dict
1340
1341 @staticmethod
1342 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1343 error_name = ErrorIf.U16OutputZeroPointNotValid
1344 param_reqs = {"rank": None, "dtype": None, "shape": None}
1345 error_result = False
1346 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1347
1348 if check:
1349 output_dtype = kwargs["output_dtype"]
1350 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1351
1352 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1353 0,
1354 32768,
1355 ]
1356
1357 info_dict = {
1358 "error_name": error_name,
1359 "error_result": error_result,
1360 "error_reason": error_reason,
1361 "param_reqs": param_reqs,
1362 }
1363 return info_dict
1364
1365 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001366 def evAxisSmallerZero(check=False, **kwargs):
1367 error_name = ErrorIf.AxisSmallerZero
1368 param_reqs = {"rank": None, "dtype": None, "shape": None}
1369 error_result = False
1370 error_reason = "Axis smaller than zero"
1371
1372 if check:
1373 axis = kwargs["axis"]
1374 if axis < 0:
1375 error_result = True
1376
1377 info_dict = {
1378 "error_name": error_name,
1379 "error_result": error_result,
1380 "error_reason": error_reason,
1381 "param_reqs": param_reqs,
1382 }
1383 return info_dict
1384
1385 @staticmethod
1386 def evAxisLargerRank(check=False, **kwargs):
1387 error_name = ErrorIf.AxisLargerRank
1388 param_reqs = {"rank": None, "dtype": None, "shape": None}
1389 error_result = False
1390 error_reason = "Axis larger than rank"
1391
1392 if check:
1393 axis = kwargs["axis"]
1394 shape = kwargs["input_shape"]
1395 if axis > len(shape):
1396 error_result = True
1397
1398 info_dict = {
1399 "error_name": error_name,
1400 "error_result": error_result,
1401 "error_reason": error_reason,
1402 "param_reqs": param_reqs,
1403 }
1404 return info_dict
1405
1406 @staticmethod
1407 def evShapeOfAxisNotOne(check=False, **kwargs):
1408 error_name = ErrorIf.ShapeOfAxisNotOne
1409 param_reqs = {"rank": None, "dtype": None, "shape": None}
1410 error_result = False
1411 error_reason = "shape[axis] is not equal to 1"
1412
1413 if check:
1414 axis = kwargs["axis"]
1415 shape = kwargs["output_shape"]
1416 if (0 <= axis < len(shape)) and shape[axis] != 1:
1417 error_result = True
1418
1419 info_dict = {
1420 "error_name": error_name,
1421 "error_result": error_result,
1422 "error_reason": error_reason,
1423 "param_reqs": param_reqs,
1424 }
1425 return info_dict
1426
1427 @staticmethod
1428 def evPadSmallerZero(check=False, **kwargs):
1429 error_name = ErrorIf.PadSmallerZero
1430 param_reqs = {"rank": None, "dtype": None, "shape": None}
1431 error_result = False
1432 error_reason = "At least one pad is smaller than zero"
1433
1434 if check:
1435 op = kwargs["op"]
1436 pad = kwargs["pad"]
1437 if op["op"] == Op.PAD:
1438 for padding in pad:
1439 if min(padding) < 0:
1440 error_result = True
1441 else:
1442 if min(pad) < 0:
1443 error_result = True
1444
1445 info_dict = {
1446 "error_name": error_name,
1447 "error_result": error_result,
1448 "error_reason": error_reason,
1449 "param_reqs": param_reqs,
1450 }
1451 return info_dict
1452
1453 @staticmethod
1454 def evPadLargerEqualKernel(check=False, **kwargs):
1455 error_name = ErrorIf.PadLargerEqualKernel
1456 param_reqs = {"rank": None, "dtype": None, "shape": None}
1457 error_result = False
1458 error_reason = "At least one pad is larger than kernel dimension"
1459
1460 if check:
1461 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001462 op = kwargs["op"]
1463 if op["op"] == Op.TRANSPOSE_CONV2D:
1464 # transpose_conv2d
1465 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001466 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001467 pad[0] <= -kernel[0]
1468 or pad[1] <= -kernel[0]
1469 or pad[2] <= -kernel[1]
1470 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001471 ):
1472 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001473 else:
1474 # pooling op
1475 kernel = kwargs["kernel"]
1476 if min(pad) > 0 and min(kernel) > 1:
1477 if (
1478 pad[0] >= kernel[0]
1479 or pad[1] >= kernel[0]
1480 or pad[2] >= kernel[1]
1481 or pad[3] >= kernel[1]
1482 ):
1483 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001484
1485 info_dict = {
1486 "error_name": error_name,
1487 "error_result": error_result,
1488 "error_reason": error_reason,
1489 "param_reqs": param_reqs,
1490 }
1491 return info_dict
1492
1493 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001494 def evPadOutputShapeMismatch(check=False, **kwargs):
1495 error_name = ErrorIf.PadOutputShapeMismatch
1496 param_reqs = {"rank": None, "dtype": None, "shape": None}
1497 error_result = False
1498 error_reason = "Pad output shape mismatch for requested padding"
1499
1500 if check:
1501 pad = kwargs["pad"]
1502 input_shape = kwargs["input_shape"]
1503 output_shape = kwargs["output_shape"]
1504 for dim, padding in enumerate(pad):
1505 expected_size = input_shape[dim] + padding[0] + padding[1]
1506 if expected_size != output_shape[dim]:
1507 error_result = True
1508
1509 info_dict = {
1510 "error_name": error_name,
1511 "error_result": error_result,
1512 "error_reason": error_reason,
1513 "param_reqs": param_reqs,
1514 }
1515 return info_dict
1516
1517 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001518 def checkPoolingParams(kernel, stride, pad):
1519 return (
1520 min(kernel) >= 1
1521 and min(stride) >= 1
1522 and min(pad) >= 0
1523 and not (
1524 pad[0] >= kernel[0]
1525 or pad[1] >= kernel[0]
1526 or pad[2] >= kernel[1]
1527 or pad[3] >= kernel[1]
1528 )
1529 )
1530
1531 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001532 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1533 error_name = ErrorIf.PoolingOutputShapeMismatch
1534 param_reqs = {"rank": None, "dtype": None, "shape": None}
1535 error_result = False
1536 error_reason = (
1537 "Mismatch between output shape provided and expected output shape"
1538 )
1539
1540 if check:
1541 pad = kwargs["pad"]
1542 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1543
1544 kernel = kwargs["kernel"]
1545 kernel_y, kernel_x = kernel[0], kernel[1]
1546
1547 input_shape = kwargs["input_shape"]
1548 IH, IW = input_shape[1], input_shape[2]
1549
1550 output_shape = kwargs["output_shape"]
1551 OH, OW = output_shape[1], output_shape[2]
1552
1553 stride = kwargs["stride"]
1554 stride_y, stride_x = stride[0], stride[1]
1555
1556 # calculate correct height, width dimensions
1557 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001558 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1559 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001560
1561 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001562 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001563
1564 if params_valid and (OH != y_correct or OW != x_correct):
1565 error_result = True
1566
1567 info_dict = {
1568 "error_name": error_name,
1569 "error_result": error_result,
1570 "error_reason": error_reason,
1571 "param_reqs": param_reqs,
1572 }
1573 return info_dict
1574
1575 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001576 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1577 error_name = ErrorIf.PoolingOutputShapeNonInteger
1578 param_reqs = {"rank": None, "dtype": None, "shape": None}
1579 error_result = False
1580 error_reason = "Parameters do not yield exact integer output dimensions"
1581
1582 if check:
1583 pad = kwargs["pad"]
1584 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1585
1586 kernel = kwargs["kernel"]
1587 kernel_y, kernel_x = kernel[0], kernel[1]
1588
1589 input_shape = kwargs["input_shape"]
1590 IH, IW = input_shape[1], input_shape[2]
1591
1592 stride = kwargs["stride"]
1593 stride_y, stride_x = stride[0], stride[1]
1594
1595 # calculate remainder of height, width dimensions
1596 if stride_x != 0 and stride_y != 0:
1597 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1598 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1599
1600 # ensure parameters are valid
1601 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1602 if params_valid and (y_remainder != 0 or x_remainder != 0):
1603 error_result = True
1604
1605 info_dict = {
1606 "error_name": error_name,
1607 "error_result": error_result,
1608 "error_reason": error_reason,
1609 "param_reqs": param_reqs,
1610 }
1611 return info_dict
1612
1613 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001614 def checkConvParams(op, weight_shape, stride, pad, dilation):
1615 if op == Op.TRANSPOSE_CONV2D:
1616 pad_ok = (
1617 pad[0] > -weight_shape[1]
1618 and pad[1] > -weight_shape[1]
1619 and pad[2] > -weight_shape[2]
1620 and pad[3] > -weight_shape[2]
1621 )
1622 else:
1623 pad_ok = min(pad) >= 0
1624
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001625 return (
1626 # Check kernel sizes
1627 min(weight_shape[1:-1]) >= 1
1628 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001629 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001630 and (dilation is None or min(dilation) >= 1)
1631 )
1632
1633 @staticmethod
1634 def evConvOutputShapeMismatch(check=False, **kwargs):
1635 error_name = ErrorIf.ConvOutputShapeMismatch
1636 param_reqs = {"rank": None, "dtype": None, "shape": None}
1637 error_result = False
1638 error_reason = (
1639 "Mismatch between output shape provided and expected output shape"
1640 )
1641
1642 if check:
1643 op = kwargs["op"]
1644 pad = kwargs["pad"]
1645 weight_shape = kwargs["weight_shape"]
1646 input_shape = kwargs["input_shape"]
1647 output_shape = kwargs["output_shape"]
1648 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1649 stride = kwargs["stride"]
1650
1651 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1652
1653 # calculate correct dimensions
1654 dims_correct = []
1655 if min(stride) > 0:
1656 for index in range(len(stride)):
1657 pad_offset = index * 2
1658 if op["op"] == Op.TRANSPOSE_CONV2D:
1659 dims_correct.append(
1660 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001661 + pad[pad_offset]
1662 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001663 + weight_shape[index + kernel_offset]
1664 )
1665 else:
1666 dims_correct.append(
1667 (
1668 input_shape[index + 1]
1669 - 1
1670 + pad[pad_offset]
1671 + pad[pad_offset + 1]
1672 - (weight_shape[index + kernel_offset] - 1)
1673 * dilation[index]
1674 )
1675 // stride[index]
1676 + 1
1677 )
1678
1679 # ensure parameters are valid
1680 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001681 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001682 )
1683
1684 if params_valid and output_shape[1:-1] != dims_correct:
1685 error_result = True
1686
1687 info_dict = {
1688 "error_name": error_name,
1689 "error_result": error_result,
1690 "error_reason": error_reason,
1691 "param_reqs": param_reqs,
1692 }
1693 return info_dict
1694
1695 @staticmethod
1696 def evConvOutputShapeNonInteger(check=False, **kwargs):
1697 error_name = ErrorIf.ConvOutputShapeNonInteger
1698 param_reqs = {"rank": None, "dtype": None, "shape": None}
1699 error_result = False
1700 error_reason = "Parameters do not yield exact integer output dimensions"
1701
1702 if check:
1703 op = kwargs["op"]
1704 pad = kwargs["pad"]
1705 weight_shape = kwargs["weight_shape"]
1706 input_shape = kwargs["input_shape"]
1707 dilation = kwargs["dilation"]
1708 stride = kwargs["stride"]
1709
1710 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1711
1712 # calculate correct height, width dimensions
1713 remainders = []
1714 if min(stride) > 0:
1715 for index in range(len(stride)):
1716 pad_offset = index * 2
1717 remainders.append(
1718 (
1719 input_shape[index + 1]
1720 - 1
1721 + pad[pad_offset]
1722 + pad[pad_offset + 1]
1723 - (weight_shape[index + kernel_offset] - 1)
1724 * dilation[index]
1725 )
1726 % stride[index]
1727 )
1728
1729 # ensure parameters are valid
1730 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001731 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001732 )
1733 if params_valid and max(remainders) > 0:
1734 error_result = True
1735
1736 info_dict = {
1737 "error_name": error_name,
1738 "error_result": error_result,
1739 "error_reason": error_reason,
1740 "param_reqs": param_reqs,
1741 }
1742 return info_dict
1743
1744 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001745 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1746 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1747 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1748 error_result = False
1749 error_reason = (
1750 "Mismatch between output shape provided and expected output shape"
1751 )
1752
1753 if check:
1754 output_shape = kwargs["output_shape"]
1755 input_shape = kwargs["input_shape"]
1756 axis = kwargs["axis"]
1757
1758 dimension_match = True
1759 axis_shift = 0
1760
1761 # Check that rank is correct before trying to check dimensions
1762 if (len(input_shape) - 1) == len(output_shape):
1763 for i in range(len(input_shape)):
1764 if i == axis:
1765 axis_shift = 1
1766 continue
1767 if input_shape[i] != output_shape[i - axis_shift]:
1768 dimension_match = False
1769
1770 if not dimension_match:
1771 error_result = True
1772
1773 info_dict = {
1774 "error_name": error_name,
1775 "error_result": error_result,
1776 "error_reason": error_reason,
1777 "param_reqs": param_reqs,
1778 }
1779 return info_dict
1780
1781 @staticmethod
1782 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1783 error_name = ErrorIf.ArgmaxOutputRankMismatch
1784 param_reqs = {"rank": None, "dtype": None, "shape": None}
1785 error_result = False
1786 error_reason = (
1787 "Mismatch between output shape provided and expected output shape"
1788 )
1789
1790 if check:
1791 output_shape = kwargs["output_shape"]
1792 input_shape = kwargs["input_shape"]
1793 axis = kwargs["axis"]
1794 valid_params = axis >= 0 and axis < len(input_shape)
1795
1796 if valid_params and (len(input_shape) - 1) != len(output_shape):
1797 error_result = True
1798
1799 info_dict = {
1800 "error_name": error_name,
1801 "error_result": error_result,
1802 "error_reason": error_reason,
1803 "param_reqs": param_reqs,
1804 }
1805 return info_dict
1806
1807 @staticmethod
1808 def evKernelSmallerOne(check=False, **kwargs):
1809 error_name = ErrorIf.KernelSmallerOne
1810 param_reqs = {"rank": None, "dtype": None, "shape": None}
1811 error_result = False
1812 error_reason = "At least one kernel dimension is smaller than zero"
1813
1814 if check:
1815 kernel = kwargs["kernel"]
1816 if min(kernel) < 1:
1817 error_result = True
1818
1819 info_dict = {
1820 "error_name": error_name,
1821 "error_result": error_result,
1822 "error_reason": error_reason,
1823 "param_reqs": param_reqs,
1824 }
1825 return info_dict
1826
1827 @staticmethod
1828 def evStrideSmallerOne(check=False, **kwargs):
1829 error_name = ErrorIf.StrideSmallerOne
1830 param_reqs = {"rank": None, "dtype": None, "shape": None}
1831 error_result = False
1832 error_reason = "At least one stride dimension is smaller than zero"
1833
1834 if check:
1835 stride = kwargs["stride"]
1836 if min(stride) < 1:
1837 error_result = True
1838
1839 info_dict = {
1840 "error_name": error_name,
1841 "error_result": error_result,
1842 "error_reason": error_reason,
1843 "param_reqs": param_reqs,
1844 }
1845 return info_dict
1846
1847 @staticmethod
1848 def evDilationSmallerOne(check=False, **kwargs):
1849 error_result = check and min(kwargs["dilation"]) < 1
1850 return {
1851 "error_name": ErrorIf.DilationSmallerOne,
1852 "error_reason": "At least one dilation is smaller than one",
1853 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1854 "error_result": error_result,
1855 }
1856
1857 @staticmethod
1858 def evScaleTrue(check=False, **kwargs):
1859 error_name = ErrorIf.ScaleTrue
1860 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1861 error_result = False
1862 error_reason = "Scale set to true but input type is INT48"
1863
1864 if check:
1865 input_dtype = kwargs["input_dtype"]
1866 scale32 = kwargs["scale32"]
1867 if scale32 and input_dtype == DType.INT48:
1868 error_result = True
1869
1870 info_dict = {
1871 "error_name": error_name,
1872 "error_result": error_result,
1873 "error_reason": error_reason,
1874 "param_reqs": param_reqs,
1875 }
1876 return info_dict
1877
1878 @staticmethod
1879 def evScaleNotTrue(check=False, **kwargs):
1880 error_name = ErrorIf.ScaleNotTrue
1881 param_reqs = {"rank": None, "dtype": None, "shape": None}
1882 error_result = False
1883 error_reason = "Scale set to false but double round set to true"
1884
1885 if check:
1886 scale32 = kwargs["scale32"]
1887 double_round = kwargs["double_round"]
1888 if not scale32 and double_round:
1889 error_result = True
1890
1891 info_dict = {
1892 "error_name": error_name,
1893 "error_result": error_result,
1894 "error_reason": error_reason,
1895 "param_reqs": param_reqs,
1896 }
1897 return info_dict
1898
1899 @staticmethod
1900 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1901 error_name = ErrorIf.TensorSizeInputOutputMismatch
1902 param_reqs = {"rank": None, "dtype": None, "shape": None}
1903 error_result = False
1904 error_reason = "Input tensor size does not match output tensor size"
Jerry Ge264f7fa2023-04-21 22:49:57 +00001905 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001906
1907 if check:
1908 input_shape = kwargs["input_shape"]
1909 output_shape = kwargs["output_shape"]
Jerry Ge264f7fa2023-04-21 22:49:57 +00001910 shape_inferencing = False
1911 if -1 in output_shape and op["op"] == Op.RESHAPE:
1912 shape_inferencing = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001913 input_size = np.prod(input_shape)
1914 output_size = np.prod(output_shape)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001915 if input_size != output_size and not shape_inferencing:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001916 error_result = True
1917
1918 info_dict = {
1919 "error_name": error_name,
1920 "error_result": error_result,
1921 "error_reason": error_reason,
1922 "param_reqs": param_reqs,
1923 }
1924 return info_dict
1925
1926 @staticmethod
1927 def evStartSmallerZero(check=False, **kwargs):
1928 error_name = ErrorIf.StartSmallerZero
1929 param_reqs = {"rank": None, "dtype": None, "shape": None}
1930 error_result = False
1931 error_reason = "Starting point smaller than zero"
1932
1933 if check:
1934 input_shape = kwargs["input_shape"]
1935 start = kwargs["start"]
1936 rank = len(input_shape)
1937 if len(start) == rank:
1938 for index in range(rank):
1939 if start[index] < 0:
1940 error_result = True
1941
1942 info_dict = {
1943 "error_name": error_name,
1944 "error_result": error_result,
1945 "error_reason": error_reason,
1946 "param_reqs": param_reqs,
1947 }
1948 return info_dict
1949
1950 @staticmethod
1951 def evSizeSmallerEqualZero(check=False, **kwargs):
1952 error_name = ErrorIf.SizeSmallerEqualZero
1953 param_reqs = {"rank": None, "dtype": None, "shape": None}
1954 error_result = False
1955 error_reason = "Size smaller than or equal to zero"
1956
1957 if check:
1958 input_shape = kwargs["input_shape"]
1959 size = kwargs["size"]
1960 rank = len(input_shape)
1961 if len(size) == rank:
1962 for index in range(rank):
1963 if size[index] <= 0:
1964 error_result = True
1965
1966 info_dict = {
1967 "error_name": error_name,
1968 "error_result": error_result,
1969 "error_reason": error_reason,
1970 "param_reqs": param_reqs,
1971 }
1972 return info_dict
1973
1974 @staticmethod
1975 def evStartSizeOutsideBounds(check=False, **kwargs):
1976 error_name = ErrorIf.StartSizeOutsideBounds
1977 param_reqs = {"rank": None, "dtype": None, "shape": None}
1978 error_result = False
1979 error_reason = "starting point plus size larger than input dimension"
1980
1981 if check:
1982 input_shape = kwargs["input_shape"]
1983 start = kwargs["start"]
1984 size = kwargs["size"]
1985 rank = len(input_shape)
1986 if len(start) == rank and len(size) == rank:
1987 for index in range(rank):
1988 if start[index] + size[index] > input_shape[index]:
1989 error_result = True
1990
1991 info_dict = {
1992 "error_name": error_name,
1993 "error_result": error_result,
1994 "error_reason": error_reason,
1995 "param_reqs": param_reqs,
1996 }
1997 return info_dict
1998
1999 @staticmethod
2000 def evSizeOutputShapeMismatch(check=False, **kwargs):
2001 error_name = ErrorIf.SizeOutputShapeMismatch
2002 param_reqs = {"rank": None, "dtype": None, "shape": None}
2003 error_result = False
2004 error_reason = "Size does not match output dimension"
2005
2006 if check:
2007 input_shape = kwargs["input_shape"]
2008 output_shape = kwargs["output_shape"]
2009 size = kwargs["size"]
Luke Huttona4e48ca2023-02-22 11:53:48 +00002010
2011 if len(input_shape) == len(output_shape):
2012 rank = len(input_shape)
2013 if len(size) == rank:
2014 for index in range(rank):
2015 if size[index] != output_shape[index]:
2016 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002017
2018 info_dict = {
2019 "error_name": error_name,
2020 "error_result": error_result,
2021 "error_reason": error_reason,
2022 "param_reqs": param_reqs,
2023 }
2024 return info_dict
2025
2026 @staticmethod
2027 def evInputSizeStartLengthMismatch(check=False, **kwargs):
2028 error_name = ErrorIf.InputSizeStartLengthMismatch
2029 param_reqs = {"rank": None, "dtype": None, "shape": None}
2030 error_result = False
2031 error_reason = "rank of input not equal to length of start or size"
2032
2033 if check:
2034 input_shape = kwargs["input_shape"]
2035 start = kwargs["start"]
2036 size = kwargs["size"]
2037 rank = len(input_shape)
2038 if rank != len(start) or rank != len(size):
2039 error_result = True
2040
2041 info_dict = {
2042 "error_name": error_name,
2043 "error_result": error_result,
2044 "error_reason": error_reason,
2045 "param_reqs": param_reqs,
2046 }
2047 return info_dict
2048
2049 @staticmethod
2050 def evIndexOutsideBounds(check=False, **kwargs):
2051 error_name = ErrorIf.IndexOutsideBounds
2052 param_reqs = {"rank": None, "dtype": None, "shape": None}
2053 error_result = False
2054 error_reason = "Index outside of allowed bounds"
2055
2056 if check:
2057 input_shape = kwargs["input_shape"]
2058 perms = kwargs["perms"]
2059 rank = len(input_shape)
2060
2061 for index in perms:
2062 if index < 0 or index > rank:
2063 error_result = True
2064
2065 info_dict = {
2066 "error_name": error_name,
2067 "error_result": error_result,
2068 "error_reason": error_reason,
2069 "param_reqs": param_reqs,
2070 }
2071 return info_dict
2072
2073 @staticmethod
2074 def evIndexUsedTwice(check=False, **kwargs):
2075 error_name = ErrorIf.IndexUsedTwice
2076 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2077 error_result = False
2078 error_reason = "Index used multiple times"
2079
2080 if check:
2081 perms = kwargs["perms"]
2082
2083 unique_indices = []
2084 for index in perms:
2085 if index in unique_indices:
2086 error_result = True
2087 else:
2088 unique_indices.append(index)
2089
2090 info_dict = {
2091 "error_name": error_name,
2092 "error_result": error_result,
2093 "error_reason": error_reason,
2094 "param_reqs": param_reqs,
2095 }
2096 return info_dict
2097
2098 @staticmethod
2099 def evMaxSmallerMin(check=False, **kwargs):
2100 error_name = ErrorIf.MaxSmallerMin
2101 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2102 error_result = False
2103 error_reason = "Max value smaller than min value"
2104
2105 if check:
2106 max_val = kwargs["max_val"]
2107 min_val = kwargs["min_val"]
2108 if max_val < min_val:
2109 error_result = True
2110
2111 info_dict = {
2112 "error_name": error_name,
2113 "error_result": error_result,
2114 "error_reason": error_reason,
2115 "param_reqs": param_reqs,
2116 }
2117 return info_dict
2118
2119 @staticmethod
2120 def evConcatInputRankMismatch(check=False, **kwargs):
2121 error_name = ErrorIf.ConcatInputRankMismatch
2122 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2123 error_result = False
2124 error_reason = "Input ranks are not identical"
2125
2126 if check:
2127 inputs = kwargs["inputs"]
2128 input_shape = kwargs["input_shape"]
2129 for input in inputs:
2130 if len(input.shape) != len(input_shape):
2131 error_result = True
2132
2133 info_dict = {
2134 "error_name": error_name,
2135 "error_result": error_result,
2136 "error_reason": error_reason,
2137 "param_reqs": param_reqs,
2138 }
2139 return info_dict
2140
2141 @staticmethod
2142 def evConcatInputDimMismatch(check=False, **kwargs):
2143 error_name = ErrorIf.ConcatInputDimMismatch
2144 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2145 error_result = False
2146 error_reason = "Input dimensions differ on too many axes"
2147
2148 if check:
2149 inputs = kwargs["inputs"]
2150 input_shape = kwargs["input_shape"]
2151 axis = kwargs["axis"]
2152
2153 # Ensure rank is valid before checking dims.
2154 valid_rank = True
2155 for input in inputs:
2156 if len(input.shape) != len(input_shape):
2157 valid_rank = False
2158
2159 if valid_rank:
2160 for input in inputs:
2161 for i, dim in enumerate(input.shape):
2162 if dim != input_shape[i] and axis != i:
2163 error_result = True
2164
2165 info_dict = {
2166 "error_name": error_name,
2167 "error_result": error_result,
2168 "error_reason": error_reason,
2169 "param_reqs": param_reqs,
2170 }
2171 return info_dict
2172
2173 @staticmethod
2174 def evConcatShapeSumMismatch(check=False, **kwargs):
2175 error_name = ErrorIf.ConcatShapeSumMismatch
2176 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2177 error_result = False
2178 error_reason = "Sum of dimensions on axis not equal to output dimension"
2179
2180 if check:
2181 inputs = kwargs["inputs"]
2182 input_shape = kwargs["input_shape"]
2183 output_shape = kwargs["output_shape"]
2184 axis = kwargs["axis"]
2185
2186 # Ensure rank is valid before checking dims.
2187 valid_params = True
2188 for input in inputs:
2189 if len(input.shape) != len(input_shape):
2190 valid_params = False
2191 if axis < 0 or axis > len(input_shape):
2192 valid_params = False
2193
2194 if valid_params:
2195 axis_dim_sum = 0
2196 for input in inputs:
2197 axis_dim_sum += input.shape[axis]
2198
2199 if axis_dim_sum != output_shape[axis]:
2200 error_result = True
2201
2202 info_dict = {
2203 "error_name": error_name,
2204 "error_result": error_result,
2205 "error_reason": error_reason,
2206 "param_reqs": param_reqs,
2207 }
2208 return info_dict
2209
2210 @staticmethod
2211 def evInputListThenGraphMismatch(check=False, **kwargs):
2212 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2213 param_reqs = {"rank": None, "dtype": None, "shape": None}
2214 error_result = False
2215 error_reason = "Input list shape does not match then-graph shape"
2216
2217 if check:
2218 a = kwargs["a"]
2219 b = kwargs["b"]
2220 basicBlocks = kwargs["basicBlocks"]
2221 then_block = basicBlocks[1]
2222 then_inputs = then_block.inputs
2223 then_tens = then_block.tensors
2224 if (a.shape != then_tens[then_inputs[0]].shape) or (
2225 b.shape != then_tens[then_inputs[1]].shape
2226 ):
2227 error_result = True
2228
2229 info_dict = {
2230 "error_name": error_name,
2231 "error_result": error_result,
2232 "error_reason": error_reason,
2233 "param_reqs": param_reqs,
2234 }
2235 return info_dict
2236
2237 @staticmethod
2238 def evInputListElseGraphMismatch(check=False, **kwargs):
2239 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2240 param_reqs = {"rank": None, "dtype": None, "shape": None}
2241 error_result = False
2242 error_reason = "Input list shape does not match else-graph shape"
2243
2244 if check:
2245 a = kwargs["a"]
2246 b = kwargs["b"]
2247 basicBlocks = kwargs["basicBlocks"]
2248 else_block = basicBlocks[2]
2249 else_inputs = else_block.inputs
2250 else_tens = else_block.tensors
2251 if (a.shape != else_tens[else_inputs[0]].shape) or (
2252 b.shape != else_tens[else_inputs[1]].shape
2253 ):
2254 error_result = True
2255
2256 info_dict = {
2257 "error_name": error_name,
2258 "error_result": error_result,
2259 "error_reason": error_reason,
2260 "param_reqs": param_reqs,
2261 }
2262 return info_dict
2263
2264 @staticmethod
2265 def evOutputListThenGraphMismatch(check=False, **kwargs):
2266 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2267 param_reqs = {"rank": None, "dtype": None, "shape": None}
2268 error_result = False
2269 error_reason = "Output list shape does not match then-graph shape"
2270
2271 if check:
2272 basicBlocks = kwargs["basicBlocks"]
2273 cond_block = basicBlocks[0]
2274 cond_outputs = cond_block.outputs
2275 cond_tens = cond_block.tensors
2276 then_block = basicBlocks[1]
2277 then_outputs = then_block.outputs
2278 then_tens = then_block.tensors
2279 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2280 error_result = True
2281
2282 info_dict = {
2283 "error_name": error_name,
2284 "error_result": error_result,
2285 "error_reason": error_reason,
2286 "param_reqs": param_reqs,
2287 }
2288 return info_dict
2289
2290 @staticmethod
2291 def evOutputListElseGraphMismatch(check=False, **kwargs):
2292 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2293 param_reqs = {"rank": None, "dtype": None, "shape": None}
2294 error_result = False
2295 error_reason = "Output list shape does not match else-graph shape"
2296
2297 if check:
2298 basicBlocks = kwargs["basicBlocks"]
2299 cond_block = basicBlocks[0]
2300 cond_outputs = cond_block.outputs
2301 cond_tens = cond_block.tensors
2302 else_block = basicBlocks[2]
2303 else_outputs = else_block.outputs
2304 else_tens = else_block.tensors
2305 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2306 error_result = True
2307
2308 info_dict = {
2309 "error_name": error_name,
2310 "error_result": error_result,
2311 "error_reason": error_reason,
2312 "param_reqs": param_reqs,
2313 }
2314 return info_dict
2315
2316 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002317 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2318 error_name = ErrorIf.CondIfCondNotMatchingBool
2319 param_reqs = {"rank": None, "dtype": None, "shape": None}
2320 error_result = False
2321 error_reason = "Conditional tensor does not match bool type"
2322
2323 if check:
2324 cond = kwargs["cond"]
2325 if cond.dtype != DType.BOOL:
2326 error_result = True
2327
2328 info_dict = {
2329 "error_name": error_name,
2330 "error_result": error_result,
2331 "error_reason": error_reason,
2332 "param_reqs": param_reqs,
2333 }
2334 return info_dict
2335
2336 @staticmethod
2337 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2338 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2339 param_reqs = {"rank": None, "dtype": None, "shape": None}
2340 error_result = False
2341 error_reason = "Conditional tensor is not equal to a size of one"
2342
2343 if check:
2344 cond = kwargs["cond"]
2345 # Size of 1 is equivalent to rank 0
2346 if len(cond.shape) != 0:
2347 error_result = True
2348
2349 info_dict = {
2350 "error_name": error_name,
2351 "error_result": error_result,
2352 "error_reason": error_reason,
2353 "param_reqs": param_reqs,
2354 }
2355 return info_dict
2356
2357 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002358 def evInputListOutputListMismatch(check=False, **kwargs):
2359 error_name = ErrorIf.InputListOutputListMismatch
2360 param_reqs = {"rank": None, "dtype": None, "shape": None}
2361 error_result = False
2362 error_reason = "Input list does not match output list"
2363
2364 if check:
2365 basicBlocks = kwargs["basicBlocks"]
2366 while_block = basicBlocks[0]
2367 while_inputs = while_block.inputs
2368 while_outputs = while_block.outputs
2369 while_tens = while_block.tensors
2370 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2371 error_result = True
2372
2373 info_dict = {
2374 "error_name": error_name,
2375 "error_result": error_result,
2376 "error_reason": error_reason,
2377 "param_reqs": param_reqs,
2378 }
2379 return info_dict
2380
2381 @staticmethod
2382 def evInputListCondGraphMismatch(check=False, **kwargs):
2383 error_name = ErrorIf.InputListCondGraphMismatch
2384 param_reqs = {"rank": None, "dtype": None, "shape": None}
2385 error_result = False
2386 error_reason = "Input list does not match cond graph"
2387
2388 if check:
2389 basicBlocks = kwargs["basicBlocks"]
2390 while_block = basicBlocks[0]
2391 while_inputs = while_block.inputs
2392 while_tens = while_block.tensors
2393 cond_block = basicBlocks[1]
2394 cond_inputs = cond_block.inputs
2395 cond_tens = cond_block.tensors
2396 if (
2397 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2398 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2399 error_result = True
2400
2401 info_dict = {
2402 "error_name": error_name,
2403 "error_result": error_result,
2404 "error_reason": error_reason,
2405 "param_reqs": param_reqs,
2406 }
2407 return info_dict
2408
2409 @staticmethod
2410 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2411 error_name = ErrorIf.InputListBodyGraphInputMismatch
2412 param_reqs = {"rank": None, "dtype": None, "shape": None}
2413 error_result = False
2414 error_reason = "Input list does not match body graph input"
2415
2416 if check:
2417 basicBlocks = kwargs["basicBlocks"]
2418 while_block = basicBlocks[0]
2419 while_inputs = while_block.inputs
2420 while_tens = while_block.tensors
2421 body_block = basicBlocks[2]
2422 body_outputs = body_block.inputs
2423 body_tens = body_block.tensors
2424 if (
2425 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2426 ) or (
2427 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2428 ):
2429 error_result = True
2430
2431 info_dict = {
2432 "error_name": error_name,
2433 "error_result": error_result,
2434 "error_reason": error_reason,
2435 "param_reqs": param_reqs,
2436 }
2437 return info_dict
2438
2439 @staticmethod
2440 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2441 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2442 param_reqs = {"rank": None, "dtype": None, "shape": None}
2443 error_result = False
2444 error_reason = "Input list does not match body graph output"
2445
2446 if check:
2447 basicBlocks = kwargs["basicBlocks"]
2448 while_block = basicBlocks[0]
2449 while_inputs = while_block.inputs
2450 while_tens = while_block.tensors
2451 body_block = basicBlocks[2]
2452 body_outputs = body_block.outputs
2453 body_tens = body_block.tensors
2454 if (
2455 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2456 ) or (
2457 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2458 ):
2459 error_result = True
2460 info_dict = {
2461 "error_name": error_name,
2462 "error_result": error_result,
2463 "error_reason": error_reason,
2464 "param_reqs": param_reqs,
2465 }
2466 return info_dict
2467
2468 @staticmethod
2469 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2470 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2471 param_reqs = {"rank": None, "dtype": None, "shape": None}
2472 error_result = False
2473 error_reason = "Cond graph output is not a match list of booleans"
2474
2475 if check:
2476 basicBlocks = kwargs["basicBlocks"]
2477 cond_block = basicBlocks[1]
2478 cond_outputs = cond_block.outputs
2479 cond_tens = cond_block.tensors
2480 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2481 error_result = True
2482
2483 info_dict = {
2484 "error_name": error_name,
2485 "error_result": error_result,
2486 "error_reason": error_reason,
2487 "param_reqs": param_reqs,
2488 }
2489 return info_dict
2490
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002491 @staticmethod
2492 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2493 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2494 param_reqs = {"rank": None, "dtype": None, "shape": None}
2495 error_result = False
2496 error_reason = "Cond graph output is not a shape of size one"
2497
2498 if check:
2499 basicBlocks = kwargs["basicBlocks"]
2500 cond_block = basicBlocks[1]
2501 cond_outputs = cond_block.outputs
2502 cond_tens = cond_block.tensors
2503 # Size of 1 is equivalent to rank 0
2504 if len(cond_tens[cond_outputs[0]].shape) != 0:
2505 error_result = True
2506
2507 info_dict = {
2508 "error_name": error_name,
2509 "error_result": error_result,
2510 "error_reason": error_reason,
2511 "param_reqs": param_reqs,
2512 }
2513 return info_dict
2514
Luke Hutton261b7b62023-01-10 14:50:31 +00002515 @staticmethod
2516 def evKernelNotPowerOfTwo(check=False, **kwargs):
2517 error_name = ErrorIf.KernelNotPowerOfTwo
2518 param_reqs = {"rank": None, "dtype": None, "shape": None}
2519 error_result = False
2520 error_reason = "kernel height and/or width not a power of two"
2521
2522 def is_power_of_two(x):
2523 return math.log(x, 2).is_integer()
2524
2525 if check:
2526 shape = kwargs["input_shape"]
2527 if len(shape) == 3:
2528 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2529 error_result = not valid_kernel
2530
2531 info_dict = {
2532 "error_name": error_name,
2533 "error_result": error_result,
2534 "error_reason": error_reason,
2535 "param_reqs": param_reqs,
2536 }
2537 return info_dict
2538
Luke Hutton57287132023-02-06 14:54:18 +00002539 @staticmethod
2540 def evFFTInputShapeMismatch(check=False, **kwargs):
2541 error_name = ErrorIf.FFTInputShapeMismatch
2542 param_reqs = {"rank": None, "dtype": None, "shape": None}
2543 error_result = False
2544 error_reason = "Mismatch between real and imaginary input shapes"
2545
2546 if check:
2547 input1 = kwargs["input1"]
2548 input2 = kwargs["input2"]
2549
2550 if input1.shape != input2.shape:
2551 error_result = True
2552
2553 info_dict = {
2554 "error_name": error_name,
2555 "error_result": error_result,
2556 "error_reason": error_reason,
2557 "param_reqs": param_reqs,
2558 }
2559 return info_dict
2560
2561 @staticmethod
2562 def evFFTOutputShapeMismatch(check=False, **kwargs):
2563 error_name = ErrorIf.FFTOutputShapeMismatch
2564 param_reqs = {"rank": None, "dtype": None, "shape": None}
2565 error_result = False
2566 error_reason = (
2567 "Mismatch between provided and expected output kernel (H, W) shape"
2568 )
2569
2570 if check:
2571 op = kwargs["op"]
2572 input_shape = kwargs["input_shape"]
2573
2574 if len(input_shape) == 3:
2575 output_shapes = kwargs["output_shape"]
2576
2577 # Ignoring batch size (N) from input shape
2578 expected_shape = input_shape[1:]
2579 if op["op"] == Op.RFFT2D:
2580 expected_shape[1] = expected_shape[1] // 2 + 1
2581
2582 # Ignoring batch size (N) from output shapes
2583 output_shape_0 = output_shapes[0][1:]
2584 output_shape_1 = output_shapes[1][1:]
2585 # Ensure sure the kernel sizes (H, W) of both outputs match the expected
2586 if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
2587 error_result = True
2588
2589 info_dict = {
2590 "error_name": error_name,
2591 "error_result": error_result,
2592 "error_reason": error_reason,
2593 "param_reqs": param_reqs,
2594 }
2595 return info_dict
2596
Jerry Ge264f7fa2023-04-21 22:49:57 +00002597 @staticmethod
Jerry Ge135c9552023-05-23 20:59:32 +00002598 def calculateBroadcastShape(input_shape_a, input_shape_b):
2599 if input_shape_a is not None and input_shape_b is not None:
2600 calculated_shape = input_shape_a.copy()
2601 for idx in range(len(calculated_shape)):
2602 if calculated_shape[idx] == 1:
2603 calculated_shape[idx] = input_shape_b[idx]
2604 elif (
2605 input_shape_b[idx] != 1
2606 and input_shape_b[idx] != calculated_shape[idx]
2607 ):
2608 return None
2609 return calculated_shape
2610 else:
2611 return None
2612
2613 @staticmethod
2614 def evBroadcastShapesMismatch(check=False, **kwargs):
2615 error_name = ErrorIf.BroadcastShapesMismatch
2616 param_reqs = {"rank": None, "dtype": None, "shape": None}
2617 error_result = False
2618 error_reason = "Broadcast shape calculating failed"
2619
2620 if check:
2621 input_shape_a = kwargs["input1"].shape
2622 input_shape_b = kwargs["input2"].shape
2623 input_shape_c = (
2624 kwargs["input3"].shape if "input3" in kwargs else input_shape_b
2625 )
2626
2627 if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
2628 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
2629 input_shape_c,
2630 TosaErrorValidator.calculateBroadcastShape(
2631 input_shape_a, input_shape_b
2632 ),
2633 )
2634 error_result = calculated_shape is None
2635
2636 info_dict = {
2637 "error_name": error_name,
2638 "error_result": error_result,
2639 "error_reason": error_reason,
2640 "param_reqs": param_reqs,
2641 }
2642 return info_dict
2643
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002644 def evWrongAccumulatorType(check=False, **kwargs):
2645 error_name = ErrorIf.WrongAccumulatorType
2646 param_reqs = {"rank": None, "dtype": None, "shape": None}
2647 error_result = False
2648 error_reason = "An unsupported accumulator data type was requested"
2649
2650 if check:
2651 op = kwargs["op"]
2652 input_dtype = kwargs["input_dtype"]
2653 accum_dtype = kwargs["accum_dtype"]
2654 if op["op"] == Op.AVG_POOL2D:
2655 if (
2656 input_dtype
2657 in (
2658 DType.INT8,
2659 DType.INT16,
2660 )
2661 and accum_dtype != DType.INT32
2662 ):
2663 error_result = True
2664 elif (
2665 input_dtype
2666 in (
2667 DType.FP32,
2668 DType.BF16,
2669 )
2670 and accum_dtype != DType.FP32
2671 ):
2672 error_result = True
2673 elif input_dtype == DType.FP16 and accum_dtype not in (
2674 DType.FP16,
2675 DType.FP32,
2676 ):
2677 error_result = True
Won Jeon2c34b462024-02-06 18:37:00 +00002678 elif (
2679 input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
2680 and accum_dtype != DType.FP16
2681 ):
2682 error_result = True
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002683
Tai Lyf36f2562024-03-14 16:21:29 +00002684 elif op["op"] in {
2685 Op.CONV2D,
2686 Op.CONV3D,
2687 Op.DEPTHWISE_CONV2D,
2688 Op.TRANSPOSE_CONV2D,
2689 }:
2690 if input_dtype == DType.INT8 and accum_dtype != DType.INT32:
2691 error_result = True
2692 elif input_dtype == DType.INT16 and accum_dtype != DType.INT48:
2693 error_result = True
2694 elif (
2695 input_dtype
2696 in (
2697 DType.FP32,
2698 DType.BF16,
2699 )
2700 and accum_dtype != DType.FP32
2701 ):
2702 error_result = True
2703 elif input_dtype == DType.FP16 and accum_dtype not in (
2704 DType.FP16,
2705 DType.FP32,
2706 ):
2707 error_result = True
2708 elif (
2709 input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
2710 and accum_dtype != DType.FP16
2711 ):
2712 error_result = True
2713
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002714 info_dict = {
2715 "error_name": error_name,
2716 "error_result": error_result,
2717 "error_reason": error_reason,
2718 "param_reqs": param_reqs,
2719 }
2720 return info_dict
2721
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002722
2723class TosaInvalidValidator:
2724 @staticmethod
2725 def ivWrongDataTypeOrModeResize(**kwargs):
2726 input_dtype = kwargs["input_dtype"]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002727 args_dict = kwargs["args"]
2728 mode = args_dict["mode"]
2729 output_dtype = args_dict["output_dtype"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002730
2731 if mode == ResizeMode.BILINEAR:
2732 # Invalid output data type / Invalid input datatype
2733 return (
2734 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002735 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002736 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002737 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002738 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002739 )
2740 elif mode == ResizeMode.NEAREST:
2741 # Invalid output data type / Invalid input datatype
2742 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002743 input_dtype
2744 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002745 )
2746 else:
2747 # Invalid resize mode
2748 return True
2749
2750 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002751 def ivHeightWidthInvalid(**kwargs):
2752 opName = kwargs["opName"]
2753
2754 inputShapes = kwargs["shapeList"]
2755 input_shape = inputShapes[0]
2756
2757 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002758
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002759 if isinstance(args, dict):
2760 args_dict = args
2761 else:
2762 # Create args_dict from list elements
2763 # TODO - Remove this once all NWHC operators agFunctions have been
2764 # converted to args_dict output
2765
2766 # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
2767 stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
2768 args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]}
2769 # Alias different info for each op
2770 args_dict["kernel"] = args[pad_idx + 1]
2771 args_dict["out_shape"] = args[pad_idx + 1]
2772 args_dict["dilation"] = args[pad_idx + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002773
2774 # Common info for all ops
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002775 strides = args_dict["stride"]
2776 padding = args_dict["pad"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002777
2778 if opName.endswith("pool2d"):
2779 # avg_pool2d, max_pool2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002780 kernel_shape = args_dict["kernel"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002781 h = (
2782 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2783 ) // strides[0]
2784 w = (
2785 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2786 ) // strides[1]
2787 # return True if any dimension is < 1
2788 return h < 1 or w < 1
2789
2790 if opName.startswith("transpose_conv2d"):
2791 # transpose_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002792 output_shape = args_dict["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002793 filter_shape = inputShapes[1]
2794 kernel_shape = filter_shape[1:-1]
2795
TatWai Chong24594f52022-06-08 00:48:04 -07002796 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002797 """Calculate the transpose_conv2d output size for a dimension."""
2798 return (in_size - 1) * stride + kernel_size + in_pad + out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002799
Jeremy Johnson0c716862023-04-13 17:18:19 +01002800 h = get_out_size(
2801 input_shape[1],
2802 strides[0],
2803 kernel_shape[0],
2804 padding[0],
2805 padding[1],
2806 )
2807 w = get_out_size(
2808 input_shape[2],
2809 strides[1],
2810 kernel_shape[1],
2811 padding[2],
2812 padding[3],
2813 )
2814 if output_shape[1] == h and output_shape[2] == w:
2815 return False
2816 # output shape does not match the expected shape
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002817 return True
2818
2819 if "conv2d" in opName or "conv3d" in opName:
2820 # conv2d, conv3d, depthwise_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002821 dilations = args_dict["dilation"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002822 filter_shape = inputShapes[1]
2823 kernel_shape = (
2824 filter_shape[0:2]
2825 if opName.startswith("depthwise_conv2d")
2826 else filter_shape[1:-1]
2827 )
2828
2829 for i in range(len(kernel_shape)):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002830 pad_offset = i * 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002831 dim = (
2832 input_shape[i + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002833 - 1
2834 + padding[pad_offset]
2835 + padding[pad_offset + 1]
2836 - (kernel_shape[i] - 1) * dilations[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002837 ) // strides[i] + 1
2838 # return True if any dimension is < 1
2839 if dim < 1:
2840 return True
2841 return False
2842
2843 assert False, f"Unrecognized Op: {opName}"
2844
2845 @staticmethod
2846 def ivNonPositiveOutputShape(**kwargs):
2847 args = kwargs["args"]
Jeremy Johnson95a67102024-01-10 14:16:39 +00002848 output_shape = args["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002849 if output_shape[1] <= 0 or output_shape[2] <= 0:
2850 # Negative output shape
2851 return True
2852 return False