blob: 7f719ee2e2b77eddacde11f2bedc93f7598dafc6 [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01006from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from generator.tosa_utils import product
8from generator.tosa_utils import usableDTypes
9from generator.tosa_utils import valueToName
10from tosa.DType import DType
11from tosa.Op import Op
12from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000013
Matthew Haddone86fd342021-09-07 16:12:21 +010014
15class ErrorIf(object):
16 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010017 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
18 ScaleNLargerMax = "ScaleNLargerMax"
19 ScaleDLargerMax = "ScaleDLargerMax"
20 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010021 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010022 BorderSmallerMin = "BorderSmallerMin"
23 BorderLargerEqualMax = "BorderLargerEqualMax"
24 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
25 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010026 WrongInputType = "WrongInputType"
27 WrongOutputType = "WrongOutputType"
28 WrongInputList = "WrongInputList"
29 WrongOutputList = "WrongOutputList"
30 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010031 BatchMismatch = "BatchMismatch"
32 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010033 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000034 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010037 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 AxisSmallerZero = "AxisSmallerZero"
39 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010040 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
41 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010042 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010043 KernelSmallerOne = "KernelSmallerOne"
44 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000045 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010046 PadSmallerZero = "PadSmallerZero"
47 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010048 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010049 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010050 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
51 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
52 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010053 ScaleNotTrue = "ScaleNotTrue"
54 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010055 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
56 StartSmallerZero = "StartSmallerZero"
57 SizeSmallerEqualZero = "SizeSmallerEqualZero"
58 StartSizeOutsideBounds = "StartSizeOutsideBounds"
59 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
60 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
61 IndexOutsideBounds = "IndexOutsideBounds"
62 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010063 MaxSmallerMin = "MaxSmallerMin"
64 ConcatInputRankMismatch = "ConcatInputRankMismatch"
65 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010066 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010067 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
68 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
69 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
70 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
71 InputListOutputListMismatch = "InputListOutputListMismatch"
72 InputListCondGraphMismatch = "InputListCondGraphMismatch"
73 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
74 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
75 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010076 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
77 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000078 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
79 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
80 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000081 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Luke Hutton57287132023-02-06 14:54:18 +000082 FFTInputShapeMismatch = "FFTInputShapeMismatch"
83 FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
Jerry Ge264f7fa2023-04-21 22:49:57 +000084 ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
85 ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
Jerry Ge135c9552023-05-23 20:59:32 +000086 BroadcastShapesMismatch = "BroadcastShapesMismatch"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087
88
89class TosaErrorIfArgGen:
90 @staticmethod
91 def eiResizeErrorIf(
92 testGen,
93 error_name,
94 mode,
95 dtype,
96 shapeList,
97 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010098 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010099 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100100 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100102 if error_name == ErrorIf.ScaleSmallerEqualZero:
103 index = testGen.randInt(low=0, high=4)
104 scale[index] = testGen.rng.choice([-2, -1, 0])
105 elif error_name == ErrorIf.ScaleNLargerMax:
106 index = testGen.rng.choice([0, 2])
107 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
108 elif error_name == ErrorIf.ScaleDLargerMax:
109 index = testGen.rng.choice([1, 3])
110 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100112 if error_name == ErrorIf.OffsetLargerEqualMax:
113 index = testGen.rng.choice([0, 1])
114 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
115 elif error_name == ErrorIf.OffsetSmallerMin:
116 index = testGen.rng.choice([0, 1])
117 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
118
119 if error_name == ErrorIf.BorderLargerEqualMax:
120 index = testGen.rng.choice([0, 1])
121 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
122 elif error_name == ErrorIf.BorderSmallerMin:
123 index = testGen.rng.choice([0, 1])
124 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100125
126 if error_name == ErrorIf.WrongOutputType:
127 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
128 incorrect_types = (
129 DType.INT4,
130 DType.INT16,
131 DType.INT32,
132 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100133 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100134 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100135 )
136 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
137 incorrect_types = (
138 DType.INT4,
139 DType.INT8,
140 DType.INT32,
141 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100142 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100143 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100144 )
145 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
146 incorrect_types = (
147 DType.INT4,
148 DType.INT8,
149 DType.INT16,
150 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100151 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100152 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100153 )
154 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
155 incorrect_types = (
156 DType.INT4,
157 DType.INT8,
158 DType.INT16,
159 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100160 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100161 DType.FP16,
162 )
163 elif dtype == DType.FP16:
164 incorrect_types = (
165 DType.INT4,
166 DType.INT8,
167 DType.INT16,
168 DType.INT32,
169 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100170 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100171 )
James Ward24dbc422022-10-19 12:20:31 +0100172 elif dtype == DType.BF16:
173 incorrect_types = (
174 DType.INT4,
175 DType.INT8,
176 DType.INT16,
177 DType.INT32,
178 DType.INT48,
179 DType.FP32,
180 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100181 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 incorrect_types = (
183 DType.INT4,
184 DType.INT8,
185 DType.INT16,
186 DType.INT32,
187 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100188 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100189 )
190 outputDType = testGen.rng.choice(a=incorrect_types)
191
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100192 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 @staticmethod
195 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
196 if (
197 error_name == ErrorIf.StrideSmallerOne
198 # padding must not exceed the kernel size
199 and pad[0] < kernel[0]
200 and pad[1] < kernel[0]
201 and pad[2] < kernel[1]
202 and pad[3] < kernel[1]
203 ):
204 wrongStride = (
205 testGen.rng.choice([0, -1, -2, -3]),
206 testGen.rng.choice([0, -1, -2, -3]),
207 )
208 return wrongStride, pad, kernel
209 elif error_name == ErrorIf.PadSmallerZero:
210 wrongPad = (
211 testGen.rng.choice([-1, -2, -3]),
212 testGen.rng.choice([-1, -2, -3]),
213 testGen.rng.choice([-1, -2, -3]),
214 testGen.rng.choice([-1, -2, -3]),
215 )
216 return stride, wrongPad, kernel
217 elif error_name == ErrorIf.KernelSmallerOne:
218 wrongKernel = (
219 testGen.rng.choice([0, -1, -2, -3]),
220 testGen.rng.choice([0, -1, -2, -3]),
221 )
222 return stride, pad, wrongKernel
223 elif error_name == ErrorIf.PadLargerEqualKernel:
224 wrongPad = (
225 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
226 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
227 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
228 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
229 )
230 return stride, wrongPad, kernel
231 else:
232 return None, None, None
233
234 @staticmethod
235 def eiRescaleWrongOutputType(input_dtype, output_dtype):
236 if input_dtype == DType.INT8:
237 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
238 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100239 elif input_dtype == DType.INT16:
240 if output_dtype not in [
241 DType.UINT8,
242 DType.INT8,
243 DType.UINT16,
244 DType.INT16,
245 DType.INT32,
246 ]:
247 return True
248 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
250 return True
251 elif input_dtype == DType.INT48:
252 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
253 return True
254 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100255 if output_dtype not in [DType.INT8, DType.INT16]:
256 return True
257 elif input_dtype == DType.UINT16:
258 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 return True
260 return False
261
262 @staticmethod
263 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
264 # Mess up input/output tensors for ERROR_IF checks
265 if error_name == "WrongInputList":
266 add_input = testGen.rng.choice([True, False])
267 if add_input:
268 input_list.append("eiDummyInput")
269 else:
270 input_list = input_list[:-1]
271 elif error_name == "WrongOutputList":
272 add_output = testGen.rng.choice([True, False])
273 if add_output:
274 output_list.append("eiDummyOutput")
275 else:
276 output_list = []
277 return input_list, output_list
278
279 @staticmethod
280 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
281 """Restrict the dimensions and overall size of a shape to
282 max_dim and max_items.
283 """
284 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
285 while product(new_shape) > max_items:
286 new_shape = [max(d - 1, 1) for d in new_shape]
287 return new_shape
288
289 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
290 if error_name == ErrorIf.StartSmallerZero:
291 newStart = []
292 for i in range(len(input_shape)):
293 newStart.append(testGen.rng.choice([-3, -2, -1]))
294 return newStart, size
295 elif error_name == ErrorIf.SizeSmallerEqualZero:
296 newSize = []
297 for i in range(len(input_shape)):
298 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
299 return start, newSize
300 elif error_name == ErrorIf.StartSizeOutsideBounds:
301 newStart, newSize = [], []
302 for i in range(len(input_shape)):
303 newStart.append(input_shape[i] - 1)
304 newSize.append(testGen.rng.choice([2, 3, 4]))
305 return newStart, newSize
306 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
307 remove = testGen.rng.choice([True, False])
308 if remove:
309 newStart = start[1:]
310 newSize = size[1:]
311 else:
312 newStart = start
313 newStart.append(1)
314 newSize = size
315 newSize.append(1)
316 return newStart, newSize
317 else:
318 return start, size
319
320 @staticmethod
321 def eiCastErrorIf(testGen, input_dtype):
James Ward736fd1a2023-01-23 17:13:37 +0000322 if input_dtype in [DType.BOOL, DType.FP32]:
323 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
324 elif input_dtype in [DType.FP16, DType.BF16]:
325 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
327 outputDType = [DType.INT48]
328 else:
James Ward736fd1a2023-01-23 17:13:37 +0000329 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330 return outputDType
331
332
333class TosaErrorValidator:
334 @staticmethod
335 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
336 """Check ERROR_IF statements are caught and set the expected result.
337
338 Args:
339 serializer: the serializer to set the expected result in
340 validator_fcns: a sequence of validator functions to verify the result
341 error_name: the name of the ERROR_IF condition to check for
342 kwargs: keyword arguments for the validator functions
343 Returns:
344 True if the result matches the expected result; otherwise False
345 """
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000346 if validator_fcns is None:
347 # Nothing to do
348 return True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 overall_result = True
350 for val_fcn in validator_fcns:
351 val_result = val_fcn(True, **kwargs)
352 validator_name = val_result["error_name"]
353 error_result = val_result["error_result"]
354 error_reason = val_result["error_reason"]
355
356 # expect an error IFF the error_name and validator_name match
357 expected_result = error_result == (error_name == validator_name)
358 overall_result &= expected_result
359
360 if expected_result and error_result:
361 serializer.setExpectedReturnCode(2, True, desc=error_reason)
362 elif error_result: # and not expected_result
363 print(
364 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
365 f" Expected: {error_name}, Got: {validator_name}"
366 )
367 elif not expected_result: # and not error_result
368 print(
369 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
370 f" Expected: {error_name}"
371 )
372
373 if not expected_result:
374 for k, v in sorted(kwargs.items()):
375 if k != "op":
376 if k.endswith("dtype"):
377 v = valueToName(DType, v)
378 print(f" {k} = {v}")
379
380 return overall_result
381
382 @staticmethod
383 def evWrongInputType(check=False, **kwargs):
384 error_result = False
385
386 # Find the unsupported input data types
387 op = kwargs["op"]
388 input_dtypes = op["types"]
389 allowed_input_dtypes = {
390 t[0] if isinstance(t, list) else t for t in input_dtypes
391 }
392 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
393
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100394 # Turn the wrong dtypes into required list of types
395 if op["op"] in [
396 Op.FULLY_CONNECTED,
397 Op.CONV2D,
398 Op.CONV3D,
399 Op.DEPTHWISE_CONV2D,
400 Op.TRANSPOSE_CONV2D,
401 ]:
402 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
403
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100404 if op["op"] == Op.CLAMP:
405 wrong_input_dtypes.remove(DType.INT48)
406
407 if check:
408 input_dtype = kwargs["input_dtype"]
409 if input_dtype not in allowed_input_dtypes:
410 error_result = True
411
412 info_dict = {
413 "error_name": ErrorIf.WrongInputType,
414 "error_result": error_result,
415 "error_reason": "Input data type not supported for this operator",
416 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
417 }
418 return info_dict
419
420 @staticmethod
421 def evWrongOutputType(check=False, **kwargs):
422 error_result = False
423
424 if check:
425 input_dtype = kwargs["input_dtype"]
426 output_dtype = kwargs["output_dtype"]
427 op = kwargs["op"]
428
429 if op["op"] == Op.RESIZE:
430 mode = kwargs["mode"]
431 if (
432 (
433 mode == ResizeMode.NEAREST
434 and input_dtype == DType.INT8
435 and output_dtype != DType.INT8
436 )
437 or (
438 mode == ResizeMode.NEAREST
439 and input_dtype == DType.INT16
440 and output_dtype != DType.INT16
441 )
442 or (
443 mode == ResizeMode.BILINEAR
444 and input_dtype == DType.INT8
445 and output_dtype != DType.INT32
446 )
447 or (
448 mode == ResizeMode.BILINEAR
449 and input_dtype == DType.INT16
450 and output_dtype != DType.INT48
451 )
James Ward8b390432022-08-12 20:48:56 +0100452 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100453 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100454 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100455 ):
456 error_result = True
457
458 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100459 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
460 input_dtype, output_dtype
461 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100462
463 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
464 if (
465 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
466 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100467 or (
468 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100469 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100470 )
James Ward24dbc422022-10-19 12:20:31 +0100471 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100472 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100473 ):
474 error_result = True
475
476 elif op["op"] == Op.ARGMAX:
477 if (
James Ward24dbc422022-10-19 12:20:31 +0100478 input_dtype
479 in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100480 and output_dtype != DType.INT32
481 ):
482 error_result = True
483
484 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100485 if (
James Ward24dbc422022-10-19 12:20:31 +0100486 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100487 and output_dtype != DType.INT32
488 ):
489 error_result = True
490 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100491 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100492 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
493 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100494 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100495 error_result = True
496
497 elif op["op"] == Op.TABLE:
498 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
499 error_result = True
500 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
501 error_result = True
502
503 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
504 if output_dtype != DType.BOOL:
505 error_result = True
506
507 elif op["op"] == Op.CAST:
508 if (
509 (
510 input_dtype == DType.BOOL
511 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
512 )
513 or (
514 input_dtype == DType.INT8
515 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100516 not in [
517 DType.BOOL,
518 DType.INT16,
519 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100520 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100521 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100522 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100523 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100524 )
525 or (
526 input_dtype == DType.INT16
527 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100528 not in [
529 DType.BOOL,
530 DType.INT8,
531 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100532 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100533 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100534 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100535 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100536 )
537 or (
538 input_dtype == DType.INT32
539 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100540 not in [
541 DType.BOOL,
542 DType.INT8,
543 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100544 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100545 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100546 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100547 ]
548 )
549 or (
550 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000551 and output_dtype
552 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100553 )
554 or (
James Ward24dbc422022-10-19 12:20:31 +0100555 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000556 and output_dtype
557 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +0100558 )
559 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100560 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000561 and output_dtype
562 not in [
563 DType.INT8,
564 DType.INT16,
565 DType.INT32,
566 DType.FP16,
567 DType.BF16,
568 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100569 )
570 ):
571 error_result = True
572
Luke Hutton57287132023-02-06 14:54:18 +0000573 elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000574 if not all([ty == input_dtype for ty in output_dtype]):
575 error_result = True
576
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100577 elif op["op"] in {
578 Op.CONV2D,
579 Op.CONV3D,
580 Op.DEPTHWISE_CONV2D,
581 Op.TRANSPOSE_CONV2D,
582 }:
583 if (
584 input_dtype == DType.INT8
585 and output_dtype != DType.INT32
586 or input_dtype == DType.INT16
587 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100588 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100589 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100590 or input_dtype == DType.BF16
591 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100592 or input_dtype == DType.FP32
593 and output_dtype != DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100594 ):
595 error_result = True
596 # invalid input types are ignored, to avoid reporting multiple errors
597
598 else:
599 if output_dtype != input_dtype:
600 error_result = True
601
602 info_dict = {
603 "error_name": ErrorIf.WrongOutputType,
604 "error_result": error_result,
605 "error_reason": (
606 "Output data type not supported for this configuration of operator"
607 ),
608 "param_reqs": {"rank": None, "dtype": None, "shape": None},
609 }
610 return info_dict
611
612 @staticmethod
613 def evWrongRank(check=False, **kwargs):
614 all_ranks = (1, 2, 3, 4, 5)
615
616 # Make a list of incorrect ranks
617 assert "op" in kwargs
618 op = kwargs["op"]
619 rmin, rmax = op["rank"]
620 rank_range = range(rmin, rmax + 1)
621 incorrect_ranks = list(set(all_ranks) - set(rank_range))
622 # Remove small incorrect ranks to avoid index errors
623 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
624 # Set minimum incorrect rank to 3 to avoid index error
625 if op["op"] in [Op.RESIZE]:
626 incorrect_ranks = [3, 5]
627 elif op["op"] in [Op.TRANSPOSE]:
628 incorrect_ranks = [7, 8]
629 elif op["op"] in [Op.CONV3D]:
630 incorrect_ranks = [6, 7]
631
632 error_name = ErrorIf.WrongRank
633 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
634 error_result = False
635 error_reason = "Rank not supported for this operator"
636
637 if check:
638 input_shape = kwargs["input_shape"]
639
640 if (
641 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
642 and len(input_shape) != 4
643 ):
644 error_result = True
645 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
646 error_result = True
647 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
648 error_result = True
649 else:
650 if len(input_shape) not in rank_range:
651 error_result = True
652
653 info_dict = {
654 "error_name": error_name,
655 "error_result": error_result,
656 "error_reason": error_reason,
657 "param_reqs": param_reqs,
658 }
659 return info_dict
660
661 @staticmethod
662 def evWrongInputList(check=False, **kwargs):
663 error_name = ErrorIf.WrongInputList
664 param_reqs = {"rank": None, "dtype": None, "shape": None}
665 error_result = False
666 error_reason = "Op input list does not match expected input"
667
668 if check:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100669 input_list = kwargs["input_list"]
670 num_operands = kwargs["num_operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100671 if len(input_list) != num_operands:
672 error_result = True
673
674 info_dict = {
675 "error_name": error_name,
676 "error_result": error_result,
677 "error_reason": error_reason,
678 "param_reqs": param_reqs,
679 }
680 return info_dict
681
682 @staticmethod
683 def evWrongOutputList(check=False, **kwargs):
684 error_name = ErrorIf.WrongOutputList
685 param_reqs = {"rank": None, "dtype": None, "shape": None}
686 error_result = False
687 error_reason = "Op output list does not match expected output"
688
689 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000690 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100691 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000692 expected_length = 1
Luke Hutton57287132023-02-06 14:54:18 +0000693 if op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000694 expected_length = 2
695
696 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100697 error_result = True
698
699 info_dict = {
700 "error_name": error_name,
701 "error_result": error_result,
702 "error_reason": error_reason,
703 "param_reqs": param_reqs,
704 }
705 return info_dict
706
707 @staticmethod
708 def evMaxDimExceeded(check=False, **kwargs):
709 error_name = ErrorIf.MaxDimExceeded
710 param_reqs = {
711 "rank": [4, 4],
712 "dtype": [DType.INT8],
713 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
714 }
715 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100716 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100717
718 if check:
719 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100720 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100721 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100722 (input_shape[1] >= MAX_RESIZE_DIMENSION)
723 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
724 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
725 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100726 ):
727 error_result = True
728
729 info_dict = {
730 "error_name": error_name,
731 "error_result": error_result,
732 "error_reason": error_reason,
733 "param_reqs": param_reqs,
734 }
735 return info_dict
736
737 @staticmethod
738 def evBatchMismatch(check=False, **kwargs):
739 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000740 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100741 error_result = False
742 error_reason = "Input batch size not equal to output batch size"
743
744 assert "op" in kwargs
745 op = kwargs["op"]
746 rmin, rmax = op["rank"]
747 rank_range = range(rmin, rmax + 1)
748
749 if check:
750 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100751
Luke Hutton261b7b62023-01-10 14:50:31 +0000752 for output in kwargs["result_tensors"]:
753 output_shape = (
754 output.shape
755 ) # Note batch is expected to be the first dim
756 if (len(input_shape) in rank_range) and (
757 input_shape[0] != output_shape[0]
758 ):
759 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100760
761 info_dict = {
762 "error_name": error_name,
763 "error_result": error_result,
764 "error_reason": error_reason,
765 "param_reqs": param_reqs,
766 }
767 return info_dict
768
769 @staticmethod
770 def evChannelMismatch(check=False, **kwargs):
771 error_name = ErrorIf.ChannelMismatch
772 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
773 error_result = False
774 error_reason = "Input channel size not equal to output channel size"
775
776 assert "op" in kwargs
777 op = kwargs["op"]
778 rmin, rmax = op["rank"]
779 rank_range = range(rmin, rmax + 1)
780
781 if check:
782 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000783 for output in kwargs["result_tensors"]:
784 output_shape = output.shape # Note this is just (N, OH, OW, C)
785 if (len(input_shape) in rank_range) and (
786 input_shape[3] != output_shape[3]
787 ):
788 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100789
790 info_dict = {
791 "error_name": error_name,
792 "error_result": error_result,
793 "error_reason": error_reason,
794 "param_reqs": param_reqs,
795 }
796 return info_dict
797
798 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100799 def evScaleSmallerEqualZero(check=False, **kwargs):
800 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100801 param_reqs = {"rank": None, "dtype": None, "shape": None}
802 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100803 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100804
805 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100806 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100807
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100808 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100809 error_result = True
810
811 info_dict = {
812 "error_name": error_name,
813 "error_result": error_result,
814 "error_reason": error_reason,
815 "param_reqs": param_reqs,
816 }
817 return info_dict
818
819 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100820 def evScaleNLargerMax(check=False, **kwargs):
821 error_name = ErrorIf.ScaleNLargerMax
822 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100823 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100824 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100825
826 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100827 scale = kwargs["scale"]
828
829 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
830 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100831
832 info_dict = {
833 "error_name": error_name,
834 "error_result": error_result,
835 "error_reason": error_reason,
836 "param_reqs": param_reqs,
837 }
838 return info_dict
839
840 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100841 def evScaleDLargerMax(check=False, **kwargs):
842 error_name = ErrorIf.ScaleDLargerMax
843 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100844 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100845 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100846
847 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100848 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100849
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100850 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
851 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100852 ):
853 error_result = True
854
855 info_dict = {
856 "error_name": error_name,
857 "error_result": error_result,
858 "error_reason": error_reason,
859 "param_reqs": param_reqs,
860 }
861 return info_dict
862
863 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100864 def evOffsetSmallerMin(check=False, **kwargs):
865 error_name = ErrorIf.OffsetSmallerMin
866 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100867 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100868 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100869
870 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100871 scale = kwargs["scale"]
872 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100873
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100874 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100875 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100876 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100877 error_result = True
878
879 info_dict = {
880 "error_name": error_name,
881 "error_result": error_result,
882 "error_reason": error_reason,
883 "param_reqs": param_reqs,
884 }
885 return info_dict
886
887 @staticmethod
888 def evOffsetLargerEqualMax(check=False, **kwargs):
889 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100890 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100891 error_result = False
892 error_reason = "Offset value larger than or equal to maximum value"
893
894 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100895 scale = kwargs["scale"]
896 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100897
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100898 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
899 error_result = True
900 elif (
901 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
902 ):
903 error_result = True
904
905 info_dict = {
906 "error_name": error_name,
907 "error_result": error_result,
908 "error_reason": error_reason,
909 "param_reqs": param_reqs,
910 }
911 return info_dict
912
913 @staticmethod
914 def evBorderSmallerMin(check=False, **kwargs):
915 error_name = ErrorIf.BorderSmallerMin
916 param_reqs = {"rank": None, "dtype": None, "shape": None}
917 error_result = False
918 error_reason = "Border value smaller than minimum value"
919
920 if check:
921 scale = kwargs["scale"]
922 border = kwargs["border"]
923
924 if (
925 scale[0] > 0
926 and scale[0] <= (1 << 11)
927 and (border[0] < (-16 * scale[0]))
928 ):
929 error_result = True
930 elif (
931 scale[2] > 0
932 and scale[2] <= (1 << 11)
933 and (border[1] < (-16 * scale[2]))
934 ):
935 error_result = True
936
937 info_dict = {
938 "error_name": error_name,
939 "error_result": error_result,
940 "error_reason": error_reason,
941 "param_reqs": param_reqs,
942 }
943 return info_dict
944
945 @staticmethod
946 def evBorderLargerEqualMax(check=False, **kwargs):
947 error_name = ErrorIf.BorderLargerEqualMax
948 param_reqs = {"rank": None, "dtype": None, "shape": None}
949 error_result = False
950 error_reason = "Border value larger than or equal to maximum value"
951
952 if check:
953 scale = kwargs["scale"]
954 border = kwargs["border"]
955
956 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
957 error_result = True
958 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
959 error_result = True
960
961 info_dict = {
962 "error_name": error_name,
963 "error_result": error_result,
964 "error_reason": error_reason,
965 "param_reqs": param_reqs,
966 }
967 return info_dict
968
969 @staticmethod
970 def checkResizeParams(scale, offset, border):
971 return (
972 min(scale) > 0
973 and max(scale[0], scale[2]) <= (1 << 11)
974 and scale[1] < 16 * scale[0]
975 and scale[3] < 16 * scale[2]
976 and offset[0] >= -scale[0]
977 and offset[1] >= -scale[2]
978 and offset[0] < 16 * scale[0]
979 and offset[1] < 16 * scale[2]
980 and border[0] >= -16 * scale[0]
981 and border[1] >= -16 * scale[2]
982 and border[0] < scale[0]
983 and border[1] < scale[2]
984 )
985
986 @staticmethod
987 def evResizeOutputShapeMismatch(check=False, **kwargs):
988 error_name = ErrorIf.ResizeOutputShapeMismatch
989 param_reqs = {"rank": None, "dtype": None, "shape": None}
990 error_result = False
991 error_reason = (
992 "Mismatch between output shape provided and expected output shape"
993 )
994
995 if check:
996 input_shape = kwargs["input_shape"]
997 output_shape = kwargs["output_shape"]
998 scale = kwargs["scale"]
999 offset = kwargs["offset"]
1000 border = kwargs["border"]
1001
1002 # Ensure parameters are valid
1003 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1004
1005 if (
1006 params_valid
1007 and max(output_shape) < MAX_RESIZE_DIMENSION
1008 and max(input_shape) < MAX_RESIZE_DIMENSION
1009 ):
1010 output_y = (
1011 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1012 ) // scale[1] + 1
1013 output_x = (
1014 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1015 ) // scale[3] + 1
1016
1017 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001018 error_result = True
1019
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001020 info_dict = {
1021 "error_name": error_name,
1022 "error_result": error_result,
1023 "error_reason": error_reason,
1024 "param_reqs": param_reqs,
1025 }
1026 return info_dict
1027
1028 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001029 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1030 error_name = ErrorIf.ResizeOutputShapeNonInteger
1031 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001032 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001033 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001034
1035 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001036 input_shape = kwargs["input_shape"]
1037 scale = kwargs["scale"]
1038 offset = kwargs["offset"]
1039 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001040
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001041 # Ensure parameters are valid
1042 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001043
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001044 if params_valid:
1045 remainder_y = (
1046 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1047 ) % scale[1]
1048 remainder_x = (
1049 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1050 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001051
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001052 if max(remainder_y, remainder_x) > 0:
1053 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001054
1055 info_dict = {
1056 "error_name": error_name,
1057 "error_result": error_result,
1058 "error_reason": error_reason,
1059 "param_reqs": param_reqs,
1060 }
1061 return info_dict
1062
1063 @staticmethod
1064 def evRankMismatch(check=False, **kwargs):
1065 error_name = ErrorIf.RankMismatch
1066 param_reqs = {"rank": None, "dtype": None, "shape": None}
1067 error_result = False
1068 error_reason = "Input Rank does not match output rank"
1069
1070 if check:
1071 input1_shape = kwargs["input1"].shape
Luke Huttona4e48ca2023-02-22 11:53:48 +00001072 input2_shape = (
1073 kwargs["input2"].shape if "input2" in kwargs else input1_shape
1074 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001075 # In case of SELECT op
1076 input3_shape = (
1077 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1078 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001079
1080 for output in kwargs["result_tensors"]:
1081 output_shape = output.shape
1082 if (
1083 (len(input1_shape) != len(output_shape))
1084 or (len(input2_shape) != len(output_shape))
1085 or (len(input3_shape) != len(output_shape))
1086 ):
1087 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001088
1089 info_dict = {
1090 "error_name": error_name,
1091 "error_result": error_result,
1092 "error_reason": error_reason,
1093 "param_reqs": param_reqs,
1094 }
1095 return info_dict
1096
1097 @staticmethod
1098 def evDimensionMismatch(check=False, **kwargs):
1099 error_name = ErrorIf.DimensionMismatch
1100 param_reqs = {"rank": None, "dtype": None, "shape": None}
1101 error_result = False
1102 error_reason = "Input Dimensions do not match output"
1103
1104 if check:
1105 input1_shape = kwargs["input1"].shape
1106 input2_shape = kwargs["input2"].shape
1107 # In case of SELECT op
1108 input3_shape = (
1109 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1110 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001111
Jerry Ge135c9552023-05-23 20:59:32 +00001112 if len(input1_shape) == len(input2_shape) == len(input3_shape):
1113 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
1114 input3_shape,
1115 TosaErrorValidator.calculateBroadcastShape(
1116 input1_shape, input2_shape
1117 ),
1118 )
1119 if calculated_shape is not None:
1120 # Valid inputs - check for output mismatch
1121 for output in kwargs["result_tensors"]:
1122 output_shape = output.shape
1123 if calculated_shape != output_shape:
1124 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001125
1126 info_dict = {
1127 "error_name": error_name,
1128 "error_result": error_result,
1129 "error_reason": error_reason,
1130 "param_reqs": param_reqs,
1131 }
1132 return info_dict
1133
1134 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001135 def _getZeroPoint(qinfo, index):
1136 """Return zero point value from quantization info.
1137
1138 Generally input_zp is index 0, output_zp is index 1
1139 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001140 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001141
1142 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001143 def evInputZeroPointNotZero(check=False, **kwargs):
1144 op = kwargs["op"]
1145 error_result = False
1146
1147 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001148 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001149
1150 # This does not apply to quantizable types
1151 inputDtypes = [
1152 dtype
1153 for dtype in op["types"]
1154 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1155 or (not isinstance(dtype, list) and dtype not in qTypes)
1156 ]
1157
1158 if check:
1159 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001160 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001162 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001163 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001164 (kwargs["input_dtype"], input_zero_point),
1165 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001166 ):
1167 if dtype not in qTypes and zp != 0:
1168 error_result = True
1169 break
1170 else:
1171 error_result = input_dtype not in qTypes and input_zero_point != 0
1172
1173 info_dict = {
1174 "error_name": ErrorIf.InputZeroPointNotZero,
1175 "error_result": error_result,
1176 "error_reason": "Input DType not INT8 and zero point not 0",
1177 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1178 }
1179 return info_dict
1180
1181 @staticmethod
1182 def evWeightZeroPointNotZero(check=False, **kwargs):
1183 op = kwargs["op"]
1184
1185 # exclude inputs with INT8 weights
1186 inputDtypes = [
1187 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1188 ]
1189
1190 error_name = ErrorIf.WeightZeroPointNotZero
1191 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1192 error_result = False
1193 error_reason = "Weight DType not INT8 and zero point not 0"
1194
1195 if check:
1196 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001197 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001198 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1199 error_result = True
1200
1201 info_dict = {
1202 "error_name": error_name,
1203 "error_result": error_result,
1204 "error_reason": error_reason,
1205 "param_reqs": param_reqs,
1206 }
1207 return info_dict
1208
1209 @staticmethod
1210 def evOutputZeroPointNotZero(check=False, **kwargs):
1211 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001212 inputDtypes = [
1213 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1214 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001215
1216 error_name = ErrorIf.OutputZeroPointNotZero
1217 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1218 error_result = False
1219 error_reason = "Output DType not INT8 and zero point not 0"
1220
1221 if check:
1222 input_dtype = kwargs["input_dtype"]
1223 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001224 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001225 if op["op"] == Op.AVG_POOL2D:
1226 if input_dtype != DType.INT8 and output_zero_point != 0:
1227 error_result = True
1228 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001229 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1230 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001231 ):
1232 error_result = True
1233
1234 info_dict = {
1235 "error_name": error_name,
1236 "error_result": error_result,
1237 "error_reason": error_reason,
1238 "param_reqs": param_reqs,
1239 }
1240 return info_dict
1241
1242 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001243 def evU16InputZeroPointNotValid(check=False, **kwargs):
1244 error_name = ErrorIf.U16InputZeroPointNotValid
1245 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1246 error_result = False
1247 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1248
1249 if check:
1250 input_dtype = kwargs["input_dtype"]
1251 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1252 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1253 0,
1254 32768,
1255 ]
1256
1257 info_dict = {
1258 "error_name": error_name,
1259 "error_result": error_result,
1260 "error_reason": error_reason,
1261 "param_reqs": param_reqs,
1262 }
1263 return info_dict
1264
1265 @staticmethod
1266 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1267 error_name = ErrorIf.U16OutputZeroPointNotValid
1268 param_reqs = {"rank": None, "dtype": None, "shape": None}
1269 error_result = False
1270 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1271
1272 if check:
1273 output_dtype = kwargs["output_dtype"]
1274 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1275
1276 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1277 0,
1278 32768,
1279 ]
1280
1281 info_dict = {
1282 "error_name": error_name,
1283 "error_result": error_result,
1284 "error_reason": error_reason,
1285 "param_reqs": param_reqs,
1286 }
1287 return info_dict
1288
1289 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001290 def evAxisSmallerZero(check=False, **kwargs):
1291 error_name = ErrorIf.AxisSmallerZero
1292 param_reqs = {"rank": None, "dtype": None, "shape": None}
1293 error_result = False
1294 error_reason = "Axis smaller than zero"
1295
1296 if check:
1297 axis = kwargs["axis"]
1298 if axis < 0:
1299 error_result = True
1300
1301 info_dict = {
1302 "error_name": error_name,
1303 "error_result": error_result,
1304 "error_reason": error_reason,
1305 "param_reqs": param_reqs,
1306 }
1307 return info_dict
1308
1309 @staticmethod
1310 def evAxisLargerRank(check=False, **kwargs):
1311 error_name = ErrorIf.AxisLargerRank
1312 param_reqs = {"rank": None, "dtype": None, "shape": None}
1313 error_result = False
1314 error_reason = "Axis larger than rank"
1315
1316 if check:
1317 axis = kwargs["axis"]
1318 shape = kwargs["input_shape"]
1319 if axis > len(shape):
1320 error_result = True
1321
1322 info_dict = {
1323 "error_name": error_name,
1324 "error_result": error_result,
1325 "error_reason": error_reason,
1326 "param_reqs": param_reqs,
1327 }
1328 return info_dict
1329
1330 @staticmethod
1331 def evShapeOfAxisNotOne(check=False, **kwargs):
1332 error_name = ErrorIf.ShapeOfAxisNotOne
1333 param_reqs = {"rank": None, "dtype": None, "shape": None}
1334 error_result = False
1335 error_reason = "shape[axis] is not equal to 1"
1336
1337 if check:
1338 axis = kwargs["axis"]
1339 shape = kwargs["output_shape"]
1340 if (0 <= axis < len(shape)) and shape[axis] != 1:
1341 error_result = True
1342
1343 info_dict = {
1344 "error_name": error_name,
1345 "error_result": error_result,
1346 "error_reason": error_reason,
1347 "param_reqs": param_reqs,
1348 }
1349 return info_dict
1350
1351 @staticmethod
1352 def evPadSmallerZero(check=False, **kwargs):
1353 error_name = ErrorIf.PadSmallerZero
1354 param_reqs = {"rank": None, "dtype": None, "shape": None}
1355 error_result = False
1356 error_reason = "At least one pad is smaller than zero"
1357
1358 if check:
1359 op = kwargs["op"]
1360 pad = kwargs["pad"]
1361 if op["op"] == Op.PAD:
1362 for padding in pad:
1363 if min(padding) < 0:
1364 error_result = True
1365 else:
1366 if min(pad) < 0:
1367 error_result = True
1368
1369 info_dict = {
1370 "error_name": error_name,
1371 "error_result": error_result,
1372 "error_reason": error_reason,
1373 "param_reqs": param_reqs,
1374 }
1375 return info_dict
1376
1377 @staticmethod
1378 def evPadLargerEqualKernel(check=False, **kwargs):
1379 error_name = ErrorIf.PadLargerEqualKernel
1380 param_reqs = {"rank": None, "dtype": None, "shape": None}
1381 error_result = False
1382 error_reason = "At least one pad is larger than kernel dimension"
1383
1384 if check:
1385 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001386 op = kwargs["op"]
1387 if op["op"] == Op.TRANSPOSE_CONV2D:
1388 # transpose_conv2d
1389 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001390 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001391 pad[0] <= -kernel[0]
1392 or pad[1] <= -kernel[0]
1393 or pad[2] <= -kernel[1]
1394 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001395 ):
1396 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001397 else:
1398 # pooling op
1399 kernel = kwargs["kernel"]
1400 if min(pad) > 0 and min(kernel) > 1:
1401 if (
1402 pad[0] >= kernel[0]
1403 or pad[1] >= kernel[0]
1404 or pad[2] >= kernel[1]
1405 or pad[3] >= kernel[1]
1406 ):
1407 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001408
1409 info_dict = {
1410 "error_name": error_name,
1411 "error_result": error_result,
1412 "error_reason": error_reason,
1413 "param_reqs": param_reqs,
1414 }
1415 return info_dict
1416
1417 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001418 def evPadOutputShapeMismatch(check=False, **kwargs):
1419 error_name = ErrorIf.PadOutputShapeMismatch
1420 param_reqs = {"rank": None, "dtype": None, "shape": None}
1421 error_result = False
1422 error_reason = "Pad output shape mismatch for requested padding"
1423
1424 if check:
1425 pad = kwargs["pad"]
1426 input_shape = kwargs["input_shape"]
1427 output_shape = kwargs["output_shape"]
1428 for dim, padding in enumerate(pad):
1429 expected_size = input_shape[dim] + padding[0] + padding[1]
1430 if expected_size != output_shape[dim]:
1431 error_result = True
1432
1433 info_dict = {
1434 "error_name": error_name,
1435 "error_result": error_result,
1436 "error_reason": error_reason,
1437 "param_reqs": param_reqs,
1438 }
1439 return info_dict
1440
1441 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001442 def checkPoolingParams(kernel, stride, pad):
1443 return (
1444 min(kernel) >= 1
1445 and min(stride) >= 1
1446 and min(pad) >= 0
1447 and not (
1448 pad[0] >= kernel[0]
1449 or pad[1] >= kernel[0]
1450 or pad[2] >= kernel[1]
1451 or pad[3] >= kernel[1]
1452 )
1453 )
1454
1455 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001456 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1457 error_name = ErrorIf.PoolingOutputShapeMismatch
1458 param_reqs = {"rank": None, "dtype": None, "shape": None}
1459 error_result = False
1460 error_reason = (
1461 "Mismatch between output shape provided and expected output shape"
1462 )
1463
1464 if check:
1465 pad = kwargs["pad"]
1466 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1467
1468 kernel = kwargs["kernel"]
1469 kernel_y, kernel_x = kernel[0], kernel[1]
1470
1471 input_shape = kwargs["input_shape"]
1472 IH, IW = input_shape[1], input_shape[2]
1473
1474 output_shape = kwargs["output_shape"]
1475 OH, OW = output_shape[1], output_shape[2]
1476
1477 stride = kwargs["stride"]
1478 stride_y, stride_x = stride[0], stride[1]
1479
1480 # calculate correct height, width dimensions
1481 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001482 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1483 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001484
1485 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001486 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001487
1488 if params_valid and (OH != y_correct or OW != x_correct):
1489 error_result = True
1490
1491 info_dict = {
1492 "error_name": error_name,
1493 "error_result": error_result,
1494 "error_reason": error_reason,
1495 "param_reqs": param_reqs,
1496 }
1497 return info_dict
1498
1499 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001500 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1501 error_name = ErrorIf.PoolingOutputShapeNonInteger
1502 param_reqs = {"rank": None, "dtype": None, "shape": None}
1503 error_result = False
1504 error_reason = "Parameters do not yield exact integer output dimensions"
1505
1506 if check:
1507 pad = kwargs["pad"]
1508 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1509
1510 kernel = kwargs["kernel"]
1511 kernel_y, kernel_x = kernel[0], kernel[1]
1512
1513 input_shape = kwargs["input_shape"]
1514 IH, IW = input_shape[1], input_shape[2]
1515
1516 stride = kwargs["stride"]
1517 stride_y, stride_x = stride[0], stride[1]
1518
1519 # calculate remainder of height, width dimensions
1520 if stride_x != 0 and stride_y != 0:
1521 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1522 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1523
1524 # ensure parameters are valid
1525 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1526 if params_valid and (y_remainder != 0 or x_remainder != 0):
1527 error_result = True
1528
1529 info_dict = {
1530 "error_name": error_name,
1531 "error_result": error_result,
1532 "error_reason": error_reason,
1533 "param_reqs": param_reqs,
1534 }
1535 return info_dict
1536
1537 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001538 def checkConvParams(op, weight_shape, stride, pad, dilation):
1539 if op == Op.TRANSPOSE_CONV2D:
1540 pad_ok = (
1541 pad[0] > -weight_shape[1]
1542 and pad[1] > -weight_shape[1]
1543 and pad[2] > -weight_shape[2]
1544 and pad[3] > -weight_shape[2]
1545 )
1546 else:
1547 pad_ok = min(pad) >= 0
1548
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001549 return (
1550 # Check kernel sizes
1551 min(weight_shape[1:-1]) >= 1
1552 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001553 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001554 and (dilation is None or min(dilation) >= 1)
1555 )
1556
1557 @staticmethod
1558 def evConvOutputShapeMismatch(check=False, **kwargs):
1559 error_name = ErrorIf.ConvOutputShapeMismatch
1560 param_reqs = {"rank": None, "dtype": None, "shape": None}
1561 error_result = False
1562 error_reason = (
1563 "Mismatch between output shape provided and expected output shape"
1564 )
1565
1566 if check:
1567 op = kwargs["op"]
1568 pad = kwargs["pad"]
1569 weight_shape = kwargs["weight_shape"]
1570 input_shape = kwargs["input_shape"]
1571 output_shape = kwargs["output_shape"]
1572 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1573 stride = kwargs["stride"]
1574
1575 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1576
1577 # calculate correct dimensions
1578 dims_correct = []
1579 if min(stride) > 0:
1580 for index in range(len(stride)):
1581 pad_offset = index * 2
1582 if op["op"] == Op.TRANSPOSE_CONV2D:
1583 dims_correct.append(
1584 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001585 + pad[pad_offset]
1586 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001587 + weight_shape[index + kernel_offset]
1588 )
1589 else:
1590 dims_correct.append(
1591 (
1592 input_shape[index + 1]
1593 - 1
1594 + pad[pad_offset]
1595 + pad[pad_offset + 1]
1596 - (weight_shape[index + kernel_offset] - 1)
1597 * dilation[index]
1598 )
1599 // stride[index]
1600 + 1
1601 )
1602
1603 # ensure parameters are valid
1604 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001605 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001606 )
1607
1608 if params_valid and output_shape[1:-1] != dims_correct:
1609 error_result = True
1610
1611 info_dict = {
1612 "error_name": error_name,
1613 "error_result": error_result,
1614 "error_reason": error_reason,
1615 "param_reqs": param_reqs,
1616 }
1617 return info_dict
1618
1619 @staticmethod
1620 def evConvOutputShapeNonInteger(check=False, **kwargs):
1621 error_name = ErrorIf.ConvOutputShapeNonInteger
1622 param_reqs = {"rank": None, "dtype": None, "shape": None}
1623 error_result = False
1624 error_reason = "Parameters do not yield exact integer output dimensions"
1625
1626 if check:
1627 op = kwargs["op"]
1628 pad = kwargs["pad"]
1629 weight_shape = kwargs["weight_shape"]
1630 input_shape = kwargs["input_shape"]
1631 dilation = kwargs["dilation"]
1632 stride = kwargs["stride"]
1633
1634 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1635
1636 # calculate correct height, width dimensions
1637 remainders = []
1638 if min(stride) > 0:
1639 for index in range(len(stride)):
1640 pad_offset = index * 2
1641 remainders.append(
1642 (
1643 input_shape[index + 1]
1644 - 1
1645 + pad[pad_offset]
1646 + pad[pad_offset + 1]
1647 - (weight_shape[index + kernel_offset] - 1)
1648 * dilation[index]
1649 )
1650 % stride[index]
1651 )
1652
1653 # ensure parameters are valid
1654 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001655 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001656 )
1657 if params_valid and max(remainders) > 0:
1658 error_result = True
1659
1660 info_dict = {
1661 "error_name": error_name,
1662 "error_result": error_result,
1663 "error_reason": error_reason,
1664 "param_reqs": param_reqs,
1665 }
1666 return info_dict
1667
1668 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001669 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1670 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1671 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1672 error_result = False
1673 error_reason = (
1674 "Mismatch between output shape provided and expected output shape"
1675 )
1676
1677 if check:
1678 output_shape = kwargs["output_shape"]
1679 input_shape = kwargs["input_shape"]
1680 axis = kwargs["axis"]
1681
1682 dimension_match = True
1683 axis_shift = 0
1684
1685 # Check that rank is correct before trying to check dimensions
1686 if (len(input_shape) - 1) == len(output_shape):
1687 for i in range(len(input_shape)):
1688 if i == axis:
1689 axis_shift = 1
1690 continue
1691 if input_shape[i] != output_shape[i - axis_shift]:
1692 dimension_match = False
1693
1694 if not dimension_match:
1695 error_result = True
1696
1697 info_dict = {
1698 "error_name": error_name,
1699 "error_result": error_result,
1700 "error_reason": error_reason,
1701 "param_reqs": param_reqs,
1702 }
1703 return info_dict
1704
1705 @staticmethod
1706 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1707 error_name = ErrorIf.ArgmaxOutputRankMismatch
1708 param_reqs = {"rank": None, "dtype": None, "shape": None}
1709 error_result = False
1710 error_reason = (
1711 "Mismatch between output shape provided and expected output shape"
1712 )
1713
1714 if check:
1715 output_shape = kwargs["output_shape"]
1716 input_shape = kwargs["input_shape"]
1717 axis = kwargs["axis"]
1718 valid_params = axis >= 0 and axis < len(input_shape)
1719
1720 if valid_params and (len(input_shape) - 1) != len(output_shape):
1721 error_result = True
1722
1723 info_dict = {
1724 "error_name": error_name,
1725 "error_result": error_result,
1726 "error_reason": error_reason,
1727 "param_reqs": param_reqs,
1728 }
1729 return info_dict
1730
1731 @staticmethod
1732 def evKernelSmallerOne(check=False, **kwargs):
1733 error_name = ErrorIf.KernelSmallerOne
1734 param_reqs = {"rank": None, "dtype": None, "shape": None}
1735 error_result = False
1736 error_reason = "At least one kernel dimension is smaller than zero"
1737
1738 if check:
1739 kernel = kwargs["kernel"]
1740 if min(kernel) < 1:
1741 error_result = True
1742
1743 info_dict = {
1744 "error_name": error_name,
1745 "error_result": error_result,
1746 "error_reason": error_reason,
1747 "param_reqs": param_reqs,
1748 }
1749 return info_dict
1750
1751 @staticmethod
1752 def evStrideSmallerOne(check=False, **kwargs):
1753 error_name = ErrorIf.StrideSmallerOne
1754 param_reqs = {"rank": None, "dtype": None, "shape": None}
1755 error_result = False
1756 error_reason = "At least one stride dimension is smaller than zero"
1757
1758 if check:
1759 stride = kwargs["stride"]
1760 if min(stride) < 1:
1761 error_result = True
1762
1763 info_dict = {
1764 "error_name": error_name,
1765 "error_result": error_result,
1766 "error_reason": error_reason,
1767 "param_reqs": param_reqs,
1768 }
1769 return info_dict
1770
1771 @staticmethod
1772 def evDilationSmallerOne(check=False, **kwargs):
1773 error_result = check and min(kwargs["dilation"]) < 1
1774 return {
1775 "error_name": ErrorIf.DilationSmallerOne,
1776 "error_reason": "At least one dilation is smaller than one",
1777 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1778 "error_result": error_result,
1779 }
1780
1781 @staticmethod
1782 def evScaleTrue(check=False, **kwargs):
1783 error_name = ErrorIf.ScaleTrue
1784 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1785 error_result = False
1786 error_reason = "Scale set to true but input type is INT48"
1787
1788 if check:
1789 input_dtype = kwargs["input_dtype"]
1790 scale32 = kwargs["scale32"]
1791 if scale32 and input_dtype == DType.INT48:
1792 error_result = True
1793
1794 info_dict = {
1795 "error_name": error_name,
1796 "error_result": error_result,
1797 "error_reason": error_reason,
1798 "param_reqs": param_reqs,
1799 }
1800 return info_dict
1801
1802 @staticmethod
1803 def evScaleNotTrue(check=False, **kwargs):
1804 error_name = ErrorIf.ScaleNotTrue
1805 param_reqs = {"rank": None, "dtype": None, "shape": None}
1806 error_result = False
1807 error_reason = "Scale set to false but double round set to true"
1808
1809 if check:
1810 scale32 = kwargs["scale32"]
1811 double_round = kwargs["double_round"]
1812 if not scale32 and double_round:
1813 error_result = True
1814
1815 info_dict = {
1816 "error_name": error_name,
1817 "error_result": error_result,
1818 "error_reason": error_reason,
1819 "param_reqs": param_reqs,
1820 }
1821 return info_dict
1822
1823 @staticmethod
1824 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1825 error_name = ErrorIf.TensorSizeInputOutputMismatch
1826 param_reqs = {"rank": None, "dtype": None, "shape": None}
1827 error_result = False
1828 error_reason = "Input tensor size does not match output tensor size"
Jerry Ge264f7fa2023-04-21 22:49:57 +00001829 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001830
1831 if check:
1832 input_shape = kwargs["input_shape"]
1833 output_shape = kwargs["output_shape"]
Jerry Ge264f7fa2023-04-21 22:49:57 +00001834 shape_inferencing = False
1835 if -1 in output_shape and op["op"] == Op.RESHAPE:
1836 shape_inferencing = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001837 input_size = np.prod(input_shape)
1838 output_size = np.prod(output_shape)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001839 if input_size != output_size and not shape_inferencing:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001840 error_result = True
1841
1842 info_dict = {
1843 "error_name": error_name,
1844 "error_result": error_result,
1845 "error_reason": error_reason,
1846 "param_reqs": param_reqs,
1847 }
1848 return info_dict
1849
1850 @staticmethod
1851 def evStartSmallerZero(check=False, **kwargs):
1852 error_name = ErrorIf.StartSmallerZero
1853 param_reqs = {"rank": None, "dtype": None, "shape": None}
1854 error_result = False
1855 error_reason = "Starting point smaller than zero"
1856
1857 if check:
1858 input_shape = kwargs["input_shape"]
1859 start = kwargs["start"]
1860 rank = len(input_shape)
1861 if len(start) == rank:
1862 for index in range(rank):
1863 if start[index] < 0:
1864 error_result = True
1865
1866 info_dict = {
1867 "error_name": error_name,
1868 "error_result": error_result,
1869 "error_reason": error_reason,
1870 "param_reqs": param_reqs,
1871 }
1872 return info_dict
1873
1874 @staticmethod
1875 def evSizeSmallerEqualZero(check=False, **kwargs):
1876 error_name = ErrorIf.SizeSmallerEqualZero
1877 param_reqs = {"rank": None, "dtype": None, "shape": None}
1878 error_result = False
1879 error_reason = "Size smaller than or equal to zero"
1880
1881 if check:
1882 input_shape = kwargs["input_shape"]
1883 size = kwargs["size"]
1884 rank = len(input_shape)
1885 if len(size) == rank:
1886 for index in range(rank):
1887 if size[index] <= 0:
1888 error_result = True
1889
1890 info_dict = {
1891 "error_name": error_name,
1892 "error_result": error_result,
1893 "error_reason": error_reason,
1894 "param_reqs": param_reqs,
1895 }
1896 return info_dict
1897
1898 @staticmethod
1899 def evStartSizeOutsideBounds(check=False, **kwargs):
1900 error_name = ErrorIf.StartSizeOutsideBounds
1901 param_reqs = {"rank": None, "dtype": None, "shape": None}
1902 error_result = False
1903 error_reason = "starting point plus size larger than input dimension"
1904
1905 if check:
1906 input_shape = kwargs["input_shape"]
1907 start = kwargs["start"]
1908 size = kwargs["size"]
1909 rank = len(input_shape)
1910 if len(start) == rank and len(size) == rank:
1911 for index in range(rank):
1912 if start[index] + size[index] > input_shape[index]:
1913 error_result = True
1914
1915 info_dict = {
1916 "error_name": error_name,
1917 "error_result": error_result,
1918 "error_reason": error_reason,
1919 "param_reqs": param_reqs,
1920 }
1921 return info_dict
1922
1923 @staticmethod
1924 def evSizeOutputShapeMismatch(check=False, **kwargs):
1925 error_name = ErrorIf.SizeOutputShapeMismatch
1926 param_reqs = {"rank": None, "dtype": None, "shape": None}
1927 error_result = False
1928 error_reason = "Size does not match output dimension"
1929
1930 if check:
1931 input_shape = kwargs["input_shape"]
1932 output_shape = kwargs["output_shape"]
1933 size = kwargs["size"]
Luke Huttona4e48ca2023-02-22 11:53:48 +00001934
1935 if len(input_shape) == len(output_shape):
1936 rank = len(input_shape)
1937 if len(size) == rank:
1938 for index in range(rank):
1939 if size[index] != output_shape[index]:
1940 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001941
1942 info_dict = {
1943 "error_name": error_name,
1944 "error_result": error_result,
1945 "error_reason": error_reason,
1946 "param_reqs": param_reqs,
1947 }
1948 return info_dict
1949
1950 @staticmethod
1951 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1952 error_name = ErrorIf.InputSizeStartLengthMismatch
1953 param_reqs = {"rank": None, "dtype": None, "shape": None}
1954 error_result = False
1955 error_reason = "rank of input not equal to length of start or size"
1956
1957 if check:
1958 input_shape = kwargs["input_shape"]
1959 start = kwargs["start"]
1960 size = kwargs["size"]
1961 rank = len(input_shape)
1962 if rank != len(start) or rank != len(size):
1963 error_result = True
1964
1965 info_dict = {
1966 "error_name": error_name,
1967 "error_result": error_result,
1968 "error_reason": error_reason,
1969 "param_reqs": param_reqs,
1970 }
1971 return info_dict
1972
1973 @staticmethod
1974 def evIndexOutsideBounds(check=False, **kwargs):
1975 error_name = ErrorIf.IndexOutsideBounds
1976 param_reqs = {"rank": None, "dtype": None, "shape": None}
1977 error_result = False
1978 error_reason = "Index outside of allowed bounds"
1979
1980 if check:
1981 input_shape = kwargs["input_shape"]
1982 perms = kwargs["perms"]
1983 rank = len(input_shape)
1984
1985 for index in perms:
1986 if index < 0 or index > rank:
1987 error_result = True
1988
1989 info_dict = {
1990 "error_name": error_name,
1991 "error_result": error_result,
1992 "error_reason": error_reason,
1993 "param_reqs": param_reqs,
1994 }
1995 return info_dict
1996
1997 @staticmethod
1998 def evIndexUsedTwice(check=False, **kwargs):
1999 error_name = ErrorIf.IndexUsedTwice
2000 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2001 error_result = False
2002 error_reason = "Index used multiple times"
2003
2004 if check:
2005 perms = kwargs["perms"]
2006
2007 unique_indices = []
2008 for index in perms:
2009 if index in unique_indices:
2010 error_result = True
2011 else:
2012 unique_indices.append(index)
2013
2014 info_dict = {
2015 "error_name": error_name,
2016 "error_result": error_result,
2017 "error_reason": error_reason,
2018 "param_reqs": param_reqs,
2019 }
2020 return info_dict
2021
2022 @staticmethod
2023 def evMaxSmallerMin(check=False, **kwargs):
2024 error_name = ErrorIf.MaxSmallerMin
2025 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2026 error_result = False
2027 error_reason = "Max value smaller than min value"
2028
2029 if check:
2030 max_val = kwargs["max_val"]
2031 min_val = kwargs["min_val"]
2032 if max_val < min_val:
2033 error_result = True
2034
2035 info_dict = {
2036 "error_name": error_name,
2037 "error_result": error_result,
2038 "error_reason": error_reason,
2039 "param_reqs": param_reqs,
2040 }
2041 return info_dict
2042
2043 @staticmethod
2044 def evConcatInputRankMismatch(check=False, **kwargs):
2045 error_name = ErrorIf.ConcatInputRankMismatch
2046 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2047 error_result = False
2048 error_reason = "Input ranks are not identical"
2049
2050 if check:
2051 inputs = kwargs["inputs"]
2052 input_shape = kwargs["input_shape"]
2053 for input in inputs:
2054 if len(input.shape) != len(input_shape):
2055 error_result = True
2056
2057 info_dict = {
2058 "error_name": error_name,
2059 "error_result": error_result,
2060 "error_reason": error_reason,
2061 "param_reqs": param_reqs,
2062 }
2063 return info_dict
2064
2065 @staticmethod
2066 def evConcatInputDimMismatch(check=False, **kwargs):
2067 error_name = ErrorIf.ConcatInputDimMismatch
2068 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2069 error_result = False
2070 error_reason = "Input dimensions differ on too many axes"
2071
2072 if check:
2073 inputs = kwargs["inputs"]
2074 input_shape = kwargs["input_shape"]
2075 axis = kwargs["axis"]
2076
2077 # Ensure rank is valid before checking dims.
2078 valid_rank = True
2079 for input in inputs:
2080 if len(input.shape) != len(input_shape):
2081 valid_rank = False
2082
2083 if valid_rank:
2084 for input in inputs:
2085 for i, dim in enumerate(input.shape):
2086 if dim != input_shape[i] and axis != i:
2087 error_result = True
2088
2089 info_dict = {
2090 "error_name": error_name,
2091 "error_result": error_result,
2092 "error_reason": error_reason,
2093 "param_reqs": param_reqs,
2094 }
2095 return info_dict
2096
2097 @staticmethod
2098 def evConcatShapeSumMismatch(check=False, **kwargs):
2099 error_name = ErrorIf.ConcatShapeSumMismatch
2100 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2101 error_result = False
2102 error_reason = "Sum of dimensions on axis not equal to output dimension"
2103
2104 if check:
2105 inputs = kwargs["inputs"]
2106 input_shape = kwargs["input_shape"]
2107 output_shape = kwargs["output_shape"]
2108 axis = kwargs["axis"]
2109
2110 # Ensure rank is valid before checking dims.
2111 valid_params = True
2112 for input in inputs:
2113 if len(input.shape) != len(input_shape):
2114 valid_params = False
2115 if axis < 0 or axis > len(input_shape):
2116 valid_params = False
2117
2118 if valid_params:
2119 axis_dim_sum = 0
2120 for input in inputs:
2121 axis_dim_sum += input.shape[axis]
2122
2123 if axis_dim_sum != output_shape[axis]:
2124 error_result = True
2125
2126 info_dict = {
2127 "error_name": error_name,
2128 "error_result": error_result,
2129 "error_reason": error_reason,
2130 "param_reqs": param_reqs,
2131 }
2132 return info_dict
2133
2134 @staticmethod
2135 def evInputListThenGraphMismatch(check=False, **kwargs):
2136 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2137 param_reqs = {"rank": None, "dtype": None, "shape": None}
2138 error_result = False
2139 error_reason = "Input list shape does not match then-graph shape"
2140
2141 if check:
2142 a = kwargs["a"]
2143 b = kwargs["b"]
2144 basicBlocks = kwargs["basicBlocks"]
2145 then_block = basicBlocks[1]
2146 then_inputs = then_block.inputs
2147 then_tens = then_block.tensors
2148 if (a.shape != then_tens[then_inputs[0]].shape) or (
2149 b.shape != then_tens[then_inputs[1]].shape
2150 ):
2151 error_result = True
2152
2153 info_dict = {
2154 "error_name": error_name,
2155 "error_result": error_result,
2156 "error_reason": error_reason,
2157 "param_reqs": param_reqs,
2158 }
2159 return info_dict
2160
2161 @staticmethod
2162 def evInputListElseGraphMismatch(check=False, **kwargs):
2163 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2164 param_reqs = {"rank": None, "dtype": None, "shape": None}
2165 error_result = False
2166 error_reason = "Input list shape does not match else-graph shape"
2167
2168 if check:
2169 a = kwargs["a"]
2170 b = kwargs["b"]
2171 basicBlocks = kwargs["basicBlocks"]
2172 else_block = basicBlocks[2]
2173 else_inputs = else_block.inputs
2174 else_tens = else_block.tensors
2175 if (a.shape != else_tens[else_inputs[0]].shape) or (
2176 b.shape != else_tens[else_inputs[1]].shape
2177 ):
2178 error_result = True
2179
2180 info_dict = {
2181 "error_name": error_name,
2182 "error_result": error_result,
2183 "error_reason": error_reason,
2184 "param_reqs": param_reqs,
2185 }
2186 return info_dict
2187
2188 @staticmethod
2189 def evOutputListThenGraphMismatch(check=False, **kwargs):
2190 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2191 param_reqs = {"rank": None, "dtype": None, "shape": None}
2192 error_result = False
2193 error_reason = "Output list shape does not match then-graph shape"
2194
2195 if check:
2196 basicBlocks = kwargs["basicBlocks"]
2197 cond_block = basicBlocks[0]
2198 cond_outputs = cond_block.outputs
2199 cond_tens = cond_block.tensors
2200 then_block = basicBlocks[1]
2201 then_outputs = then_block.outputs
2202 then_tens = then_block.tensors
2203 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2204 error_result = True
2205
2206 info_dict = {
2207 "error_name": error_name,
2208 "error_result": error_result,
2209 "error_reason": error_reason,
2210 "param_reqs": param_reqs,
2211 }
2212 return info_dict
2213
2214 @staticmethod
2215 def evOutputListElseGraphMismatch(check=False, **kwargs):
2216 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2217 param_reqs = {"rank": None, "dtype": None, "shape": None}
2218 error_result = False
2219 error_reason = "Output list shape does not match else-graph shape"
2220
2221 if check:
2222 basicBlocks = kwargs["basicBlocks"]
2223 cond_block = basicBlocks[0]
2224 cond_outputs = cond_block.outputs
2225 cond_tens = cond_block.tensors
2226 else_block = basicBlocks[2]
2227 else_outputs = else_block.outputs
2228 else_tens = else_block.tensors
2229 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2230 error_result = True
2231
2232 info_dict = {
2233 "error_name": error_name,
2234 "error_result": error_result,
2235 "error_reason": error_reason,
2236 "param_reqs": param_reqs,
2237 }
2238 return info_dict
2239
2240 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002241 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2242 error_name = ErrorIf.CondIfCondNotMatchingBool
2243 param_reqs = {"rank": None, "dtype": None, "shape": None}
2244 error_result = False
2245 error_reason = "Conditional tensor does not match bool type"
2246
2247 if check:
2248 cond = kwargs["cond"]
2249 if cond.dtype != DType.BOOL:
2250 error_result = True
2251
2252 info_dict = {
2253 "error_name": error_name,
2254 "error_result": error_result,
2255 "error_reason": error_reason,
2256 "param_reqs": param_reqs,
2257 }
2258 return info_dict
2259
2260 @staticmethod
2261 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2262 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2263 param_reqs = {"rank": None, "dtype": None, "shape": None}
2264 error_result = False
2265 error_reason = "Conditional tensor is not equal to a size of one"
2266
2267 if check:
2268 cond = kwargs["cond"]
2269 # Size of 1 is equivalent to rank 0
2270 if len(cond.shape) != 0:
2271 error_result = True
2272
2273 info_dict = {
2274 "error_name": error_name,
2275 "error_result": error_result,
2276 "error_reason": error_reason,
2277 "param_reqs": param_reqs,
2278 }
2279 return info_dict
2280
2281 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002282 def evInputListOutputListMismatch(check=False, **kwargs):
2283 error_name = ErrorIf.InputListOutputListMismatch
2284 param_reqs = {"rank": None, "dtype": None, "shape": None}
2285 error_result = False
2286 error_reason = "Input list does not match output list"
2287
2288 if check:
2289 basicBlocks = kwargs["basicBlocks"]
2290 while_block = basicBlocks[0]
2291 while_inputs = while_block.inputs
2292 while_outputs = while_block.outputs
2293 while_tens = while_block.tensors
2294 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2295 error_result = True
2296
2297 info_dict = {
2298 "error_name": error_name,
2299 "error_result": error_result,
2300 "error_reason": error_reason,
2301 "param_reqs": param_reqs,
2302 }
2303 return info_dict
2304
2305 @staticmethod
2306 def evInputListCondGraphMismatch(check=False, **kwargs):
2307 error_name = ErrorIf.InputListCondGraphMismatch
2308 param_reqs = {"rank": None, "dtype": None, "shape": None}
2309 error_result = False
2310 error_reason = "Input list does not match cond graph"
2311
2312 if check:
2313 basicBlocks = kwargs["basicBlocks"]
2314 while_block = basicBlocks[0]
2315 while_inputs = while_block.inputs
2316 while_tens = while_block.tensors
2317 cond_block = basicBlocks[1]
2318 cond_inputs = cond_block.inputs
2319 cond_tens = cond_block.tensors
2320 if (
2321 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2322 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2323 error_result = True
2324
2325 info_dict = {
2326 "error_name": error_name,
2327 "error_result": error_result,
2328 "error_reason": error_reason,
2329 "param_reqs": param_reqs,
2330 }
2331 return info_dict
2332
2333 @staticmethod
2334 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2335 error_name = ErrorIf.InputListBodyGraphInputMismatch
2336 param_reqs = {"rank": None, "dtype": None, "shape": None}
2337 error_result = False
2338 error_reason = "Input list does not match body graph input"
2339
2340 if check:
2341 basicBlocks = kwargs["basicBlocks"]
2342 while_block = basicBlocks[0]
2343 while_inputs = while_block.inputs
2344 while_tens = while_block.tensors
2345 body_block = basicBlocks[2]
2346 body_outputs = body_block.inputs
2347 body_tens = body_block.tensors
2348 if (
2349 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2350 ) or (
2351 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2352 ):
2353 error_result = True
2354
2355 info_dict = {
2356 "error_name": error_name,
2357 "error_result": error_result,
2358 "error_reason": error_reason,
2359 "param_reqs": param_reqs,
2360 }
2361 return info_dict
2362
2363 @staticmethod
2364 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2365 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2366 param_reqs = {"rank": None, "dtype": None, "shape": None}
2367 error_result = False
2368 error_reason = "Input list does not match body graph output"
2369
2370 if check:
2371 basicBlocks = kwargs["basicBlocks"]
2372 while_block = basicBlocks[0]
2373 while_inputs = while_block.inputs
2374 while_tens = while_block.tensors
2375 body_block = basicBlocks[2]
2376 body_outputs = body_block.outputs
2377 body_tens = body_block.tensors
2378 if (
2379 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2380 ) or (
2381 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2382 ):
2383 error_result = True
2384 info_dict = {
2385 "error_name": error_name,
2386 "error_result": error_result,
2387 "error_reason": error_reason,
2388 "param_reqs": param_reqs,
2389 }
2390 return info_dict
2391
2392 @staticmethod
2393 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2394 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2395 param_reqs = {"rank": None, "dtype": None, "shape": None}
2396 error_result = False
2397 error_reason = "Cond graph output is not a match list of booleans"
2398
2399 if check:
2400 basicBlocks = kwargs["basicBlocks"]
2401 cond_block = basicBlocks[1]
2402 cond_outputs = cond_block.outputs
2403 cond_tens = cond_block.tensors
2404 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2405 error_result = True
2406
2407 info_dict = {
2408 "error_name": error_name,
2409 "error_result": error_result,
2410 "error_reason": error_reason,
2411 "param_reqs": param_reqs,
2412 }
2413 return info_dict
2414
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002415 @staticmethod
2416 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2417 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2418 param_reqs = {"rank": None, "dtype": None, "shape": None}
2419 error_result = False
2420 error_reason = "Cond graph output is not a shape of size one"
2421
2422 if check:
2423 basicBlocks = kwargs["basicBlocks"]
2424 cond_block = basicBlocks[1]
2425 cond_outputs = cond_block.outputs
2426 cond_tens = cond_block.tensors
2427 # Size of 1 is equivalent to rank 0
2428 if len(cond_tens[cond_outputs[0]].shape) != 0:
2429 error_result = True
2430
2431 info_dict = {
2432 "error_name": error_name,
2433 "error_result": error_result,
2434 "error_reason": error_reason,
2435 "param_reqs": param_reqs,
2436 }
2437 return info_dict
2438
Luke Hutton261b7b62023-01-10 14:50:31 +00002439 @staticmethod
2440 def evKernelNotPowerOfTwo(check=False, **kwargs):
2441 error_name = ErrorIf.KernelNotPowerOfTwo
2442 param_reqs = {"rank": None, "dtype": None, "shape": None}
2443 error_result = False
2444 error_reason = "kernel height and/or width not a power of two"
2445
2446 def is_power_of_two(x):
2447 return math.log(x, 2).is_integer()
2448
2449 if check:
2450 shape = kwargs["input_shape"]
2451 if len(shape) == 3:
2452 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2453 error_result = not valid_kernel
2454
2455 info_dict = {
2456 "error_name": error_name,
2457 "error_result": error_result,
2458 "error_reason": error_reason,
2459 "param_reqs": param_reqs,
2460 }
2461 return info_dict
2462
Luke Hutton57287132023-02-06 14:54:18 +00002463 @staticmethod
2464 def evFFTInputShapeMismatch(check=False, **kwargs):
2465 error_name = ErrorIf.FFTInputShapeMismatch
2466 param_reqs = {"rank": None, "dtype": None, "shape": None}
2467 error_result = False
2468 error_reason = "Mismatch between real and imaginary input shapes"
2469
2470 if check:
2471 input1 = kwargs["input1"]
2472 input2 = kwargs["input2"]
2473
2474 if input1.shape != input2.shape:
2475 error_result = True
2476
2477 info_dict = {
2478 "error_name": error_name,
2479 "error_result": error_result,
2480 "error_reason": error_reason,
2481 "param_reqs": param_reqs,
2482 }
2483 return info_dict
2484
2485 @staticmethod
2486 def evFFTOutputShapeMismatch(check=False, **kwargs):
2487 error_name = ErrorIf.FFTOutputShapeMismatch
2488 param_reqs = {"rank": None, "dtype": None, "shape": None}
2489 error_result = False
2490 error_reason = (
2491 "Mismatch between provided and expected output kernel (H, W) shape"
2492 )
2493
2494 if check:
2495 op = kwargs["op"]
2496 input_shape = kwargs["input_shape"]
2497
2498 if len(input_shape) == 3:
2499 output_shapes = kwargs["output_shape"]
2500
2501 # Ignoring batch size (N) from input shape
2502 expected_shape = input_shape[1:]
2503 if op["op"] == Op.RFFT2D:
2504 expected_shape[1] = expected_shape[1] // 2 + 1
2505
2506 # Ignoring batch size (N) from output shapes
2507 output_shape_0 = output_shapes[0][1:]
2508 output_shape_1 = output_shapes[1][1:]
2509 # Ensure sure the kernel sizes (H, W) of both outputs match the expected
2510 if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
2511 error_result = True
2512
2513 info_dict = {
2514 "error_name": error_name,
2515 "error_result": error_result,
2516 "error_reason": error_reason,
2517 "param_reqs": param_reqs,
2518 }
2519 return info_dict
2520
Jerry Ge264f7fa2023-04-21 22:49:57 +00002521 @staticmethod
Jerry Ge135c9552023-05-23 20:59:32 +00002522 def calculateBroadcastShape(input_shape_a, input_shape_b):
2523 if input_shape_a is not None and input_shape_b is not None:
2524 calculated_shape = input_shape_a.copy()
2525 for idx in range(len(calculated_shape)):
2526 if calculated_shape[idx] == 1:
2527 calculated_shape[idx] = input_shape_b[idx]
2528 elif (
2529 input_shape_b[idx] != 1
2530 and input_shape_b[idx] != calculated_shape[idx]
2531 ):
2532 return None
2533 return calculated_shape
2534 else:
2535 return None
2536
2537 @staticmethod
2538 def evBroadcastShapesMismatch(check=False, **kwargs):
2539 error_name = ErrorIf.BroadcastShapesMismatch
2540 param_reqs = {"rank": None, "dtype": None, "shape": None}
2541 error_result = False
2542 error_reason = "Broadcast shape calculating failed"
2543
2544 if check:
2545 input_shape_a = kwargs["input1"].shape
2546 input_shape_b = kwargs["input2"].shape
2547 input_shape_c = (
2548 kwargs["input3"].shape if "input3" in kwargs else input_shape_b
2549 )
2550
2551 if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
2552 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
2553 input_shape_c,
2554 TosaErrorValidator.calculateBroadcastShape(
2555 input_shape_a, input_shape_b
2556 ),
2557 )
2558 error_result = calculated_shape is None
2559
2560 info_dict = {
2561 "error_name": error_name,
2562 "error_result": error_result,
2563 "error_reason": error_reason,
2564 "param_reqs": param_reqs,
2565 }
2566 return info_dict
2567
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002568
2569class TosaInvalidValidator:
2570 @staticmethod
2571 def ivWrongDataTypeOrModeResize(**kwargs):
2572 input_dtype = kwargs["input_dtype"]
2573 args = kwargs["args"]
2574 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002575 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002576
2577 if mode == ResizeMode.BILINEAR:
2578 # Invalid output data type / Invalid input datatype
2579 return (
2580 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002581 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002582 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002583 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002584 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002585 )
2586 elif mode == ResizeMode.NEAREST:
2587 # Invalid output data type / Invalid input datatype
2588 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002589 input_dtype
2590 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002591 )
2592 else:
2593 # Invalid resize mode
2594 return True
2595
2596 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002597 def ivHeightWidthInvalid(**kwargs):
2598 opName = kwargs["opName"]
2599
2600 inputShapes = kwargs["shapeList"]
2601 input_shape = inputShapes[0]
2602
2603 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002604
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002605 if isinstance(args, dict):
2606 args_dict = args
2607 else:
2608 # Create args_dict from list elements
2609 # TODO - Remove this once all NWHC operators agFunctions have been
2610 # converted to args_dict output
2611
2612 # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
2613 stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
2614 args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]}
2615 # Alias different info for each op
2616 args_dict["kernel"] = args[pad_idx + 1]
2617 args_dict["out_shape"] = args[pad_idx + 1]
2618 args_dict["dilation"] = args[pad_idx + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002619
2620 # Common info for all ops
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002621 strides = args_dict["stride"]
2622 padding = args_dict["pad"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002623
2624 if opName.endswith("pool2d"):
2625 # avg_pool2d, max_pool2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002626 kernel_shape = args_dict["kernel"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002627 h = (
2628 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2629 ) // strides[0]
2630 w = (
2631 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2632 ) // strides[1]
2633 # return True if any dimension is < 1
2634 return h < 1 or w < 1
2635
2636 if opName.startswith("transpose_conv2d"):
2637 # transpose_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002638 output_shape = args_dict["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002639 filter_shape = inputShapes[1]
2640 kernel_shape = filter_shape[1:-1]
2641
TatWai Chong24594f52022-06-08 00:48:04 -07002642 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002643 """Calculate the transpose_conv2d output size for a dimension."""
2644 return (in_size - 1) * stride + kernel_size + in_pad + out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002645
Jeremy Johnson0c716862023-04-13 17:18:19 +01002646 h = get_out_size(
2647 input_shape[1],
2648 strides[0],
2649 kernel_shape[0],
2650 padding[0],
2651 padding[1],
2652 )
2653 w = get_out_size(
2654 input_shape[2],
2655 strides[1],
2656 kernel_shape[1],
2657 padding[2],
2658 padding[3],
2659 )
2660 if output_shape[1] == h and output_shape[2] == w:
2661 return False
2662 # output shape does not match the expected shape
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002663 return True
2664
2665 if "conv2d" in opName or "conv3d" in opName:
2666 # conv2d, conv3d, depthwise_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002667 dilations = args_dict["dilation"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002668 filter_shape = inputShapes[1]
2669 kernel_shape = (
2670 filter_shape[0:2]
2671 if opName.startswith("depthwise_conv2d")
2672 else filter_shape[1:-1]
2673 )
2674
2675 for i in range(len(kernel_shape)):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002676 pad_offset = i * 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002677 dim = (
2678 input_shape[i + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002679 - 1
2680 + padding[pad_offset]
2681 + padding[pad_offset + 1]
2682 - (kernel_shape[i] - 1) * dilations[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002683 ) // strides[i] + 1
2684 # return True if any dimension is < 1
2685 if dim < 1:
2686 return True
2687 return False
2688
2689 assert False, f"Unrecognized Op: {opName}"
2690
2691 @staticmethod
2692 def ivNonPositiveOutputShape(**kwargs):
2693 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002694 output_shape = args[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002695 if output_shape[1] <= 0 or output_shape[2] <= 0:
2696 # Negative output shape
2697 return True
2698 return False