blob: 8c40371ef98654d839735a66c1cc05b53d1cf0b4 [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01006from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from generator.tosa_utils import product
8from generator.tosa_utils import usableDTypes
9from generator.tosa_utils import valueToName
10from tosa.DType import DType
11from tosa.Op import Op
12from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000013
Matthew Haddone86fd342021-09-07 16:12:21 +010014
15class ErrorIf(object):
16 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010017 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
18 ScaleNLargerMax = "ScaleNLargerMax"
19 ScaleDLargerMax = "ScaleDLargerMax"
20 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010021 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010022 BorderSmallerMin = "BorderSmallerMin"
23 BorderLargerEqualMax = "BorderLargerEqualMax"
24 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
25 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010026 WrongInputType = "WrongInputType"
27 WrongOutputType = "WrongOutputType"
28 WrongInputList = "WrongInputList"
29 WrongOutputList = "WrongOutputList"
30 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010031 BatchMismatch = "BatchMismatch"
32 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010033 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000034 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010037 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 AxisSmallerZero = "AxisSmallerZero"
39 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010040 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
41 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010042 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010043 KernelSmallerOne = "KernelSmallerOne"
44 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000045 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010046 PadSmallerZero = "PadSmallerZero"
47 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010048 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010049 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010050 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
51 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
52 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010053 ScaleNotTrue = "ScaleNotTrue"
54 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010055 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
56 StartSmallerZero = "StartSmallerZero"
57 SizeSmallerEqualZero = "SizeSmallerEqualZero"
58 StartSizeOutsideBounds = "StartSizeOutsideBounds"
59 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
60 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
61 IndexOutsideBounds = "IndexOutsideBounds"
62 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010063 MaxSmallerMin = "MaxSmallerMin"
64 ConcatInputRankMismatch = "ConcatInputRankMismatch"
65 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010066 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010067 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
68 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
69 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
70 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
71 InputListOutputListMismatch = "InputListOutputListMismatch"
72 InputListCondGraphMismatch = "InputListCondGraphMismatch"
73 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
74 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
75 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010076 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
77 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000078 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
79 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
80 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000081 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Luke Hutton57287132023-02-06 14:54:18 +000082 FFTInputShapeMismatch = "FFTInputShapeMismatch"
83 FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084
85
86class TosaErrorIfArgGen:
87 @staticmethod
88 def eiResizeErrorIf(
89 testGen,
90 error_name,
91 mode,
92 dtype,
93 shapeList,
94 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010095 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010097 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010099 if error_name == ErrorIf.ScaleSmallerEqualZero:
100 index = testGen.randInt(low=0, high=4)
101 scale[index] = testGen.rng.choice([-2, -1, 0])
102 elif error_name == ErrorIf.ScaleNLargerMax:
103 index = testGen.rng.choice([0, 2])
104 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
105 elif error_name == ErrorIf.ScaleDLargerMax:
106 index = testGen.rng.choice([1, 3])
107 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100109 if error_name == ErrorIf.OffsetLargerEqualMax:
110 index = testGen.rng.choice([0, 1])
111 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
112 elif error_name == ErrorIf.OffsetSmallerMin:
113 index = testGen.rng.choice([0, 1])
114 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
115
116 if error_name == ErrorIf.BorderLargerEqualMax:
117 index = testGen.rng.choice([0, 1])
118 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
119 elif error_name == ErrorIf.BorderSmallerMin:
120 index = testGen.rng.choice([0, 1])
121 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100122
123 if error_name == ErrorIf.WrongOutputType:
124 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
125 incorrect_types = (
126 DType.INT4,
127 DType.INT16,
128 DType.INT32,
129 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100130 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100131 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100132 )
133 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
134 incorrect_types = (
135 DType.INT4,
136 DType.INT8,
137 DType.INT32,
138 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100139 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100140 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100141 )
142 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
143 incorrect_types = (
144 DType.INT4,
145 DType.INT8,
146 DType.INT16,
147 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100148 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100149 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100150 )
151 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
152 incorrect_types = (
153 DType.INT4,
154 DType.INT8,
155 DType.INT16,
156 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100157 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100158 DType.FP16,
159 )
160 elif dtype == DType.FP16:
161 incorrect_types = (
162 DType.INT4,
163 DType.INT8,
164 DType.INT16,
165 DType.INT32,
166 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100167 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100168 )
James Ward24dbc422022-10-19 12:20:31 +0100169 elif dtype == DType.BF16:
170 incorrect_types = (
171 DType.INT4,
172 DType.INT8,
173 DType.INT16,
174 DType.INT32,
175 DType.INT48,
176 DType.FP32,
177 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100178 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100179 incorrect_types = (
180 DType.INT4,
181 DType.INT8,
182 DType.INT16,
183 DType.INT32,
184 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100185 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100186 )
187 outputDType = testGen.rng.choice(a=incorrect_types)
188
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100189 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100190
191 @staticmethod
192 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
193 if (
194 error_name == ErrorIf.StrideSmallerOne
195 # padding must not exceed the kernel size
196 and pad[0] < kernel[0]
197 and pad[1] < kernel[0]
198 and pad[2] < kernel[1]
199 and pad[3] < kernel[1]
200 ):
201 wrongStride = (
202 testGen.rng.choice([0, -1, -2, -3]),
203 testGen.rng.choice([0, -1, -2, -3]),
204 )
205 return wrongStride, pad, kernel
206 elif error_name == ErrorIf.PadSmallerZero:
207 wrongPad = (
208 testGen.rng.choice([-1, -2, -3]),
209 testGen.rng.choice([-1, -2, -3]),
210 testGen.rng.choice([-1, -2, -3]),
211 testGen.rng.choice([-1, -2, -3]),
212 )
213 return stride, wrongPad, kernel
214 elif error_name == ErrorIf.KernelSmallerOne:
215 wrongKernel = (
216 testGen.rng.choice([0, -1, -2, -3]),
217 testGen.rng.choice([0, -1, -2, -3]),
218 )
219 return stride, pad, wrongKernel
220 elif error_name == ErrorIf.PadLargerEqualKernel:
221 wrongPad = (
222 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
223 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
224 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
225 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
226 )
227 return stride, wrongPad, kernel
228 else:
229 return None, None, None
230
231 @staticmethod
232 def eiRescaleWrongOutputType(input_dtype, output_dtype):
233 if input_dtype == DType.INT8:
234 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
235 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100236 elif input_dtype == DType.INT16:
237 if output_dtype not in [
238 DType.UINT8,
239 DType.INT8,
240 DType.UINT16,
241 DType.INT16,
242 DType.INT32,
243 ]:
244 return True
245 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100246 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
247 return True
248 elif input_dtype == DType.INT48:
249 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
250 return True
251 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100252 if output_dtype not in [DType.INT8, DType.INT16]:
253 return True
254 elif input_dtype == DType.UINT16:
255 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100256 return True
257 return False
258
259 @staticmethod
260 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
261 # Mess up input/output tensors for ERROR_IF checks
262 if error_name == "WrongInputList":
263 add_input = testGen.rng.choice([True, False])
264 if add_input:
265 input_list.append("eiDummyInput")
266 else:
267 input_list = input_list[:-1]
268 elif error_name == "WrongOutputList":
269 add_output = testGen.rng.choice([True, False])
270 if add_output:
271 output_list.append("eiDummyOutput")
272 else:
273 output_list = []
274 return input_list, output_list
275
276 @staticmethod
277 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
278 """Restrict the dimensions and overall size of a shape to
279 max_dim and max_items.
280 """
281 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
282 while product(new_shape) > max_items:
283 new_shape = [max(d - 1, 1) for d in new_shape]
284 return new_shape
285
286 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
287 if error_name == ErrorIf.StartSmallerZero:
288 newStart = []
289 for i in range(len(input_shape)):
290 newStart.append(testGen.rng.choice([-3, -2, -1]))
291 return newStart, size
292 elif error_name == ErrorIf.SizeSmallerEqualZero:
293 newSize = []
294 for i in range(len(input_shape)):
295 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
296 return start, newSize
297 elif error_name == ErrorIf.StartSizeOutsideBounds:
298 newStart, newSize = [], []
299 for i in range(len(input_shape)):
300 newStart.append(input_shape[i] - 1)
301 newSize.append(testGen.rng.choice([2, 3, 4]))
302 return newStart, newSize
303 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
304 remove = testGen.rng.choice([True, False])
305 if remove:
306 newStart = start[1:]
307 newSize = size[1:]
308 else:
309 newStart = start
310 newStart.append(1)
311 newSize = size
312 newSize.append(1)
313 return newStart, newSize
314 else:
315 return start, size
316
317 @staticmethod
318 def eiCastErrorIf(testGen, input_dtype):
James Ward736fd1a2023-01-23 17:13:37 +0000319 if input_dtype in [DType.BOOL, DType.FP32]:
320 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
321 elif input_dtype in [DType.FP16, DType.BF16]:
322 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100323 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
324 outputDType = [DType.INT48]
325 else:
James Ward736fd1a2023-01-23 17:13:37 +0000326 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100327 return outputDType
328
329
330class TosaErrorValidator:
331 @staticmethod
332 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
333 """Check ERROR_IF statements are caught and set the expected result.
334
335 Args:
336 serializer: the serializer to set the expected result in
337 validator_fcns: a sequence of validator functions to verify the result
338 error_name: the name of the ERROR_IF condition to check for
339 kwargs: keyword arguments for the validator functions
340 Returns:
341 True if the result matches the expected result; otherwise False
342 """
343 overall_result = True
344 for val_fcn in validator_fcns:
345 val_result = val_fcn(True, **kwargs)
346 validator_name = val_result["error_name"]
347 error_result = val_result["error_result"]
348 error_reason = val_result["error_reason"]
349
350 # expect an error IFF the error_name and validator_name match
351 expected_result = error_result == (error_name == validator_name)
352 overall_result &= expected_result
353
354 if expected_result and error_result:
355 serializer.setExpectedReturnCode(2, True, desc=error_reason)
356 elif error_result: # and not expected_result
357 print(
358 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
359 f" Expected: {error_name}, Got: {validator_name}"
360 )
361 elif not expected_result: # and not error_result
362 print(
363 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
364 f" Expected: {error_name}"
365 )
366
367 if not expected_result:
368 for k, v in sorted(kwargs.items()):
369 if k != "op":
370 if k.endswith("dtype"):
371 v = valueToName(DType, v)
372 print(f" {k} = {v}")
373
374 return overall_result
375
376 @staticmethod
377 def evWrongInputType(check=False, **kwargs):
378 error_result = False
379
380 # Find the unsupported input data types
381 op = kwargs["op"]
382 input_dtypes = op["types"]
383 allowed_input_dtypes = {
384 t[0] if isinstance(t, list) else t for t in input_dtypes
385 }
386 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
387
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100388 # Turn the wrong dtypes into required list of types
389 if op["op"] in [
390 Op.FULLY_CONNECTED,
391 Op.CONV2D,
392 Op.CONV3D,
393 Op.DEPTHWISE_CONV2D,
394 Op.TRANSPOSE_CONV2D,
395 ]:
396 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
397
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100398 if op["op"] == Op.CLAMP:
399 wrong_input_dtypes.remove(DType.INT48)
400
401 if check:
402 input_dtype = kwargs["input_dtype"]
403 if input_dtype not in allowed_input_dtypes:
404 error_result = True
405
406 info_dict = {
407 "error_name": ErrorIf.WrongInputType,
408 "error_result": error_result,
409 "error_reason": "Input data type not supported for this operator",
410 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
411 }
412 return info_dict
413
414 @staticmethod
415 def evWrongOutputType(check=False, **kwargs):
416 error_result = False
417
418 if check:
419 input_dtype = kwargs["input_dtype"]
420 output_dtype = kwargs["output_dtype"]
421 op = kwargs["op"]
422
423 if op["op"] == Op.RESIZE:
424 mode = kwargs["mode"]
425 if (
426 (
427 mode == ResizeMode.NEAREST
428 and input_dtype == DType.INT8
429 and output_dtype != DType.INT8
430 )
431 or (
432 mode == ResizeMode.NEAREST
433 and input_dtype == DType.INT16
434 and output_dtype != DType.INT16
435 )
436 or (
437 mode == ResizeMode.BILINEAR
438 and input_dtype == DType.INT8
439 and output_dtype != DType.INT32
440 )
441 or (
442 mode == ResizeMode.BILINEAR
443 and input_dtype == DType.INT16
444 and output_dtype != DType.INT48
445 )
James Ward8b390432022-08-12 20:48:56 +0100446 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100447 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100448 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100449 ):
450 error_result = True
451
452 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100453 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
454 input_dtype, output_dtype
455 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100456
457 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
458 if (
459 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
460 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100461 or (
462 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100463 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100464 )
James Ward24dbc422022-10-19 12:20:31 +0100465 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100466 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100467 ):
468 error_result = True
469
470 elif op["op"] == Op.ARGMAX:
471 if (
James Ward24dbc422022-10-19 12:20:31 +0100472 input_dtype
473 in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100474 and output_dtype != DType.INT32
475 ):
476 error_result = True
477
478 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100479 if (
James Ward24dbc422022-10-19 12:20:31 +0100480 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100481 and output_dtype != DType.INT32
482 ):
483 error_result = True
484 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100485 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100486 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
487 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100488 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100489 error_result = True
490
491 elif op["op"] == Op.TABLE:
492 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
493 error_result = True
494 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
495 error_result = True
496
497 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
498 if output_dtype != DType.BOOL:
499 error_result = True
500
501 elif op["op"] == Op.CAST:
502 if (
503 (
504 input_dtype == DType.BOOL
505 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
506 )
507 or (
508 input_dtype == DType.INT8
509 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100510 not in [
511 DType.BOOL,
512 DType.INT16,
513 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100514 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100515 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100516 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100517 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100518 )
519 or (
520 input_dtype == DType.INT16
521 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100522 not in [
523 DType.BOOL,
524 DType.INT8,
525 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100526 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100527 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100528 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100529 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100530 )
531 or (
532 input_dtype == DType.INT32
533 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100534 not in [
535 DType.BOOL,
536 DType.INT8,
537 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100538 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100539 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100540 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100541 ]
542 )
543 or (
544 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000545 and output_dtype
546 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100547 )
548 or (
James Ward24dbc422022-10-19 12:20:31 +0100549 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000550 and output_dtype
551 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +0100552 )
553 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100554 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000555 and output_dtype
556 not in [
557 DType.INT8,
558 DType.INT16,
559 DType.INT32,
560 DType.FP16,
561 DType.BF16,
562 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100563 )
564 ):
565 error_result = True
566
Luke Hutton57287132023-02-06 14:54:18 +0000567 elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000568 if not all([ty == input_dtype for ty in output_dtype]):
569 error_result = True
570
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100571 elif op["op"] in {
572 Op.CONV2D,
573 Op.CONV3D,
574 Op.DEPTHWISE_CONV2D,
575 Op.TRANSPOSE_CONV2D,
576 }:
577 if (
578 input_dtype == DType.INT8
579 and output_dtype != DType.INT32
580 or input_dtype == DType.INT16
581 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100582 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100583 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100584 or input_dtype == DType.BF16
585 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100586 or input_dtype == DType.FP32
587 and output_dtype != DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100588 ):
589 error_result = True
590 # invalid input types are ignored, to avoid reporting multiple errors
591
592 else:
593 if output_dtype != input_dtype:
594 error_result = True
595
596 info_dict = {
597 "error_name": ErrorIf.WrongOutputType,
598 "error_result": error_result,
599 "error_reason": (
600 "Output data type not supported for this configuration of operator"
601 ),
602 "param_reqs": {"rank": None, "dtype": None, "shape": None},
603 }
604 return info_dict
605
606 @staticmethod
607 def evWrongRank(check=False, **kwargs):
608 all_ranks = (1, 2, 3, 4, 5)
609
610 # Make a list of incorrect ranks
611 assert "op" in kwargs
612 op = kwargs["op"]
613 rmin, rmax = op["rank"]
614 rank_range = range(rmin, rmax + 1)
615 incorrect_ranks = list(set(all_ranks) - set(rank_range))
616 # Remove small incorrect ranks to avoid index errors
617 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
618 # Set minimum incorrect rank to 3 to avoid index error
619 if op["op"] in [Op.RESIZE]:
620 incorrect_ranks = [3, 5]
621 elif op["op"] in [Op.TRANSPOSE]:
622 incorrect_ranks = [7, 8]
623 elif op["op"] in [Op.CONV3D]:
624 incorrect_ranks = [6, 7]
625
626 error_name = ErrorIf.WrongRank
627 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
628 error_result = False
629 error_reason = "Rank not supported for this operator"
630
631 if check:
632 input_shape = kwargs["input_shape"]
633
634 if (
635 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
636 and len(input_shape) != 4
637 ):
638 error_result = True
639 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
640 error_result = True
641 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
642 error_result = True
643 else:
644 if len(input_shape) not in rank_range:
645 error_result = True
646
647 info_dict = {
648 "error_name": error_name,
649 "error_result": error_result,
650 "error_reason": error_reason,
651 "param_reqs": param_reqs,
652 }
653 return info_dict
654
655 @staticmethod
656 def evWrongInputList(check=False, **kwargs):
657 error_name = ErrorIf.WrongInputList
658 param_reqs = {"rank": None, "dtype": None, "shape": None}
659 error_result = False
660 error_reason = "Op input list does not match expected input"
661
662 if check:
663 op = kwargs["op"]
664 input_list = kwargs["input_list"]
665 num_operands = kwargs["num_operands"]
666 if op["op"] in [Op.SCATTER, Op.GATHER]:
667 # SCATTER/GATHER add an indices input tensor in their build functions
668 num_operands += 1
669 if len(input_list) != num_operands:
670 error_result = True
671
672 info_dict = {
673 "error_name": error_name,
674 "error_result": error_result,
675 "error_reason": error_reason,
676 "param_reqs": param_reqs,
677 }
678 return info_dict
679
680 @staticmethod
681 def evWrongOutputList(check=False, **kwargs):
682 error_name = ErrorIf.WrongOutputList
683 param_reqs = {"rank": None, "dtype": None, "shape": None}
684 error_result = False
685 error_reason = "Op output list does not match expected output"
686
687 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000688 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100689 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000690 expected_length = 1
Luke Hutton57287132023-02-06 14:54:18 +0000691 if op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000692 expected_length = 2
693
694 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100695 error_result = True
696
697 info_dict = {
698 "error_name": error_name,
699 "error_result": error_result,
700 "error_reason": error_reason,
701 "param_reqs": param_reqs,
702 }
703 return info_dict
704
705 @staticmethod
706 def evMaxDimExceeded(check=False, **kwargs):
707 error_name = ErrorIf.MaxDimExceeded
708 param_reqs = {
709 "rank": [4, 4],
710 "dtype": [DType.INT8],
711 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
712 }
713 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100714 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100715
716 if check:
717 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100718 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100719 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100720 (input_shape[1] >= MAX_RESIZE_DIMENSION)
721 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
722 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
723 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100724 ):
725 error_result = True
726
727 info_dict = {
728 "error_name": error_name,
729 "error_result": error_result,
730 "error_reason": error_reason,
731 "param_reqs": param_reqs,
732 }
733 return info_dict
734
735 @staticmethod
736 def evBatchMismatch(check=False, **kwargs):
737 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000738 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100739 error_result = False
740 error_reason = "Input batch size not equal to output batch size"
741
742 assert "op" in kwargs
743 op = kwargs["op"]
744 rmin, rmax = op["rank"]
745 rank_range = range(rmin, rmax + 1)
746
747 if check:
748 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100749
Luke Hutton261b7b62023-01-10 14:50:31 +0000750 for output in kwargs["result_tensors"]:
751 output_shape = (
752 output.shape
753 ) # Note batch is expected to be the first dim
754 if (len(input_shape) in rank_range) and (
755 input_shape[0] != output_shape[0]
756 ):
757 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100758
759 info_dict = {
760 "error_name": error_name,
761 "error_result": error_result,
762 "error_reason": error_reason,
763 "param_reqs": param_reqs,
764 }
765 return info_dict
766
767 @staticmethod
768 def evChannelMismatch(check=False, **kwargs):
769 error_name = ErrorIf.ChannelMismatch
770 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
771 error_result = False
772 error_reason = "Input channel size not equal to output channel size"
773
774 assert "op" in kwargs
775 op = kwargs["op"]
776 rmin, rmax = op["rank"]
777 rank_range = range(rmin, rmax + 1)
778
779 if check:
780 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000781 for output in kwargs["result_tensors"]:
782 output_shape = output.shape # Note this is just (N, OH, OW, C)
783 if (len(input_shape) in rank_range) and (
784 input_shape[3] != output_shape[3]
785 ):
786 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100787
788 info_dict = {
789 "error_name": error_name,
790 "error_result": error_result,
791 "error_reason": error_reason,
792 "param_reqs": param_reqs,
793 }
794 return info_dict
795
796 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100797 def evScaleSmallerEqualZero(check=False, **kwargs):
798 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100799 param_reqs = {"rank": None, "dtype": None, "shape": None}
800 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100801 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100802
803 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100804 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100805
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100806 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100807 error_result = True
808
809 info_dict = {
810 "error_name": error_name,
811 "error_result": error_result,
812 "error_reason": error_reason,
813 "param_reqs": param_reqs,
814 }
815 return info_dict
816
817 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100818 def evScaleNLargerMax(check=False, **kwargs):
819 error_name = ErrorIf.ScaleNLargerMax
820 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100821 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100822 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100823
824 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100825 scale = kwargs["scale"]
826
827 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
828 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100829
830 info_dict = {
831 "error_name": error_name,
832 "error_result": error_result,
833 "error_reason": error_reason,
834 "param_reqs": param_reqs,
835 }
836 return info_dict
837
838 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100839 def evScaleDLargerMax(check=False, **kwargs):
840 error_name = ErrorIf.ScaleDLargerMax
841 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100842 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100843 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100844
845 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100846 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100847
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100848 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
849 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100850 ):
851 error_result = True
852
853 info_dict = {
854 "error_name": error_name,
855 "error_result": error_result,
856 "error_reason": error_reason,
857 "param_reqs": param_reqs,
858 }
859 return info_dict
860
861 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100862 def evOffsetSmallerMin(check=False, **kwargs):
863 error_name = ErrorIf.OffsetSmallerMin
864 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100865 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100866 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100867
868 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100869 scale = kwargs["scale"]
870 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100871
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100872 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100873 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100874 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100875 error_result = True
876
877 info_dict = {
878 "error_name": error_name,
879 "error_result": error_result,
880 "error_reason": error_reason,
881 "param_reqs": param_reqs,
882 }
883 return info_dict
884
885 @staticmethod
886 def evOffsetLargerEqualMax(check=False, **kwargs):
887 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100888 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100889 error_result = False
890 error_reason = "Offset value larger than or equal to maximum value"
891
892 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100893 scale = kwargs["scale"]
894 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100895
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100896 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
897 error_result = True
898 elif (
899 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
900 ):
901 error_result = True
902
903 info_dict = {
904 "error_name": error_name,
905 "error_result": error_result,
906 "error_reason": error_reason,
907 "param_reqs": param_reqs,
908 }
909 return info_dict
910
911 @staticmethod
912 def evBorderSmallerMin(check=False, **kwargs):
913 error_name = ErrorIf.BorderSmallerMin
914 param_reqs = {"rank": None, "dtype": None, "shape": None}
915 error_result = False
916 error_reason = "Border value smaller than minimum value"
917
918 if check:
919 scale = kwargs["scale"]
920 border = kwargs["border"]
921
922 if (
923 scale[0] > 0
924 and scale[0] <= (1 << 11)
925 and (border[0] < (-16 * scale[0]))
926 ):
927 error_result = True
928 elif (
929 scale[2] > 0
930 and scale[2] <= (1 << 11)
931 and (border[1] < (-16 * scale[2]))
932 ):
933 error_result = True
934
935 info_dict = {
936 "error_name": error_name,
937 "error_result": error_result,
938 "error_reason": error_reason,
939 "param_reqs": param_reqs,
940 }
941 return info_dict
942
943 @staticmethod
944 def evBorderLargerEqualMax(check=False, **kwargs):
945 error_name = ErrorIf.BorderLargerEqualMax
946 param_reqs = {"rank": None, "dtype": None, "shape": None}
947 error_result = False
948 error_reason = "Border value larger than or equal to maximum value"
949
950 if check:
951 scale = kwargs["scale"]
952 border = kwargs["border"]
953
954 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
955 error_result = True
956 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
957 error_result = True
958
959 info_dict = {
960 "error_name": error_name,
961 "error_result": error_result,
962 "error_reason": error_reason,
963 "param_reqs": param_reqs,
964 }
965 return info_dict
966
967 @staticmethod
968 def checkResizeParams(scale, offset, border):
969 return (
970 min(scale) > 0
971 and max(scale[0], scale[2]) <= (1 << 11)
972 and scale[1] < 16 * scale[0]
973 and scale[3] < 16 * scale[2]
974 and offset[0] >= -scale[0]
975 and offset[1] >= -scale[2]
976 and offset[0] < 16 * scale[0]
977 and offset[1] < 16 * scale[2]
978 and border[0] >= -16 * scale[0]
979 and border[1] >= -16 * scale[2]
980 and border[0] < scale[0]
981 and border[1] < scale[2]
982 )
983
984 @staticmethod
985 def evResizeOutputShapeMismatch(check=False, **kwargs):
986 error_name = ErrorIf.ResizeOutputShapeMismatch
987 param_reqs = {"rank": None, "dtype": None, "shape": None}
988 error_result = False
989 error_reason = (
990 "Mismatch between output shape provided and expected output shape"
991 )
992
993 if check:
994 input_shape = kwargs["input_shape"]
995 output_shape = kwargs["output_shape"]
996 scale = kwargs["scale"]
997 offset = kwargs["offset"]
998 border = kwargs["border"]
999
1000 # Ensure parameters are valid
1001 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1002
1003 if (
1004 params_valid
1005 and max(output_shape) < MAX_RESIZE_DIMENSION
1006 and max(input_shape) < MAX_RESIZE_DIMENSION
1007 ):
1008 output_y = (
1009 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1010 ) // scale[1] + 1
1011 output_x = (
1012 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1013 ) // scale[3] + 1
1014
1015 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001016 error_result = True
1017
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001018 info_dict = {
1019 "error_name": error_name,
1020 "error_result": error_result,
1021 "error_reason": error_reason,
1022 "param_reqs": param_reqs,
1023 }
1024 return info_dict
1025
1026 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001027 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1028 error_name = ErrorIf.ResizeOutputShapeNonInteger
1029 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001030 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001031 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001032
1033 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001034 input_shape = kwargs["input_shape"]
1035 scale = kwargs["scale"]
1036 offset = kwargs["offset"]
1037 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001038
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001039 # Ensure parameters are valid
1040 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001041
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001042 if params_valid:
1043 remainder_y = (
1044 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1045 ) % scale[1]
1046 remainder_x = (
1047 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1048 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001049
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001050 if max(remainder_y, remainder_x) > 0:
1051 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001052
1053 info_dict = {
1054 "error_name": error_name,
1055 "error_result": error_result,
1056 "error_reason": error_reason,
1057 "param_reqs": param_reqs,
1058 }
1059 return info_dict
1060
1061 @staticmethod
1062 def evRankMismatch(check=False, **kwargs):
1063 error_name = ErrorIf.RankMismatch
1064 param_reqs = {"rank": None, "dtype": None, "shape": None}
1065 error_result = False
1066 error_reason = "Input Rank does not match output rank"
1067
1068 if check:
1069 input1_shape = kwargs["input1"].shape
Luke Huttona4e48ca2023-02-22 11:53:48 +00001070 input2_shape = (
1071 kwargs["input2"].shape if "input2" in kwargs else input1_shape
1072 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001073 # In case of SELECT op
1074 input3_shape = (
1075 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1076 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001077
1078 for output in kwargs["result_tensors"]:
1079 output_shape = output.shape
1080 if (
1081 (len(input1_shape) != len(output_shape))
1082 or (len(input2_shape) != len(output_shape))
1083 or (len(input3_shape) != len(output_shape))
1084 ):
1085 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001086
1087 info_dict = {
1088 "error_name": error_name,
1089 "error_result": error_result,
1090 "error_reason": error_reason,
1091 "param_reqs": param_reqs,
1092 }
1093 return info_dict
1094
1095 @staticmethod
1096 def evDimensionMismatch(check=False, **kwargs):
1097 error_name = ErrorIf.DimensionMismatch
1098 param_reqs = {"rank": None, "dtype": None, "shape": None}
1099 error_result = False
1100 error_reason = "Input Dimensions do not match output"
1101
1102 if check:
1103 input1_shape = kwargs["input1"].shape
1104 input2_shape = kwargs["input2"].shape
1105 # In case of SELECT op
1106 input3_shape = (
1107 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1108 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001109
1110 for output in kwargs["result_tensors"]:
1111 output_shape = output.shape
1112 for i in range(
1113 min(len(input1_shape), len(input2_shape), len(input3_shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001114 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00001115 if (
1116 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
1117 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
1118 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
1119 ):
1120 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001121
1122 info_dict = {
1123 "error_name": error_name,
1124 "error_result": error_result,
1125 "error_reason": error_reason,
1126 "param_reqs": param_reqs,
1127 }
1128 return info_dict
1129
1130 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001131 def _getZeroPoint(qinfo, index):
1132 """Return zero point value from quantization info.
1133
1134 Generally input_zp is index 0, output_zp is index 1
1135 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001136 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001137
1138 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001139 def evInputZeroPointNotZero(check=False, **kwargs):
1140 op = kwargs["op"]
1141 error_result = False
1142
1143 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001144 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001145
1146 # This does not apply to quantizable types
1147 inputDtypes = [
1148 dtype
1149 for dtype in op["types"]
1150 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1151 or (not isinstance(dtype, list) and dtype not in qTypes)
1152 ]
1153
1154 if check:
1155 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001156 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001157 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001158 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001159 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001160 (kwargs["input_dtype"], input_zero_point),
1161 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001162 ):
1163 if dtype not in qTypes and zp != 0:
1164 error_result = True
1165 break
1166 else:
1167 error_result = input_dtype not in qTypes and input_zero_point != 0
1168
1169 info_dict = {
1170 "error_name": ErrorIf.InputZeroPointNotZero,
1171 "error_result": error_result,
1172 "error_reason": "Input DType not INT8 and zero point not 0",
1173 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1174 }
1175 return info_dict
1176
1177 @staticmethod
1178 def evWeightZeroPointNotZero(check=False, **kwargs):
1179 op = kwargs["op"]
1180
1181 # exclude inputs with INT8 weights
1182 inputDtypes = [
1183 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1184 ]
1185
1186 error_name = ErrorIf.WeightZeroPointNotZero
1187 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1188 error_result = False
1189 error_reason = "Weight DType not INT8 and zero point not 0"
1190
1191 if check:
1192 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001193 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001194 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1195 error_result = True
1196
1197 info_dict = {
1198 "error_name": error_name,
1199 "error_result": error_result,
1200 "error_reason": error_reason,
1201 "param_reqs": param_reqs,
1202 }
1203 return info_dict
1204
1205 @staticmethod
1206 def evOutputZeroPointNotZero(check=False, **kwargs):
1207 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001208 inputDtypes = [
1209 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1210 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001211
1212 error_name = ErrorIf.OutputZeroPointNotZero
1213 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1214 error_result = False
1215 error_reason = "Output DType not INT8 and zero point not 0"
1216
1217 if check:
1218 input_dtype = kwargs["input_dtype"]
1219 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001220 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001221 if op["op"] == Op.AVG_POOL2D:
1222 if input_dtype != DType.INT8 and output_zero_point != 0:
1223 error_result = True
1224 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001225 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1226 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001227 ):
1228 error_result = True
1229
1230 info_dict = {
1231 "error_name": error_name,
1232 "error_result": error_result,
1233 "error_reason": error_reason,
1234 "param_reqs": param_reqs,
1235 }
1236 return info_dict
1237
1238 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001239 def evU16InputZeroPointNotValid(check=False, **kwargs):
1240 error_name = ErrorIf.U16InputZeroPointNotValid
1241 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1242 error_result = False
1243 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1244
1245 if check:
1246 input_dtype = kwargs["input_dtype"]
1247 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1248 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1249 0,
1250 32768,
1251 ]
1252
1253 info_dict = {
1254 "error_name": error_name,
1255 "error_result": error_result,
1256 "error_reason": error_reason,
1257 "param_reqs": param_reqs,
1258 }
1259 return info_dict
1260
1261 @staticmethod
1262 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1263 error_name = ErrorIf.U16OutputZeroPointNotValid
1264 param_reqs = {"rank": None, "dtype": None, "shape": None}
1265 error_result = False
1266 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1267
1268 if check:
1269 output_dtype = kwargs["output_dtype"]
1270 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1271
1272 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1273 0,
1274 32768,
1275 ]
1276
1277 info_dict = {
1278 "error_name": error_name,
1279 "error_result": error_result,
1280 "error_reason": error_reason,
1281 "param_reqs": param_reqs,
1282 }
1283 return info_dict
1284
1285 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001286 def evAxisSmallerZero(check=False, **kwargs):
1287 error_name = ErrorIf.AxisSmallerZero
1288 param_reqs = {"rank": None, "dtype": None, "shape": None}
1289 error_result = False
1290 error_reason = "Axis smaller than zero"
1291
1292 if check:
1293 axis = kwargs["axis"]
1294 if axis < 0:
1295 error_result = True
1296
1297 info_dict = {
1298 "error_name": error_name,
1299 "error_result": error_result,
1300 "error_reason": error_reason,
1301 "param_reqs": param_reqs,
1302 }
1303 return info_dict
1304
1305 @staticmethod
1306 def evAxisLargerRank(check=False, **kwargs):
1307 error_name = ErrorIf.AxisLargerRank
1308 param_reqs = {"rank": None, "dtype": None, "shape": None}
1309 error_result = False
1310 error_reason = "Axis larger than rank"
1311
1312 if check:
1313 axis = kwargs["axis"]
1314 shape = kwargs["input_shape"]
1315 if axis > len(shape):
1316 error_result = True
1317
1318 info_dict = {
1319 "error_name": error_name,
1320 "error_result": error_result,
1321 "error_reason": error_reason,
1322 "param_reqs": param_reqs,
1323 }
1324 return info_dict
1325
1326 @staticmethod
1327 def evShapeOfAxisNotOne(check=False, **kwargs):
1328 error_name = ErrorIf.ShapeOfAxisNotOne
1329 param_reqs = {"rank": None, "dtype": None, "shape": None}
1330 error_result = False
1331 error_reason = "shape[axis] is not equal to 1"
1332
1333 if check:
1334 axis = kwargs["axis"]
1335 shape = kwargs["output_shape"]
1336 if (0 <= axis < len(shape)) and shape[axis] != 1:
1337 error_result = True
1338
1339 info_dict = {
1340 "error_name": error_name,
1341 "error_result": error_result,
1342 "error_reason": error_reason,
1343 "param_reqs": param_reqs,
1344 }
1345 return info_dict
1346
1347 @staticmethod
1348 def evPadSmallerZero(check=False, **kwargs):
1349 error_name = ErrorIf.PadSmallerZero
1350 param_reqs = {"rank": None, "dtype": None, "shape": None}
1351 error_result = False
1352 error_reason = "At least one pad is smaller than zero"
1353
1354 if check:
1355 op = kwargs["op"]
1356 pad = kwargs["pad"]
1357 if op["op"] == Op.PAD:
1358 for padding in pad:
1359 if min(padding) < 0:
1360 error_result = True
1361 else:
1362 if min(pad) < 0:
1363 error_result = True
1364
1365 info_dict = {
1366 "error_name": error_name,
1367 "error_result": error_result,
1368 "error_reason": error_reason,
1369 "param_reqs": param_reqs,
1370 }
1371 return info_dict
1372
1373 @staticmethod
1374 def evPadLargerEqualKernel(check=False, **kwargs):
1375 error_name = ErrorIf.PadLargerEqualKernel
1376 param_reqs = {"rank": None, "dtype": None, "shape": None}
1377 error_result = False
1378 error_reason = "At least one pad is larger than kernel dimension"
1379
1380 if check:
1381 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001382 op = kwargs["op"]
1383 if op["op"] == Op.TRANSPOSE_CONV2D:
1384 # transpose_conv2d
1385 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001386 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001387 pad[0] <= -kernel[0]
1388 or pad[1] <= -kernel[0]
1389 or pad[2] <= -kernel[1]
1390 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001391 ):
1392 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001393 else:
1394 # pooling op
1395 kernel = kwargs["kernel"]
1396 if min(pad) > 0 and min(kernel) > 1:
1397 if (
1398 pad[0] >= kernel[0]
1399 or pad[1] >= kernel[0]
1400 or pad[2] >= kernel[1]
1401 or pad[3] >= kernel[1]
1402 ):
1403 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001404
1405 info_dict = {
1406 "error_name": error_name,
1407 "error_result": error_result,
1408 "error_reason": error_reason,
1409 "param_reqs": param_reqs,
1410 }
1411 return info_dict
1412
1413 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001414 def evPadOutputShapeMismatch(check=False, **kwargs):
1415 error_name = ErrorIf.PadOutputShapeMismatch
1416 param_reqs = {"rank": None, "dtype": None, "shape": None}
1417 error_result = False
1418 error_reason = "Pad output shape mismatch for requested padding"
1419
1420 if check:
1421 pad = kwargs["pad"]
1422 input_shape = kwargs["input_shape"]
1423 output_shape = kwargs["output_shape"]
1424 for dim, padding in enumerate(pad):
1425 expected_size = input_shape[dim] + padding[0] + padding[1]
1426 if expected_size != output_shape[dim]:
1427 error_result = True
1428
1429 info_dict = {
1430 "error_name": error_name,
1431 "error_result": error_result,
1432 "error_reason": error_reason,
1433 "param_reqs": param_reqs,
1434 }
1435 return info_dict
1436
1437 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001438 def checkPoolingParams(kernel, stride, pad):
1439 return (
1440 min(kernel) >= 1
1441 and min(stride) >= 1
1442 and min(pad) >= 0
1443 and not (
1444 pad[0] >= kernel[0]
1445 or pad[1] >= kernel[0]
1446 or pad[2] >= kernel[1]
1447 or pad[3] >= kernel[1]
1448 )
1449 )
1450
1451 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001452 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1453 error_name = ErrorIf.PoolingOutputShapeMismatch
1454 param_reqs = {"rank": None, "dtype": None, "shape": None}
1455 error_result = False
1456 error_reason = (
1457 "Mismatch between output shape provided and expected output shape"
1458 )
1459
1460 if check:
1461 pad = kwargs["pad"]
1462 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1463
1464 kernel = kwargs["kernel"]
1465 kernel_y, kernel_x = kernel[0], kernel[1]
1466
1467 input_shape = kwargs["input_shape"]
1468 IH, IW = input_shape[1], input_shape[2]
1469
1470 output_shape = kwargs["output_shape"]
1471 OH, OW = output_shape[1], output_shape[2]
1472
1473 stride = kwargs["stride"]
1474 stride_y, stride_x = stride[0], stride[1]
1475
1476 # calculate correct height, width dimensions
1477 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001478 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1479 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001480
1481 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001482 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001483
1484 if params_valid and (OH != y_correct or OW != x_correct):
1485 error_result = True
1486
1487 info_dict = {
1488 "error_name": error_name,
1489 "error_result": error_result,
1490 "error_reason": error_reason,
1491 "param_reqs": param_reqs,
1492 }
1493 return info_dict
1494
1495 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001496 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1497 error_name = ErrorIf.PoolingOutputShapeNonInteger
1498 param_reqs = {"rank": None, "dtype": None, "shape": None}
1499 error_result = False
1500 error_reason = "Parameters do not yield exact integer output dimensions"
1501
1502 if check:
1503 pad = kwargs["pad"]
1504 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1505
1506 kernel = kwargs["kernel"]
1507 kernel_y, kernel_x = kernel[0], kernel[1]
1508
1509 input_shape = kwargs["input_shape"]
1510 IH, IW = input_shape[1], input_shape[2]
1511
1512 stride = kwargs["stride"]
1513 stride_y, stride_x = stride[0], stride[1]
1514
1515 # calculate remainder of height, width dimensions
1516 if stride_x != 0 and stride_y != 0:
1517 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1518 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1519
1520 # ensure parameters are valid
1521 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1522 if params_valid and (y_remainder != 0 or x_remainder != 0):
1523 error_result = True
1524
1525 info_dict = {
1526 "error_name": error_name,
1527 "error_result": error_result,
1528 "error_reason": error_reason,
1529 "param_reqs": param_reqs,
1530 }
1531 return info_dict
1532
1533 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001534 def checkConvParams(op, weight_shape, stride, pad, dilation):
1535 if op == Op.TRANSPOSE_CONV2D:
1536 pad_ok = (
1537 pad[0] > -weight_shape[1]
1538 and pad[1] > -weight_shape[1]
1539 and pad[2] > -weight_shape[2]
1540 and pad[3] > -weight_shape[2]
1541 )
1542 else:
1543 pad_ok = min(pad) >= 0
1544
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001545 return (
1546 # Check kernel sizes
1547 min(weight_shape[1:-1]) >= 1
1548 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001549 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001550 and (dilation is None or min(dilation) >= 1)
1551 )
1552
1553 @staticmethod
1554 def evConvOutputShapeMismatch(check=False, **kwargs):
1555 error_name = ErrorIf.ConvOutputShapeMismatch
1556 param_reqs = {"rank": None, "dtype": None, "shape": None}
1557 error_result = False
1558 error_reason = (
1559 "Mismatch between output shape provided and expected output shape"
1560 )
1561
1562 if check:
1563 op = kwargs["op"]
1564 pad = kwargs["pad"]
1565 weight_shape = kwargs["weight_shape"]
1566 input_shape = kwargs["input_shape"]
1567 output_shape = kwargs["output_shape"]
1568 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1569 stride = kwargs["stride"]
1570
1571 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1572
1573 # calculate correct dimensions
1574 dims_correct = []
1575 if min(stride) > 0:
1576 for index in range(len(stride)):
1577 pad_offset = index * 2
1578 if op["op"] == Op.TRANSPOSE_CONV2D:
1579 dims_correct.append(
1580 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001581 + pad[pad_offset]
1582 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001583 + weight_shape[index + kernel_offset]
1584 )
1585 else:
1586 dims_correct.append(
1587 (
1588 input_shape[index + 1]
1589 - 1
1590 + pad[pad_offset]
1591 + pad[pad_offset + 1]
1592 - (weight_shape[index + kernel_offset] - 1)
1593 * dilation[index]
1594 )
1595 // stride[index]
1596 + 1
1597 )
1598
1599 # ensure parameters are valid
1600 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001601 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001602 )
1603
1604 if params_valid and output_shape[1:-1] != dims_correct:
1605 error_result = True
1606
1607 info_dict = {
1608 "error_name": error_name,
1609 "error_result": error_result,
1610 "error_reason": error_reason,
1611 "param_reqs": param_reqs,
1612 }
1613 return info_dict
1614
1615 @staticmethod
1616 def evConvOutputShapeNonInteger(check=False, **kwargs):
1617 error_name = ErrorIf.ConvOutputShapeNonInteger
1618 param_reqs = {"rank": None, "dtype": None, "shape": None}
1619 error_result = False
1620 error_reason = "Parameters do not yield exact integer output dimensions"
1621
1622 if check:
1623 op = kwargs["op"]
1624 pad = kwargs["pad"]
1625 weight_shape = kwargs["weight_shape"]
1626 input_shape = kwargs["input_shape"]
1627 dilation = kwargs["dilation"]
1628 stride = kwargs["stride"]
1629
1630 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1631
1632 # calculate correct height, width dimensions
1633 remainders = []
1634 if min(stride) > 0:
1635 for index in range(len(stride)):
1636 pad_offset = index * 2
1637 remainders.append(
1638 (
1639 input_shape[index + 1]
1640 - 1
1641 + pad[pad_offset]
1642 + pad[pad_offset + 1]
1643 - (weight_shape[index + kernel_offset] - 1)
1644 * dilation[index]
1645 )
1646 % stride[index]
1647 )
1648
1649 # ensure parameters are valid
1650 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001651 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001652 )
1653 if params_valid and max(remainders) > 0:
1654 error_result = True
1655
1656 info_dict = {
1657 "error_name": error_name,
1658 "error_result": error_result,
1659 "error_reason": error_reason,
1660 "param_reqs": param_reqs,
1661 }
1662 return info_dict
1663
1664 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001665 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1666 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1667 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1668 error_result = False
1669 error_reason = (
1670 "Mismatch between output shape provided and expected output shape"
1671 )
1672
1673 if check:
1674 output_shape = kwargs["output_shape"]
1675 input_shape = kwargs["input_shape"]
1676 axis = kwargs["axis"]
1677
1678 dimension_match = True
1679 axis_shift = 0
1680
1681 # Check that rank is correct before trying to check dimensions
1682 if (len(input_shape) - 1) == len(output_shape):
1683 for i in range(len(input_shape)):
1684 if i == axis:
1685 axis_shift = 1
1686 continue
1687 if input_shape[i] != output_shape[i - axis_shift]:
1688 dimension_match = False
1689
1690 if not dimension_match:
1691 error_result = True
1692
1693 info_dict = {
1694 "error_name": error_name,
1695 "error_result": error_result,
1696 "error_reason": error_reason,
1697 "param_reqs": param_reqs,
1698 }
1699 return info_dict
1700
1701 @staticmethod
1702 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1703 error_name = ErrorIf.ArgmaxOutputRankMismatch
1704 param_reqs = {"rank": None, "dtype": None, "shape": None}
1705 error_result = False
1706 error_reason = (
1707 "Mismatch between output shape provided and expected output shape"
1708 )
1709
1710 if check:
1711 output_shape = kwargs["output_shape"]
1712 input_shape = kwargs["input_shape"]
1713 axis = kwargs["axis"]
1714 valid_params = axis >= 0 and axis < len(input_shape)
1715
1716 if valid_params and (len(input_shape) - 1) != len(output_shape):
1717 error_result = True
1718
1719 info_dict = {
1720 "error_name": error_name,
1721 "error_result": error_result,
1722 "error_reason": error_reason,
1723 "param_reqs": param_reqs,
1724 }
1725 return info_dict
1726
1727 @staticmethod
1728 def evKernelSmallerOne(check=False, **kwargs):
1729 error_name = ErrorIf.KernelSmallerOne
1730 param_reqs = {"rank": None, "dtype": None, "shape": None}
1731 error_result = False
1732 error_reason = "At least one kernel dimension is smaller than zero"
1733
1734 if check:
1735 kernel = kwargs["kernel"]
1736 if min(kernel) < 1:
1737 error_result = True
1738
1739 info_dict = {
1740 "error_name": error_name,
1741 "error_result": error_result,
1742 "error_reason": error_reason,
1743 "param_reqs": param_reqs,
1744 }
1745 return info_dict
1746
1747 @staticmethod
1748 def evStrideSmallerOne(check=False, **kwargs):
1749 error_name = ErrorIf.StrideSmallerOne
1750 param_reqs = {"rank": None, "dtype": None, "shape": None}
1751 error_result = False
1752 error_reason = "At least one stride dimension is smaller than zero"
1753
1754 if check:
1755 stride = kwargs["stride"]
1756 if min(stride) < 1:
1757 error_result = True
1758
1759 info_dict = {
1760 "error_name": error_name,
1761 "error_result": error_result,
1762 "error_reason": error_reason,
1763 "param_reqs": param_reqs,
1764 }
1765 return info_dict
1766
1767 @staticmethod
1768 def evDilationSmallerOne(check=False, **kwargs):
1769 error_result = check and min(kwargs["dilation"]) < 1
1770 return {
1771 "error_name": ErrorIf.DilationSmallerOne,
1772 "error_reason": "At least one dilation is smaller than one",
1773 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1774 "error_result": error_result,
1775 }
1776
1777 @staticmethod
1778 def evScaleTrue(check=False, **kwargs):
1779 error_name = ErrorIf.ScaleTrue
1780 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1781 error_result = False
1782 error_reason = "Scale set to true but input type is INT48"
1783
1784 if check:
1785 input_dtype = kwargs["input_dtype"]
1786 scale32 = kwargs["scale32"]
1787 if scale32 and input_dtype == DType.INT48:
1788 error_result = True
1789
1790 info_dict = {
1791 "error_name": error_name,
1792 "error_result": error_result,
1793 "error_reason": error_reason,
1794 "param_reqs": param_reqs,
1795 }
1796 return info_dict
1797
1798 @staticmethod
1799 def evScaleNotTrue(check=False, **kwargs):
1800 error_name = ErrorIf.ScaleNotTrue
1801 param_reqs = {"rank": None, "dtype": None, "shape": None}
1802 error_result = False
1803 error_reason = "Scale set to false but double round set to true"
1804
1805 if check:
1806 scale32 = kwargs["scale32"]
1807 double_round = kwargs["double_round"]
1808 if not scale32 and double_round:
1809 error_result = True
1810
1811 info_dict = {
1812 "error_name": error_name,
1813 "error_result": error_result,
1814 "error_reason": error_reason,
1815 "param_reqs": param_reqs,
1816 }
1817 return info_dict
1818
1819 @staticmethod
1820 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1821 error_name = ErrorIf.TensorSizeInputOutputMismatch
1822 param_reqs = {"rank": None, "dtype": None, "shape": None}
1823 error_result = False
1824 error_reason = "Input tensor size does not match output tensor size"
1825
1826 if check:
1827 input_shape = kwargs["input_shape"]
1828 output_shape = kwargs["output_shape"]
1829 input_size = np.prod(input_shape)
1830 output_size = np.prod(output_shape)
1831 if input_size != output_size:
1832 error_result = True
1833
1834 info_dict = {
1835 "error_name": error_name,
1836 "error_result": error_result,
1837 "error_reason": error_reason,
1838 "param_reqs": param_reqs,
1839 }
1840 return info_dict
1841
1842 @staticmethod
1843 def evStartSmallerZero(check=False, **kwargs):
1844 error_name = ErrorIf.StartSmallerZero
1845 param_reqs = {"rank": None, "dtype": None, "shape": None}
1846 error_result = False
1847 error_reason = "Starting point smaller than zero"
1848
1849 if check:
1850 input_shape = kwargs["input_shape"]
1851 start = kwargs["start"]
1852 rank = len(input_shape)
1853 if len(start) == rank:
1854 for index in range(rank):
1855 if start[index] < 0:
1856 error_result = True
1857
1858 info_dict = {
1859 "error_name": error_name,
1860 "error_result": error_result,
1861 "error_reason": error_reason,
1862 "param_reqs": param_reqs,
1863 }
1864 return info_dict
1865
1866 @staticmethod
1867 def evSizeSmallerEqualZero(check=False, **kwargs):
1868 error_name = ErrorIf.SizeSmallerEqualZero
1869 param_reqs = {"rank": None, "dtype": None, "shape": None}
1870 error_result = False
1871 error_reason = "Size smaller than or equal to zero"
1872
1873 if check:
1874 input_shape = kwargs["input_shape"]
1875 size = kwargs["size"]
1876 rank = len(input_shape)
1877 if len(size) == rank:
1878 for index in range(rank):
1879 if size[index] <= 0:
1880 error_result = True
1881
1882 info_dict = {
1883 "error_name": error_name,
1884 "error_result": error_result,
1885 "error_reason": error_reason,
1886 "param_reqs": param_reqs,
1887 }
1888 return info_dict
1889
1890 @staticmethod
1891 def evStartSizeOutsideBounds(check=False, **kwargs):
1892 error_name = ErrorIf.StartSizeOutsideBounds
1893 param_reqs = {"rank": None, "dtype": None, "shape": None}
1894 error_result = False
1895 error_reason = "starting point plus size larger than input dimension"
1896
1897 if check:
1898 input_shape = kwargs["input_shape"]
1899 start = kwargs["start"]
1900 size = kwargs["size"]
1901 rank = len(input_shape)
1902 if len(start) == rank and len(size) == rank:
1903 for index in range(rank):
1904 if start[index] + size[index] > input_shape[index]:
1905 error_result = True
1906
1907 info_dict = {
1908 "error_name": error_name,
1909 "error_result": error_result,
1910 "error_reason": error_reason,
1911 "param_reqs": param_reqs,
1912 }
1913 return info_dict
1914
1915 @staticmethod
1916 def evSizeOutputShapeMismatch(check=False, **kwargs):
1917 error_name = ErrorIf.SizeOutputShapeMismatch
1918 param_reqs = {"rank": None, "dtype": None, "shape": None}
1919 error_result = False
1920 error_reason = "Size does not match output dimension"
1921
1922 if check:
1923 input_shape = kwargs["input_shape"]
1924 output_shape = kwargs["output_shape"]
1925 size = kwargs["size"]
Luke Huttona4e48ca2023-02-22 11:53:48 +00001926
1927 if len(input_shape) == len(output_shape):
1928 rank = len(input_shape)
1929 if len(size) == rank:
1930 for index in range(rank):
1931 if size[index] != output_shape[index]:
1932 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001933
1934 info_dict = {
1935 "error_name": error_name,
1936 "error_result": error_result,
1937 "error_reason": error_reason,
1938 "param_reqs": param_reqs,
1939 }
1940 return info_dict
1941
1942 @staticmethod
1943 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1944 error_name = ErrorIf.InputSizeStartLengthMismatch
1945 param_reqs = {"rank": None, "dtype": None, "shape": None}
1946 error_result = False
1947 error_reason = "rank of input not equal to length of start or size"
1948
1949 if check:
1950 input_shape = kwargs["input_shape"]
1951 start = kwargs["start"]
1952 size = kwargs["size"]
1953 rank = len(input_shape)
1954 if rank != len(start) or rank != len(size):
1955 error_result = True
1956
1957 info_dict = {
1958 "error_name": error_name,
1959 "error_result": error_result,
1960 "error_reason": error_reason,
1961 "param_reqs": param_reqs,
1962 }
1963 return info_dict
1964
1965 @staticmethod
1966 def evIndexOutsideBounds(check=False, **kwargs):
1967 error_name = ErrorIf.IndexOutsideBounds
1968 param_reqs = {"rank": None, "dtype": None, "shape": None}
1969 error_result = False
1970 error_reason = "Index outside of allowed bounds"
1971
1972 if check:
1973 input_shape = kwargs["input_shape"]
1974 perms = kwargs["perms"]
1975 rank = len(input_shape)
1976
1977 for index in perms:
1978 if index < 0 or index > rank:
1979 error_result = True
1980
1981 info_dict = {
1982 "error_name": error_name,
1983 "error_result": error_result,
1984 "error_reason": error_reason,
1985 "param_reqs": param_reqs,
1986 }
1987 return info_dict
1988
1989 @staticmethod
1990 def evIndexUsedTwice(check=False, **kwargs):
1991 error_name = ErrorIf.IndexUsedTwice
1992 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1993 error_result = False
1994 error_reason = "Index used multiple times"
1995
1996 if check:
1997 perms = kwargs["perms"]
1998
1999 unique_indices = []
2000 for index in perms:
2001 if index in unique_indices:
2002 error_result = True
2003 else:
2004 unique_indices.append(index)
2005
2006 info_dict = {
2007 "error_name": error_name,
2008 "error_result": error_result,
2009 "error_reason": error_reason,
2010 "param_reqs": param_reqs,
2011 }
2012 return info_dict
2013
2014 @staticmethod
2015 def evMaxSmallerMin(check=False, **kwargs):
2016 error_name = ErrorIf.MaxSmallerMin
2017 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2018 error_result = False
2019 error_reason = "Max value smaller than min value"
2020
2021 if check:
2022 max_val = kwargs["max_val"]
2023 min_val = kwargs["min_val"]
2024 if max_val < min_val:
2025 error_result = True
2026
2027 info_dict = {
2028 "error_name": error_name,
2029 "error_result": error_result,
2030 "error_reason": error_reason,
2031 "param_reqs": param_reqs,
2032 }
2033 return info_dict
2034
2035 @staticmethod
2036 def evConcatInputRankMismatch(check=False, **kwargs):
2037 error_name = ErrorIf.ConcatInputRankMismatch
2038 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2039 error_result = False
2040 error_reason = "Input ranks are not identical"
2041
2042 if check:
2043 inputs = kwargs["inputs"]
2044 input_shape = kwargs["input_shape"]
2045 for input in inputs:
2046 if len(input.shape) != len(input_shape):
2047 error_result = True
2048
2049 info_dict = {
2050 "error_name": error_name,
2051 "error_result": error_result,
2052 "error_reason": error_reason,
2053 "param_reqs": param_reqs,
2054 }
2055 return info_dict
2056
2057 @staticmethod
2058 def evConcatInputDimMismatch(check=False, **kwargs):
2059 error_name = ErrorIf.ConcatInputDimMismatch
2060 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2061 error_result = False
2062 error_reason = "Input dimensions differ on too many axes"
2063
2064 if check:
2065 inputs = kwargs["inputs"]
2066 input_shape = kwargs["input_shape"]
2067 axis = kwargs["axis"]
2068
2069 # Ensure rank is valid before checking dims.
2070 valid_rank = True
2071 for input in inputs:
2072 if len(input.shape) != len(input_shape):
2073 valid_rank = False
2074
2075 if valid_rank:
2076 for input in inputs:
2077 for i, dim in enumerate(input.shape):
2078 if dim != input_shape[i] and axis != i:
2079 error_result = True
2080
2081 info_dict = {
2082 "error_name": error_name,
2083 "error_result": error_result,
2084 "error_reason": error_reason,
2085 "param_reqs": param_reqs,
2086 }
2087 return info_dict
2088
2089 @staticmethod
2090 def evConcatShapeSumMismatch(check=False, **kwargs):
2091 error_name = ErrorIf.ConcatShapeSumMismatch
2092 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2093 error_result = False
2094 error_reason = "Sum of dimensions on axis not equal to output dimension"
2095
2096 if check:
2097 inputs = kwargs["inputs"]
2098 input_shape = kwargs["input_shape"]
2099 output_shape = kwargs["output_shape"]
2100 axis = kwargs["axis"]
2101
2102 # Ensure rank is valid before checking dims.
2103 valid_params = True
2104 for input in inputs:
2105 if len(input.shape) != len(input_shape):
2106 valid_params = False
2107 if axis < 0 or axis > len(input_shape):
2108 valid_params = False
2109
2110 if valid_params:
2111 axis_dim_sum = 0
2112 for input in inputs:
2113 axis_dim_sum += input.shape[axis]
2114
2115 if axis_dim_sum != output_shape[axis]:
2116 error_result = True
2117
2118 info_dict = {
2119 "error_name": error_name,
2120 "error_result": error_result,
2121 "error_reason": error_reason,
2122 "param_reqs": param_reqs,
2123 }
2124 return info_dict
2125
2126 @staticmethod
2127 def evInputListThenGraphMismatch(check=False, **kwargs):
2128 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2129 param_reqs = {"rank": None, "dtype": None, "shape": None}
2130 error_result = False
2131 error_reason = "Input list shape does not match then-graph shape"
2132
2133 if check:
2134 a = kwargs["a"]
2135 b = kwargs["b"]
2136 basicBlocks = kwargs["basicBlocks"]
2137 then_block = basicBlocks[1]
2138 then_inputs = then_block.inputs
2139 then_tens = then_block.tensors
2140 if (a.shape != then_tens[then_inputs[0]].shape) or (
2141 b.shape != then_tens[then_inputs[1]].shape
2142 ):
2143 error_result = True
2144
2145 info_dict = {
2146 "error_name": error_name,
2147 "error_result": error_result,
2148 "error_reason": error_reason,
2149 "param_reqs": param_reqs,
2150 }
2151 return info_dict
2152
2153 @staticmethod
2154 def evInputListElseGraphMismatch(check=False, **kwargs):
2155 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2156 param_reqs = {"rank": None, "dtype": None, "shape": None}
2157 error_result = False
2158 error_reason = "Input list shape does not match else-graph shape"
2159
2160 if check:
2161 a = kwargs["a"]
2162 b = kwargs["b"]
2163 basicBlocks = kwargs["basicBlocks"]
2164 else_block = basicBlocks[2]
2165 else_inputs = else_block.inputs
2166 else_tens = else_block.tensors
2167 if (a.shape != else_tens[else_inputs[0]].shape) or (
2168 b.shape != else_tens[else_inputs[1]].shape
2169 ):
2170 error_result = True
2171
2172 info_dict = {
2173 "error_name": error_name,
2174 "error_result": error_result,
2175 "error_reason": error_reason,
2176 "param_reqs": param_reqs,
2177 }
2178 return info_dict
2179
2180 @staticmethod
2181 def evOutputListThenGraphMismatch(check=False, **kwargs):
2182 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2183 param_reqs = {"rank": None, "dtype": None, "shape": None}
2184 error_result = False
2185 error_reason = "Output list shape does not match then-graph shape"
2186
2187 if check:
2188 basicBlocks = kwargs["basicBlocks"]
2189 cond_block = basicBlocks[0]
2190 cond_outputs = cond_block.outputs
2191 cond_tens = cond_block.tensors
2192 then_block = basicBlocks[1]
2193 then_outputs = then_block.outputs
2194 then_tens = then_block.tensors
2195 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2196 error_result = True
2197
2198 info_dict = {
2199 "error_name": error_name,
2200 "error_result": error_result,
2201 "error_reason": error_reason,
2202 "param_reqs": param_reqs,
2203 }
2204 return info_dict
2205
2206 @staticmethod
2207 def evOutputListElseGraphMismatch(check=False, **kwargs):
2208 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2209 param_reqs = {"rank": None, "dtype": None, "shape": None}
2210 error_result = False
2211 error_reason = "Output list shape does not match else-graph shape"
2212
2213 if check:
2214 basicBlocks = kwargs["basicBlocks"]
2215 cond_block = basicBlocks[0]
2216 cond_outputs = cond_block.outputs
2217 cond_tens = cond_block.tensors
2218 else_block = basicBlocks[2]
2219 else_outputs = else_block.outputs
2220 else_tens = else_block.tensors
2221 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2222 error_result = True
2223
2224 info_dict = {
2225 "error_name": error_name,
2226 "error_result": error_result,
2227 "error_reason": error_reason,
2228 "param_reqs": param_reqs,
2229 }
2230 return info_dict
2231
2232 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002233 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2234 error_name = ErrorIf.CondIfCondNotMatchingBool
2235 param_reqs = {"rank": None, "dtype": None, "shape": None}
2236 error_result = False
2237 error_reason = "Conditional tensor does not match bool type"
2238
2239 if check:
2240 cond = kwargs["cond"]
2241 if cond.dtype != DType.BOOL:
2242 error_result = True
2243
2244 info_dict = {
2245 "error_name": error_name,
2246 "error_result": error_result,
2247 "error_reason": error_reason,
2248 "param_reqs": param_reqs,
2249 }
2250 return info_dict
2251
2252 @staticmethod
2253 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2254 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2255 param_reqs = {"rank": None, "dtype": None, "shape": None}
2256 error_result = False
2257 error_reason = "Conditional tensor is not equal to a size of one"
2258
2259 if check:
2260 cond = kwargs["cond"]
2261 # Size of 1 is equivalent to rank 0
2262 if len(cond.shape) != 0:
2263 error_result = True
2264
2265 info_dict = {
2266 "error_name": error_name,
2267 "error_result": error_result,
2268 "error_reason": error_reason,
2269 "param_reqs": param_reqs,
2270 }
2271 return info_dict
2272
2273 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002274 def evInputListOutputListMismatch(check=False, **kwargs):
2275 error_name = ErrorIf.InputListOutputListMismatch
2276 param_reqs = {"rank": None, "dtype": None, "shape": None}
2277 error_result = False
2278 error_reason = "Input list does not match output list"
2279
2280 if check:
2281 basicBlocks = kwargs["basicBlocks"]
2282 while_block = basicBlocks[0]
2283 while_inputs = while_block.inputs
2284 while_outputs = while_block.outputs
2285 while_tens = while_block.tensors
2286 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2287 error_result = True
2288
2289 info_dict = {
2290 "error_name": error_name,
2291 "error_result": error_result,
2292 "error_reason": error_reason,
2293 "param_reqs": param_reqs,
2294 }
2295 return info_dict
2296
2297 @staticmethod
2298 def evInputListCondGraphMismatch(check=False, **kwargs):
2299 error_name = ErrorIf.InputListCondGraphMismatch
2300 param_reqs = {"rank": None, "dtype": None, "shape": None}
2301 error_result = False
2302 error_reason = "Input list does not match cond graph"
2303
2304 if check:
2305 basicBlocks = kwargs["basicBlocks"]
2306 while_block = basicBlocks[0]
2307 while_inputs = while_block.inputs
2308 while_tens = while_block.tensors
2309 cond_block = basicBlocks[1]
2310 cond_inputs = cond_block.inputs
2311 cond_tens = cond_block.tensors
2312 if (
2313 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2314 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2315 error_result = True
2316
2317 info_dict = {
2318 "error_name": error_name,
2319 "error_result": error_result,
2320 "error_reason": error_reason,
2321 "param_reqs": param_reqs,
2322 }
2323 return info_dict
2324
2325 @staticmethod
2326 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2327 error_name = ErrorIf.InputListBodyGraphInputMismatch
2328 param_reqs = {"rank": None, "dtype": None, "shape": None}
2329 error_result = False
2330 error_reason = "Input list does not match body graph input"
2331
2332 if check:
2333 basicBlocks = kwargs["basicBlocks"]
2334 while_block = basicBlocks[0]
2335 while_inputs = while_block.inputs
2336 while_tens = while_block.tensors
2337 body_block = basicBlocks[2]
2338 body_outputs = body_block.inputs
2339 body_tens = body_block.tensors
2340 if (
2341 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2342 ) or (
2343 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2344 ):
2345 error_result = True
2346
2347 info_dict = {
2348 "error_name": error_name,
2349 "error_result": error_result,
2350 "error_reason": error_reason,
2351 "param_reqs": param_reqs,
2352 }
2353 return info_dict
2354
2355 @staticmethod
2356 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2357 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2358 param_reqs = {"rank": None, "dtype": None, "shape": None}
2359 error_result = False
2360 error_reason = "Input list does not match body graph output"
2361
2362 if check:
2363 basicBlocks = kwargs["basicBlocks"]
2364 while_block = basicBlocks[0]
2365 while_inputs = while_block.inputs
2366 while_tens = while_block.tensors
2367 body_block = basicBlocks[2]
2368 body_outputs = body_block.outputs
2369 body_tens = body_block.tensors
2370 if (
2371 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2372 ) or (
2373 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2374 ):
2375 error_result = True
2376 info_dict = {
2377 "error_name": error_name,
2378 "error_result": error_result,
2379 "error_reason": error_reason,
2380 "param_reqs": param_reqs,
2381 }
2382 return info_dict
2383
2384 @staticmethod
2385 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2386 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2387 param_reqs = {"rank": None, "dtype": None, "shape": None}
2388 error_result = False
2389 error_reason = "Cond graph output is not a match list of booleans"
2390
2391 if check:
2392 basicBlocks = kwargs["basicBlocks"]
2393 cond_block = basicBlocks[1]
2394 cond_outputs = cond_block.outputs
2395 cond_tens = cond_block.tensors
2396 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2397 error_result = True
2398
2399 info_dict = {
2400 "error_name": error_name,
2401 "error_result": error_result,
2402 "error_reason": error_reason,
2403 "param_reqs": param_reqs,
2404 }
2405 return info_dict
2406
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002407 @staticmethod
2408 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2409 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2410 param_reqs = {"rank": None, "dtype": None, "shape": None}
2411 error_result = False
2412 error_reason = "Cond graph output is not a shape of size one"
2413
2414 if check:
2415 basicBlocks = kwargs["basicBlocks"]
2416 cond_block = basicBlocks[1]
2417 cond_outputs = cond_block.outputs
2418 cond_tens = cond_block.tensors
2419 # Size of 1 is equivalent to rank 0
2420 if len(cond_tens[cond_outputs[0]].shape) != 0:
2421 error_result = True
2422
2423 info_dict = {
2424 "error_name": error_name,
2425 "error_result": error_result,
2426 "error_reason": error_reason,
2427 "param_reqs": param_reqs,
2428 }
2429 return info_dict
2430
Luke Hutton261b7b62023-01-10 14:50:31 +00002431 @staticmethod
2432 def evKernelNotPowerOfTwo(check=False, **kwargs):
2433 error_name = ErrorIf.KernelNotPowerOfTwo
2434 param_reqs = {"rank": None, "dtype": None, "shape": None}
2435 error_result = False
2436 error_reason = "kernel height and/or width not a power of two"
2437
2438 def is_power_of_two(x):
2439 return math.log(x, 2).is_integer()
2440
2441 if check:
2442 shape = kwargs["input_shape"]
2443 if len(shape) == 3:
2444 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2445 error_result = not valid_kernel
2446
2447 info_dict = {
2448 "error_name": error_name,
2449 "error_result": error_result,
2450 "error_reason": error_reason,
2451 "param_reqs": param_reqs,
2452 }
2453 return info_dict
2454
Luke Hutton57287132023-02-06 14:54:18 +00002455 @staticmethod
2456 def evFFTInputShapeMismatch(check=False, **kwargs):
2457 error_name = ErrorIf.FFTInputShapeMismatch
2458 param_reqs = {"rank": None, "dtype": None, "shape": None}
2459 error_result = False
2460 error_reason = "Mismatch between real and imaginary input shapes"
2461
2462 if check:
2463 input1 = kwargs["input1"]
2464 input2 = kwargs["input2"]
2465
2466 if input1.shape != input2.shape:
2467 error_result = True
2468
2469 info_dict = {
2470 "error_name": error_name,
2471 "error_result": error_result,
2472 "error_reason": error_reason,
2473 "param_reqs": param_reqs,
2474 }
2475 return info_dict
2476
2477 @staticmethod
2478 def evFFTOutputShapeMismatch(check=False, **kwargs):
2479 error_name = ErrorIf.FFTOutputShapeMismatch
2480 param_reqs = {"rank": None, "dtype": None, "shape": None}
2481 error_result = False
2482 error_reason = (
2483 "Mismatch between provided and expected output kernel (H, W) shape"
2484 )
2485
2486 if check:
2487 op = kwargs["op"]
2488 input_shape = kwargs["input_shape"]
2489
2490 if len(input_shape) == 3:
2491 output_shapes = kwargs["output_shape"]
2492
2493 # Ignoring batch size (N) from input shape
2494 expected_shape = input_shape[1:]
2495 if op["op"] == Op.RFFT2D:
2496 expected_shape[1] = expected_shape[1] // 2 + 1
2497
2498 # Ignoring batch size (N) from output shapes
2499 output_shape_0 = output_shapes[0][1:]
2500 output_shape_1 = output_shapes[1][1:]
2501 # Ensure sure the kernel sizes (H, W) of both outputs match the expected
2502 if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
2503 error_result = True
2504
2505 info_dict = {
2506 "error_name": error_name,
2507 "error_result": error_result,
2508 "error_reason": error_reason,
2509 "param_reqs": param_reqs,
2510 }
2511 return info_dict
2512
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002513
2514class TosaInvalidValidator:
2515 @staticmethod
2516 def ivWrongDataTypeOrModeResize(**kwargs):
2517 input_dtype = kwargs["input_dtype"]
2518 args = kwargs["args"]
2519 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002520 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002521
2522 if mode == ResizeMode.BILINEAR:
2523 # Invalid output data type / Invalid input datatype
2524 return (
2525 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002526 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002527 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002528 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002529 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002530 )
2531 elif mode == ResizeMode.NEAREST:
2532 # Invalid output data type / Invalid input datatype
2533 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002534 input_dtype
2535 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002536 )
2537 else:
2538 # Invalid resize mode
2539 return True
2540
2541 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002542 def ivHeightWidthInvalid(**kwargs):
2543 opName = kwargs["opName"]
2544
2545 inputShapes = kwargs["shapeList"]
2546 input_shape = inputShapes[0]
2547
2548 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002549
Jeremy Johnson0c716862023-04-13 17:18:19 +01002550 # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
2551 stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
2552
2553 # Common info for all ops
James Ward8b390432022-08-12 20:48:56 +01002554 strides = args[stride_idx]
2555 padding = args[pad_idx]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002556
2557 if opName.endswith("pool2d"):
2558 # avg_pool2d, max_pool2d
Jeremy Johnson0c716862023-04-13 17:18:19 +01002559 kernel_shape = args[pad_idx + 1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002560 h = (
2561 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2562 ) // strides[0]
2563 w = (
2564 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2565 ) // strides[1]
2566 # return True if any dimension is < 1
2567 return h < 1 or w < 1
2568
2569 if opName.startswith("transpose_conv2d"):
2570 # transpose_conv2d
Jeremy Johnson0c716862023-04-13 17:18:19 +01002571 output_shape = args[pad_idx + 1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002572 filter_shape = inputShapes[1]
2573 kernel_shape = filter_shape[1:-1]
2574
TatWai Chong24594f52022-06-08 00:48:04 -07002575 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002576 """Calculate the transpose_conv2d output size for a dimension."""
2577 return (in_size - 1) * stride + kernel_size + in_pad + out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002578
Jeremy Johnson0c716862023-04-13 17:18:19 +01002579 h = get_out_size(
2580 input_shape[1],
2581 strides[0],
2582 kernel_shape[0],
2583 padding[0],
2584 padding[1],
2585 )
2586 w = get_out_size(
2587 input_shape[2],
2588 strides[1],
2589 kernel_shape[1],
2590 padding[2],
2591 padding[3],
2592 )
2593 if output_shape[1] == h and output_shape[2] == w:
2594 return False
2595 # output shape does not match the expected shape
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002596 return True
2597
2598 if "conv2d" in opName or "conv3d" in opName:
2599 # conv2d, conv3d, depthwise_conv2d
Jeremy Johnson0c716862023-04-13 17:18:19 +01002600 dilations = args[pad_idx + 1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002601 filter_shape = inputShapes[1]
2602 kernel_shape = (
2603 filter_shape[0:2]
2604 if opName.startswith("depthwise_conv2d")
2605 else filter_shape[1:-1]
2606 )
2607
2608 for i in range(len(kernel_shape)):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002609 pad_offset = i * 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002610 dim = (
2611 input_shape[i + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002612 - 1
2613 + padding[pad_offset]
2614 + padding[pad_offset + 1]
2615 - (kernel_shape[i] - 1) * dilations[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002616 ) // strides[i] + 1
2617 # return True if any dimension is < 1
2618 if dim < 1:
2619 return True
2620 return False
2621
2622 assert False, f"Unrecognized Op: {opName}"
2623
2624 @staticmethod
2625 def ivNonPositiveOutputShape(**kwargs):
2626 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002627 output_shape = args[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002628 if output_shape[1] <= 0 or output_shape[2] <= 0:
2629 # Negative output shape
2630 return True
2631 return False