blob: 86be34708dc1590c58298a589fd7172246d2117c [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01006from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from generator.tosa_utils import product
8from generator.tosa_utils import usableDTypes
9from generator.tosa_utils import valueToName
10from tosa.DType import DType
11from tosa.Op import Op
12from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000013
Matthew Haddone86fd342021-09-07 16:12:21 +010014
15class ErrorIf(object):
16 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010017 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
18 ScaleNLargerMax = "ScaleNLargerMax"
19 ScaleDLargerMax = "ScaleDLargerMax"
20 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010021 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010022 BorderSmallerMin = "BorderSmallerMin"
23 BorderLargerEqualMax = "BorderLargerEqualMax"
24 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
25 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010026 WrongInputType = "WrongInputType"
27 WrongOutputType = "WrongOutputType"
28 WrongInputList = "WrongInputList"
29 WrongOutputList = "WrongOutputList"
30 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010031 BatchMismatch = "BatchMismatch"
32 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010033 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000034 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010037 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 AxisSmallerZero = "AxisSmallerZero"
39 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010040 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
41 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010042 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010043 KernelSmallerOne = "KernelSmallerOne"
44 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000045 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010046 PadSmallerZero = "PadSmallerZero"
47 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010048 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010049 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010050 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
51 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
52 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010053 ScaleNotTrue = "ScaleNotTrue"
54 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010055 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
56 StartSmallerZero = "StartSmallerZero"
57 SizeSmallerEqualZero = "SizeSmallerEqualZero"
58 StartSizeOutsideBounds = "StartSizeOutsideBounds"
59 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
60 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
61 IndexOutsideBounds = "IndexOutsideBounds"
62 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010063 MaxSmallerMin = "MaxSmallerMin"
64 ConcatInputRankMismatch = "ConcatInputRankMismatch"
65 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010066 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010067 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
68 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
69 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
70 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
71 InputListOutputListMismatch = "InputListOutputListMismatch"
72 InputListCondGraphMismatch = "InputListCondGraphMismatch"
73 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
74 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
75 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010076 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
77 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000078 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
79 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
80 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000081 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Luke Hutton57287132023-02-06 14:54:18 +000082 FFTInputShapeMismatch = "FFTInputShapeMismatch"
83 FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
Jerry Ge264f7fa2023-04-21 22:49:57 +000084 ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
85 ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
Jerry Ge135c9552023-05-23 20:59:32 +000086 BroadcastShapesMismatch = "BroadcastShapesMismatch"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087
88
89class TosaErrorIfArgGen:
90 @staticmethod
91 def eiResizeErrorIf(
92 testGen,
93 error_name,
94 mode,
95 dtype,
96 shapeList,
97 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010098 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010099 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100100 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100102 if error_name == ErrorIf.ScaleSmallerEqualZero:
103 index = testGen.randInt(low=0, high=4)
104 scale[index] = testGen.rng.choice([-2, -1, 0])
105 elif error_name == ErrorIf.ScaleNLargerMax:
106 index = testGen.rng.choice([0, 2])
107 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
108 elif error_name == ErrorIf.ScaleDLargerMax:
109 index = testGen.rng.choice([1, 3])
110 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100112 if error_name == ErrorIf.OffsetLargerEqualMax:
113 index = testGen.rng.choice([0, 1])
114 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
115 elif error_name == ErrorIf.OffsetSmallerMin:
116 index = testGen.rng.choice([0, 1])
117 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
118
119 if error_name == ErrorIf.BorderLargerEqualMax:
120 index = testGen.rng.choice([0, 1])
121 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
122 elif error_name == ErrorIf.BorderSmallerMin:
123 index = testGen.rng.choice([0, 1])
124 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100125
126 if error_name == ErrorIf.WrongOutputType:
127 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
128 incorrect_types = (
129 DType.INT4,
130 DType.INT16,
131 DType.INT32,
132 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100133 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100134 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100135 )
136 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
137 incorrect_types = (
138 DType.INT4,
139 DType.INT8,
140 DType.INT32,
141 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100142 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100143 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100144 )
145 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
146 incorrect_types = (
147 DType.INT4,
148 DType.INT8,
149 DType.INT16,
150 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100151 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100152 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100153 )
154 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
155 incorrect_types = (
156 DType.INT4,
157 DType.INT8,
158 DType.INT16,
159 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100160 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100161 DType.FP16,
162 )
163 elif dtype == DType.FP16:
164 incorrect_types = (
165 DType.INT4,
166 DType.INT8,
167 DType.INT16,
168 DType.INT32,
169 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100170 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100171 )
James Ward24dbc422022-10-19 12:20:31 +0100172 elif dtype == DType.BF16:
173 incorrect_types = (
174 DType.INT4,
175 DType.INT8,
176 DType.INT16,
177 DType.INT32,
178 DType.INT48,
179 DType.FP32,
180 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100181 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 incorrect_types = (
183 DType.INT4,
184 DType.INT8,
185 DType.INT16,
186 DType.INT32,
187 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100188 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100189 )
190 outputDType = testGen.rng.choice(a=incorrect_types)
191
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100192 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 @staticmethod
195 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
196 if (
197 error_name == ErrorIf.StrideSmallerOne
198 # padding must not exceed the kernel size
199 and pad[0] < kernel[0]
200 and pad[1] < kernel[0]
201 and pad[2] < kernel[1]
202 and pad[3] < kernel[1]
203 ):
204 wrongStride = (
205 testGen.rng.choice([0, -1, -2, -3]),
206 testGen.rng.choice([0, -1, -2, -3]),
207 )
208 return wrongStride, pad, kernel
209 elif error_name == ErrorIf.PadSmallerZero:
210 wrongPad = (
211 testGen.rng.choice([-1, -2, -3]),
212 testGen.rng.choice([-1, -2, -3]),
213 testGen.rng.choice([-1, -2, -3]),
214 testGen.rng.choice([-1, -2, -3]),
215 )
216 return stride, wrongPad, kernel
217 elif error_name == ErrorIf.KernelSmallerOne:
218 wrongKernel = (
219 testGen.rng.choice([0, -1, -2, -3]),
220 testGen.rng.choice([0, -1, -2, -3]),
221 )
222 return stride, pad, wrongKernel
223 elif error_name == ErrorIf.PadLargerEqualKernel:
224 wrongPad = (
225 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
226 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
227 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
228 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
229 )
230 return stride, wrongPad, kernel
231 else:
232 return None, None, None
233
234 @staticmethod
235 def eiRescaleWrongOutputType(input_dtype, output_dtype):
236 if input_dtype == DType.INT8:
237 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
238 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100239 elif input_dtype == DType.INT16:
240 if output_dtype not in [
241 DType.UINT8,
242 DType.INT8,
243 DType.UINT16,
244 DType.INT16,
245 DType.INT32,
246 ]:
247 return True
248 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
250 return True
251 elif input_dtype == DType.INT48:
252 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
253 return True
254 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100255 if output_dtype not in [DType.INT8, DType.INT16]:
256 return True
257 elif input_dtype == DType.UINT16:
258 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 return True
260 return False
261
262 @staticmethod
263 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
264 # Mess up input/output tensors for ERROR_IF checks
265 if error_name == "WrongInputList":
266 add_input = testGen.rng.choice([True, False])
267 if add_input:
268 input_list.append("eiDummyInput")
269 else:
270 input_list = input_list[:-1]
271 elif error_name == "WrongOutputList":
272 add_output = testGen.rng.choice([True, False])
273 if add_output:
274 output_list.append("eiDummyOutput")
275 else:
276 output_list = []
277 return input_list, output_list
278
279 @staticmethod
280 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
281 """Restrict the dimensions and overall size of a shape to
282 max_dim and max_items.
283 """
284 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
285 while product(new_shape) > max_items:
286 new_shape = [max(d - 1, 1) for d in new_shape]
287 return new_shape
288
289 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
290 if error_name == ErrorIf.StartSmallerZero:
291 newStart = []
292 for i in range(len(input_shape)):
293 newStart.append(testGen.rng.choice([-3, -2, -1]))
294 return newStart, size
295 elif error_name == ErrorIf.SizeSmallerEqualZero:
296 newSize = []
297 for i in range(len(input_shape)):
298 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
299 return start, newSize
300 elif error_name == ErrorIf.StartSizeOutsideBounds:
301 newStart, newSize = [], []
302 for i in range(len(input_shape)):
303 newStart.append(input_shape[i] - 1)
304 newSize.append(testGen.rng.choice([2, 3, 4]))
305 return newStart, newSize
306 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
307 remove = testGen.rng.choice([True, False])
308 if remove:
309 newStart = start[1:]
310 newSize = size[1:]
311 else:
312 newStart = start
313 newStart.append(1)
314 newSize = size
315 newSize.append(1)
316 return newStart, newSize
317 else:
318 return start, size
319
320 @staticmethod
321 def eiCastErrorIf(testGen, input_dtype):
James Ward736fd1a2023-01-23 17:13:37 +0000322 if input_dtype in [DType.BOOL, DType.FP32]:
323 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
324 elif input_dtype in [DType.FP16, DType.BF16]:
325 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
327 outputDType = [DType.INT48]
328 else:
James Ward736fd1a2023-01-23 17:13:37 +0000329 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330 return outputDType
331
332
333class TosaErrorValidator:
334 @staticmethod
335 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
336 """Check ERROR_IF statements are caught and set the expected result.
337
338 Args:
339 serializer: the serializer to set the expected result in
340 validator_fcns: a sequence of validator functions to verify the result
341 error_name: the name of the ERROR_IF condition to check for
342 kwargs: keyword arguments for the validator functions
343 Returns:
344 True if the result matches the expected result; otherwise False
345 """
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000346 if validator_fcns is None:
347 # Nothing to do
348 return True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100349 overall_result = True
350 for val_fcn in validator_fcns:
351 val_result = val_fcn(True, **kwargs)
352 validator_name = val_result["error_name"]
353 error_result = val_result["error_result"]
354 error_reason = val_result["error_reason"]
355
356 # expect an error IFF the error_name and validator_name match
357 expected_result = error_result == (error_name == validator_name)
358 overall_result &= expected_result
359
360 if expected_result and error_result:
361 serializer.setExpectedReturnCode(2, True, desc=error_reason)
362 elif error_result: # and not expected_result
363 print(
364 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
365 f" Expected: {error_name}, Got: {validator_name}"
366 )
367 elif not expected_result: # and not error_result
368 print(
369 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
370 f" Expected: {error_name}"
371 )
372
373 if not expected_result:
374 for k, v in sorted(kwargs.items()):
375 if k != "op":
376 if k.endswith("dtype"):
377 v = valueToName(DType, v)
378 print(f" {k} = {v}")
379
380 return overall_result
381
382 @staticmethod
383 def evWrongInputType(check=False, **kwargs):
384 error_result = False
385
386 # Find the unsupported input data types
387 op = kwargs["op"]
388 input_dtypes = op["types"]
389 allowed_input_dtypes = {
390 t[0] if isinstance(t, list) else t for t in input_dtypes
391 }
392 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
393
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100394 # Turn the wrong dtypes into required list of types
395 if op["op"] in [
396 Op.FULLY_CONNECTED,
397 Op.CONV2D,
398 Op.CONV3D,
399 Op.DEPTHWISE_CONV2D,
400 Op.TRANSPOSE_CONV2D,
401 ]:
402 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
403
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100404 if op["op"] == Op.CLAMP:
405 wrong_input_dtypes.remove(DType.INT48)
406
407 if check:
408 input_dtype = kwargs["input_dtype"]
409 if input_dtype not in allowed_input_dtypes:
410 error_result = True
411
412 info_dict = {
413 "error_name": ErrorIf.WrongInputType,
414 "error_result": error_result,
415 "error_reason": "Input data type not supported for this operator",
416 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
417 }
418 return info_dict
419
420 @staticmethod
421 def evWrongOutputType(check=False, **kwargs):
422 error_result = False
423
424 if check:
425 input_dtype = kwargs["input_dtype"]
426 output_dtype = kwargs["output_dtype"]
427 op = kwargs["op"]
428
429 if op["op"] == Op.RESIZE:
430 mode = kwargs["mode"]
431 if (
432 (
433 mode == ResizeMode.NEAREST
434 and input_dtype == DType.INT8
435 and output_dtype != DType.INT8
436 )
437 or (
438 mode == ResizeMode.NEAREST
439 and input_dtype == DType.INT16
440 and output_dtype != DType.INT16
441 )
442 or (
443 mode == ResizeMode.BILINEAR
444 and input_dtype == DType.INT8
445 and output_dtype != DType.INT32
446 )
447 or (
448 mode == ResizeMode.BILINEAR
449 and input_dtype == DType.INT16
450 and output_dtype != DType.INT48
451 )
James Ward8b390432022-08-12 20:48:56 +0100452 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100453 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100454 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100455 ):
456 error_result = True
457
458 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100459 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
460 input_dtype, output_dtype
461 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100462
463 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
464 if (
465 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
466 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100467 or (
468 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100469 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100470 )
James Ward24dbc422022-10-19 12:20:31 +0100471 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100472 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100473 ):
474 error_result = True
475
476 elif op["op"] == Op.ARGMAX:
477 if (
James Ward24dbc422022-10-19 12:20:31 +0100478 input_dtype
479 in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100480 and output_dtype != DType.INT32
481 ):
482 error_result = True
483
484 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100485 if (
James Ward24dbc422022-10-19 12:20:31 +0100486 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100487 and output_dtype != DType.INT32
488 ):
489 error_result = True
490 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100491 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100492 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
493 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100494 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100495 error_result = True
496
497 elif op["op"] == Op.TABLE:
498 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
499 error_result = True
500 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
501 error_result = True
502
503 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
504 if output_dtype != DType.BOOL:
505 error_result = True
506
507 elif op["op"] == Op.CAST:
508 if (
509 (
510 input_dtype == DType.BOOL
511 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
512 )
513 or (
514 input_dtype == DType.INT8
515 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100516 not in [
517 DType.BOOL,
518 DType.INT16,
519 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100520 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100521 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100522 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100523 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100524 )
525 or (
526 input_dtype == DType.INT16
527 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100528 not in [
529 DType.BOOL,
530 DType.INT8,
531 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100532 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100533 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100534 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100535 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100536 )
537 or (
538 input_dtype == DType.INT32
539 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100540 not in [
541 DType.BOOL,
542 DType.INT8,
543 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100544 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100545 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100546 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100547 ]
548 )
549 or (
550 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000551 and output_dtype
552 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100553 )
554 or (
James Ward24dbc422022-10-19 12:20:31 +0100555 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000556 and output_dtype
557 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +0100558 )
559 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100560 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000561 and output_dtype
562 not in [
563 DType.INT8,
564 DType.INT16,
565 DType.INT32,
566 DType.FP16,
567 DType.BF16,
568 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100569 )
570 ):
571 error_result = True
572
Luke Hutton57287132023-02-06 14:54:18 +0000573 elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000574 if not all([ty == input_dtype for ty in output_dtype]):
575 error_result = True
576
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100577 elif op["op"] in {
578 Op.CONV2D,
579 Op.CONV3D,
580 Op.DEPTHWISE_CONV2D,
581 Op.TRANSPOSE_CONV2D,
582 }:
583 if (
584 input_dtype == DType.INT8
585 and output_dtype != DType.INT32
586 or input_dtype == DType.INT16
587 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100588 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100589 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100590 or input_dtype == DType.BF16
591 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100592 or input_dtype == DType.FP32
593 and output_dtype != DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100594 ):
595 error_result = True
596 # invalid input types are ignored, to avoid reporting multiple errors
597
598 else:
599 if output_dtype != input_dtype:
600 error_result = True
601
602 info_dict = {
603 "error_name": ErrorIf.WrongOutputType,
604 "error_result": error_result,
605 "error_reason": (
606 "Output data type not supported for this configuration of operator"
607 ),
608 "param_reqs": {"rank": None, "dtype": None, "shape": None},
609 }
610 return info_dict
611
612 @staticmethod
613 def evWrongRank(check=False, **kwargs):
614 all_ranks = (1, 2, 3, 4, 5)
615
616 # Make a list of incorrect ranks
617 assert "op" in kwargs
618 op = kwargs["op"]
619 rmin, rmax = op["rank"]
620 rank_range = range(rmin, rmax + 1)
621 incorrect_ranks = list(set(all_ranks) - set(rank_range))
622 # Remove small incorrect ranks to avoid index errors
623 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
624 # Set minimum incorrect rank to 3 to avoid index error
625 if op["op"] in [Op.RESIZE]:
626 incorrect_ranks = [3, 5]
627 elif op["op"] in [Op.TRANSPOSE]:
628 incorrect_ranks = [7, 8]
629 elif op["op"] in [Op.CONV3D]:
630 incorrect_ranks = [6, 7]
631
632 error_name = ErrorIf.WrongRank
633 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
634 error_result = False
635 error_reason = "Rank not supported for this operator"
636
637 if check:
638 input_shape = kwargs["input_shape"]
639
640 if (
641 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
642 and len(input_shape) != 4
643 ):
644 error_result = True
645 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
646 error_result = True
647 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
648 error_result = True
649 else:
650 if len(input_shape) not in rank_range:
651 error_result = True
652
653 info_dict = {
654 "error_name": error_name,
655 "error_result": error_result,
656 "error_reason": error_reason,
657 "param_reqs": param_reqs,
658 }
659 return info_dict
660
661 @staticmethod
662 def evWrongInputList(check=False, **kwargs):
663 error_name = ErrorIf.WrongInputList
664 param_reqs = {"rank": None, "dtype": None, "shape": None}
665 error_result = False
666 error_reason = "Op input list does not match expected input"
667
668 if check:
669 op = kwargs["op"]
670 input_list = kwargs["input_list"]
671 num_operands = kwargs["num_operands"]
672 if op["op"] in [Op.SCATTER, Op.GATHER]:
673 # SCATTER/GATHER add an indices input tensor in their build functions
674 num_operands += 1
675 if len(input_list) != num_operands:
676 error_result = True
677
678 info_dict = {
679 "error_name": error_name,
680 "error_result": error_result,
681 "error_reason": error_reason,
682 "param_reqs": param_reqs,
683 }
684 return info_dict
685
686 @staticmethod
687 def evWrongOutputList(check=False, **kwargs):
688 error_name = ErrorIf.WrongOutputList
689 param_reqs = {"rank": None, "dtype": None, "shape": None}
690 error_result = False
691 error_reason = "Op output list does not match expected output"
692
693 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000694 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100695 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000696 expected_length = 1
Luke Hutton57287132023-02-06 14:54:18 +0000697 if op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000698 expected_length = 2
699
700 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100701 error_result = True
702
703 info_dict = {
704 "error_name": error_name,
705 "error_result": error_result,
706 "error_reason": error_reason,
707 "param_reqs": param_reqs,
708 }
709 return info_dict
710
711 @staticmethod
712 def evMaxDimExceeded(check=False, **kwargs):
713 error_name = ErrorIf.MaxDimExceeded
714 param_reqs = {
715 "rank": [4, 4],
716 "dtype": [DType.INT8],
717 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
718 }
719 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100720 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100721
722 if check:
723 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100724 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100725 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100726 (input_shape[1] >= MAX_RESIZE_DIMENSION)
727 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
728 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
729 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100730 ):
731 error_result = True
732
733 info_dict = {
734 "error_name": error_name,
735 "error_result": error_result,
736 "error_reason": error_reason,
737 "param_reqs": param_reqs,
738 }
739 return info_dict
740
741 @staticmethod
742 def evBatchMismatch(check=False, **kwargs):
743 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000744 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100745 error_result = False
746 error_reason = "Input batch size not equal to output batch size"
747
748 assert "op" in kwargs
749 op = kwargs["op"]
750 rmin, rmax = op["rank"]
751 rank_range = range(rmin, rmax + 1)
752
753 if check:
754 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100755
Luke Hutton261b7b62023-01-10 14:50:31 +0000756 for output in kwargs["result_tensors"]:
757 output_shape = (
758 output.shape
759 ) # Note batch is expected to be the first dim
760 if (len(input_shape) in rank_range) and (
761 input_shape[0] != output_shape[0]
762 ):
763 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100764
765 info_dict = {
766 "error_name": error_name,
767 "error_result": error_result,
768 "error_reason": error_reason,
769 "param_reqs": param_reqs,
770 }
771 return info_dict
772
773 @staticmethod
774 def evChannelMismatch(check=False, **kwargs):
775 error_name = ErrorIf.ChannelMismatch
776 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
777 error_result = False
778 error_reason = "Input channel size not equal to output channel size"
779
780 assert "op" in kwargs
781 op = kwargs["op"]
782 rmin, rmax = op["rank"]
783 rank_range = range(rmin, rmax + 1)
784
785 if check:
786 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000787 for output in kwargs["result_tensors"]:
788 output_shape = output.shape # Note this is just (N, OH, OW, C)
789 if (len(input_shape) in rank_range) and (
790 input_shape[3] != output_shape[3]
791 ):
792 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100793
794 info_dict = {
795 "error_name": error_name,
796 "error_result": error_result,
797 "error_reason": error_reason,
798 "param_reqs": param_reqs,
799 }
800 return info_dict
801
802 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100803 def evScaleSmallerEqualZero(check=False, **kwargs):
804 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100805 param_reqs = {"rank": None, "dtype": None, "shape": None}
806 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100807 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100808
809 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100810 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100811
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100812 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100813 error_result = True
814
815 info_dict = {
816 "error_name": error_name,
817 "error_result": error_result,
818 "error_reason": error_reason,
819 "param_reqs": param_reqs,
820 }
821 return info_dict
822
823 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100824 def evScaleNLargerMax(check=False, **kwargs):
825 error_name = ErrorIf.ScaleNLargerMax
826 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100827 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100828 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100829
830 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100831 scale = kwargs["scale"]
832
833 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
834 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100835
836 info_dict = {
837 "error_name": error_name,
838 "error_result": error_result,
839 "error_reason": error_reason,
840 "param_reqs": param_reqs,
841 }
842 return info_dict
843
844 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100845 def evScaleDLargerMax(check=False, **kwargs):
846 error_name = ErrorIf.ScaleDLargerMax
847 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100848 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100849 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100850
851 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100852 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100853
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100854 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
855 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100856 ):
857 error_result = True
858
859 info_dict = {
860 "error_name": error_name,
861 "error_result": error_result,
862 "error_reason": error_reason,
863 "param_reqs": param_reqs,
864 }
865 return info_dict
866
867 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100868 def evOffsetSmallerMin(check=False, **kwargs):
869 error_name = ErrorIf.OffsetSmallerMin
870 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100871 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100872 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100873
874 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100875 scale = kwargs["scale"]
876 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100877
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100878 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100879 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100880 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100881 error_result = True
882
883 info_dict = {
884 "error_name": error_name,
885 "error_result": error_result,
886 "error_reason": error_reason,
887 "param_reqs": param_reqs,
888 }
889 return info_dict
890
891 @staticmethod
892 def evOffsetLargerEqualMax(check=False, **kwargs):
893 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100894 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100895 error_result = False
896 error_reason = "Offset value larger than or equal to maximum value"
897
898 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100899 scale = kwargs["scale"]
900 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100902 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
903 error_result = True
904 elif (
905 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
906 ):
907 error_result = True
908
909 info_dict = {
910 "error_name": error_name,
911 "error_result": error_result,
912 "error_reason": error_reason,
913 "param_reqs": param_reqs,
914 }
915 return info_dict
916
917 @staticmethod
918 def evBorderSmallerMin(check=False, **kwargs):
919 error_name = ErrorIf.BorderSmallerMin
920 param_reqs = {"rank": None, "dtype": None, "shape": None}
921 error_result = False
922 error_reason = "Border value smaller than minimum value"
923
924 if check:
925 scale = kwargs["scale"]
926 border = kwargs["border"]
927
928 if (
929 scale[0] > 0
930 and scale[0] <= (1 << 11)
931 and (border[0] < (-16 * scale[0]))
932 ):
933 error_result = True
934 elif (
935 scale[2] > 0
936 and scale[2] <= (1 << 11)
937 and (border[1] < (-16 * scale[2]))
938 ):
939 error_result = True
940
941 info_dict = {
942 "error_name": error_name,
943 "error_result": error_result,
944 "error_reason": error_reason,
945 "param_reqs": param_reqs,
946 }
947 return info_dict
948
949 @staticmethod
950 def evBorderLargerEqualMax(check=False, **kwargs):
951 error_name = ErrorIf.BorderLargerEqualMax
952 param_reqs = {"rank": None, "dtype": None, "shape": None}
953 error_result = False
954 error_reason = "Border value larger than or equal to maximum value"
955
956 if check:
957 scale = kwargs["scale"]
958 border = kwargs["border"]
959
960 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
961 error_result = True
962 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
963 error_result = True
964
965 info_dict = {
966 "error_name": error_name,
967 "error_result": error_result,
968 "error_reason": error_reason,
969 "param_reqs": param_reqs,
970 }
971 return info_dict
972
973 @staticmethod
974 def checkResizeParams(scale, offset, border):
975 return (
976 min(scale) > 0
977 and max(scale[0], scale[2]) <= (1 << 11)
978 and scale[1] < 16 * scale[0]
979 and scale[3] < 16 * scale[2]
980 and offset[0] >= -scale[0]
981 and offset[1] >= -scale[2]
982 and offset[0] < 16 * scale[0]
983 and offset[1] < 16 * scale[2]
984 and border[0] >= -16 * scale[0]
985 and border[1] >= -16 * scale[2]
986 and border[0] < scale[0]
987 and border[1] < scale[2]
988 )
989
990 @staticmethod
991 def evResizeOutputShapeMismatch(check=False, **kwargs):
992 error_name = ErrorIf.ResizeOutputShapeMismatch
993 param_reqs = {"rank": None, "dtype": None, "shape": None}
994 error_result = False
995 error_reason = (
996 "Mismatch between output shape provided and expected output shape"
997 )
998
999 if check:
1000 input_shape = kwargs["input_shape"]
1001 output_shape = kwargs["output_shape"]
1002 scale = kwargs["scale"]
1003 offset = kwargs["offset"]
1004 border = kwargs["border"]
1005
1006 # Ensure parameters are valid
1007 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1008
1009 if (
1010 params_valid
1011 and max(output_shape) < MAX_RESIZE_DIMENSION
1012 and max(input_shape) < MAX_RESIZE_DIMENSION
1013 ):
1014 output_y = (
1015 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1016 ) // scale[1] + 1
1017 output_x = (
1018 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1019 ) // scale[3] + 1
1020
1021 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001022 error_result = True
1023
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001024 info_dict = {
1025 "error_name": error_name,
1026 "error_result": error_result,
1027 "error_reason": error_reason,
1028 "param_reqs": param_reqs,
1029 }
1030 return info_dict
1031
1032 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001033 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1034 error_name = ErrorIf.ResizeOutputShapeNonInteger
1035 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001037 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001038
1039 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001040 input_shape = kwargs["input_shape"]
1041 scale = kwargs["scale"]
1042 offset = kwargs["offset"]
1043 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001045 # Ensure parameters are valid
1046 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001047
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001048 if params_valid:
1049 remainder_y = (
1050 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1051 ) % scale[1]
1052 remainder_x = (
1053 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1054 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001055
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001056 if max(remainder_y, remainder_x) > 0:
1057 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001058
1059 info_dict = {
1060 "error_name": error_name,
1061 "error_result": error_result,
1062 "error_reason": error_reason,
1063 "param_reqs": param_reqs,
1064 }
1065 return info_dict
1066
1067 @staticmethod
1068 def evRankMismatch(check=False, **kwargs):
1069 error_name = ErrorIf.RankMismatch
1070 param_reqs = {"rank": None, "dtype": None, "shape": None}
1071 error_result = False
1072 error_reason = "Input Rank does not match output rank"
1073
1074 if check:
1075 input1_shape = kwargs["input1"].shape
Luke Huttona4e48ca2023-02-22 11:53:48 +00001076 input2_shape = (
1077 kwargs["input2"].shape if "input2" in kwargs else input1_shape
1078 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001079 # In case of SELECT op
1080 input3_shape = (
1081 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1082 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001083
1084 for output in kwargs["result_tensors"]:
1085 output_shape = output.shape
1086 if (
1087 (len(input1_shape) != len(output_shape))
1088 or (len(input2_shape) != len(output_shape))
1089 or (len(input3_shape) != len(output_shape))
1090 ):
1091 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001092
1093 info_dict = {
1094 "error_name": error_name,
1095 "error_result": error_result,
1096 "error_reason": error_reason,
1097 "param_reqs": param_reqs,
1098 }
1099 return info_dict
1100
1101 @staticmethod
1102 def evDimensionMismatch(check=False, **kwargs):
1103 error_name = ErrorIf.DimensionMismatch
1104 param_reqs = {"rank": None, "dtype": None, "shape": None}
1105 error_result = False
1106 error_reason = "Input Dimensions do not match output"
1107
1108 if check:
1109 input1_shape = kwargs["input1"].shape
1110 input2_shape = kwargs["input2"].shape
1111 # In case of SELECT op
1112 input3_shape = (
1113 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1114 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001115
Jerry Ge135c9552023-05-23 20:59:32 +00001116 if len(input1_shape) == len(input2_shape) == len(input3_shape):
1117 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
1118 input3_shape,
1119 TosaErrorValidator.calculateBroadcastShape(
1120 input1_shape, input2_shape
1121 ),
1122 )
1123 if calculated_shape is not None:
1124 # Valid inputs - check for output mismatch
1125 for output in kwargs["result_tensors"]:
1126 output_shape = output.shape
1127 if calculated_shape != output_shape:
1128 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001129
1130 info_dict = {
1131 "error_name": error_name,
1132 "error_result": error_result,
1133 "error_reason": error_reason,
1134 "param_reqs": param_reqs,
1135 }
1136 return info_dict
1137
1138 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001139 def _getZeroPoint(qinfo, index):
1140 """Return zero point value from quantization info.
1141
1142 Generally input_zp is index 0, output_zp is index 1
1143 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001144 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001145
1146 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001147 def evInputZeroPointNotZero(check=False, **kwargs):
1148 op = kwargs["op"]
1149 error_result = False
1150
1151 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001152 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001153
1154 # This does not apply to quantizable types
1155 inputDtypes = [
1156 dtype
1157 for dtype in op["types"]
1158 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1159 or (not isinstance(dtype, list) and dtype not in qTypes)
1160 ]
1161
1162 if check:
1163 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001164 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001165 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001166 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001167 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001168 (kwargs["input_dtype"], input_zero_point),
1169 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001170 ):
1171 if dtype not in qTypes and zp != 0:
1172 error_result = True
1173 break
1174 else:
1175 error_result = input_dtype not in qTypes and input_zero_point != 0
1176
1177 info_dict = {
1178 "error_name": ErrorIf.InputZeroPointNotZero,
1179 "error_result": error_result,
1180 "error_reason": "Input DType not INT8 and zero point not 0",
1181 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1182 }
1183 return info_dict
1184
1185 @staticmethod
1186 def evWeightZeroPointNotZero(check=False, **kwargs):
1187 op = kwargs["op"]
1188
1189 # exclude inputs with INT8 weights
1190 inputDtypes = [
1191 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1192 ]
1193
1194 error_name = ErrorIf.WeightZeroPointNotZero
1195 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1196 error_result = False
1197 error_reason = "Weight DType not INT8 and zero point not 0"
1198
1199 if check:
1200 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001201 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001202 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1203 error_result = True
1204
1205 info_dict = {
1206 "error_name": error_name,
1207 "error_result": error_result,
1208 "error_reason": error_reason,
1209 "param_reqs": param_reqs,
1210 }
1211 return info_dict
1212
1213 @staticmethod
1214 def evOutputZeroPointNotZero(check=False, **kwargs):
1215 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001216 inputDtypes = [
1217 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1218 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001219
1220 error_name = ErrorIf.OutputZeroPointNotZero
1221 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1222 error_result = False
1223 error_reason = "Output DType not INT8 and zero point not 0"
1224
1225 if check:
1226 input_dtype = kwargs["input_dtype"]
1227 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001228 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001229 if op["op"] == Op.AVG_POOL2D:
1230 if input_dtype != DType.INT8 and output_zero_point != 0:
1231 error_result = True
1232 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001233 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1234 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001235 ):
1236 error_result = True
1237
1238 info_dict = {
1239 "error_name": error_name,
1240 "error_result": error_result,
1241 "error_reason": error_reason,
1242 "param_reqs": param_reqs,
1243 }
1244 return info_dict
1245
1246 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001247 def evU16InputZeroPointNotValid(check=False, **kwargs):
1248 error_name = ErrorIf.U16InputZeroPointNotValid
1249 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1250 error_result = False
1251 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1252
1253 if check:
1254 input_dtype = kwargs["input_dtype"]
1255 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1256 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1257 0,
1258 32768,
1259 ]
1260
1261 info_dict = {
1262 "error_name": error_name,
1263 "error_result": error_result,
1264 "error_reason": error_reason,
1265 "param_reqs": param_reqs,
1266 }
1267 return info_dict
1268
1269 @staticmethod
1270 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1271 error_name = ErrorIf.U16OutputZeroPointNotValid
1272 param_reqs = {"rank": None, "dtype": None, "shape": None}
1273 error_result = False
1274 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1275
1276 if check:
1277 output_dtype = kwargs["output_dtype"]
1278 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1279
1280 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1281 0,
1282 32768,
1283 ]
1284
1285 info_dict = {
1286 "error_name": error_name,
1287 "error_result": error_result,
1288 "error_reason": error_reason,
1289 "param_reqs": param_reqs,
1290 }
1291 return info_dict
1292
1293 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001294 def evAxisSmallerZero(check=False, **kwargs):
1295 error_name = ErrorIf.AxisSmallerZero
1296 param_reqs = {"rank": None, "dtype": None, "shape": None}
1297 error_result = False
1298 error_reason = "Axis smaller than zero"
1299
1300 if check:
1301 axis = kwargs["axis"]
1302 if axis < 0:
1303 error_result = True
1304
1305 info_dict = {
1306 "error_name": error_name,
1307 "error_result": error_result,
1308 "error_reason": error_reason,
1309 "param_reqs": param_reqs,
1310 }
1311 return info_dict
1312
1313 @staticmethod
1314 def evAxisLargerRank(check=False, **kwargs):
1315 error_name = ErrorIf.AxisLargerRank
1316 param_reqs = {"rank": None, "dtype": None, "shape": None}
1317 error_result = False
1318 error_reason = "Axis larger than rank"
1319
1320 if check:
1321 axis = kwargs["axis"]
1322 shape = kwargs["input_shape"]
1323 if axis > len(shape):
1324 error_result = True
1325
1326 info_dict = {
1327 "error_name": error_name,
1328 "error_result": error_result,
1329 "error_reason": error_reason,
1330 "param_reqs": param_reqs,
1331 }
1332 return info_dict
1333
1334 @staticmethod
1335 def evShapeOfAxisNotOne(check=False, **kwargs):
1336 error_name = ErrorIf.ShapeOfAxisNotOne
1337 param_reqs = {"rank": None, "dtype": None, "shape": None}
1338 error_result = False
1339 error_reason = "shape[axis] is not equal to 1"
1340
1341 if check:
1342 axis = kwargs["axis"]
1343 shape = kwargs["output_shape"]
1344 if (0 <= axis < len(shape)) and shape[axis] != 1:
1345 error_result = True
1346
1347 info_dict = {
1348 "error_name": error_name,
1349 "error_result": error_result,
1350 "error_reason": error_reason,
1351 "param_reqs": param_reqs,
1352 }
1353 return info_dict
1354
1355 @staticmethod
1356 def evPadSmallerZero(check=False, **kwargs):
1357 error_name = ErrorIf.PadSmallerZero
1358 param_reqs = {"rank": None, "dtype": None, "shape": None}
1359 error_result = False
1360 error_reason = "At least one pad is smaller than zero"
1361
1362 if check:
1363 op = kwargs["op"]
1364 pad = kwargs["pad"]
1365 if op["op"] == Op.PAD:
1366 for padding in pad:
1367 if min(padding) < 0:
1368 error_result = True
1369 else:
1370 if min(pad) < 0:
1371 error_result = True
1372
1373 info_dict = {
1374 "error_name": error_name,
1375 "error_result": error_result,
1376 "error_reason": error_reason,
1377 "param_reqs": param_reqs,
1378 }
1379 return info_dict
1380
1381 @staticmethod
1382 def evPadLargerEqualKernel(check=False, **kwargs):
1383 error_name = ErrorIf.PadLargerEqualKernel
1384 param_reqs = {"rank": None, "dtype": None, "shape": None}
1385 error_result = False
1386 error_reason = "At least one pad is larger than kernel dimension"
1387
1388 if check:
1389 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001390 op = kwargs["op"]
1391 if op["op"] == Op.TRANSPOSE_CONV2D:
1392 # transpose_conv2d
1393 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001394 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001395 pad[0] <= -kernel[0]
1396 or pad[1] <= -kernel[0]
1397 or pad[2] <= -kernel[1]
1398 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001399 ):
1400 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001401 else:
1402 # pooling op
1403 kernel = kwargs["kernel"]
1404 if min(pad) > 0 and min(kernel) > 1:
1405 if (
1406 pad[0] >= kernel[0]
1407 or pad[1] >= kernel[0]
1408 or pad[2] >= kernel[1]
1409 or pad[3] >= kernel[1]
1410 ):
1411 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001412
1413 info_dict = {
1414 "error_name": error_name,
1415 "error_result": error_result,
1416 "error_reason": error_reason,
1417 "param_reqs": param_reqs,
1418 }
1419 return info_dict
1420
1421 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001422 def evPadOutputShapeMismatch(check=False, **kwargs):
1423 error_name = ErrorIf.PadOutputShapeMismatch
1424 param_reqs = {"rank": None, "dtype": None, "shape": None}
1425 error_result = False
1426 error_reason = "Pad output shape mismatch for requested padding"
1427
1428 if check:
1429 pad = kwargs["pad"]
1430 input_shape = kwargs["input_shape"]
1431 output_shape = kwargs["output_shape"]
1432 for dim, padding in enumerate(pad):
1433 expected_size = input_shape[dim] + padding[0] + padding[1]
1434 if expected_size != output_shape[dim]:
1435 error_result = True
1436
1437 info_dict = {
1438 "error_name": error_name,
1439 "error_result": error_result,
1440 "error_reason": error_reason,
1441 "param_reqs": param_reqs,
1442 }
1443 return info_dict
1444
1445 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001446 def checkPoolingParams(kernel, stride, pad):
1447 return (
1448 min(kernel) >= 1
1449 and min(stride) >= 1
1450 and min(pad) >= 0
1451 and not (
1452 pad[0] >= kernel[0]
1453 or pad[1] >= kernel[0]
1454 or pad[2] >= kernel[1]
1455 or pad[3] >= kernel[1]
1456 )
1457 )
1458
1459 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001460 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1461 error_name = ErrorIf.PoolingOutputShapeMismatch
1462 param_reqs = {"rank": None, "dtype": None, "shape": None}
1463 error_result = False
1464 error_reason = (
1465 "Mismatch between output shape provided and expected output shape"
1466 )
1467
1468 if check:
1469 pad = kwargs["pad"]
1470 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1471
1472 kernel = kwargs["kernel"]
1473 kernel_y, kernel_x = kernel[0], kernel[1]
1474
1475 input_shape = kwargs["input_shape"]
1476 IH, IW = input_shape[1], input_shape[2]
1477
1478 output_shape = kwargs["output_shape"]
1479 OH, OW = output_shape[1], output_shape[2]
1480
1481 stride = kwargs["stride"]
1482 stride_y, stride_x = stride[0], stride[1]
1483
1484 # calculate correct height, width dimensions
1485 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001486 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1487 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001488
1489 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001490 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001491
1492 if params_valid and (OH != y_correct or OW != x_correct):
1493 error_result = True
1494
1495 info_dict = {
1496 "error_name": error_name,
1497 "error_result": error_result,
1498 "error_reason": error_reason,
1499 "param_reqs": param_reqs,
1500 }
1501 return info_dict
1502
1503 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001504 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1505 error_name = ErrorIf.PoolingOutputShapeNonInteger
1506 param_reqs = {"rank": None, "dtype": None, "shape": None}
1507 error_result = False
1508 error_reason = "Parameters do not yield exact integer output dimensions"
1509
1510 if check:
1511 pad = kwargs["pad"]
1512 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1513
1514 kernel = kwargs["kernel"]
1515 kernel_y, kernel_x = kernel[0], kernel[1]
1516
1517 input_shape = kwargs["input_shape"]
1518 IH, IW = input_shape[1], input_shape[2]
1519
1520 stride = kwargs["stride"]
1521 stride_y, stride_x = stride[0], stride[1]
1522
1523 # calculate remainder of height, width dimensions
1524 if stride_x != 0 and stride_y != 0:
1525 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1526 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1527
1528 # ensure parameters are valid
1529 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1530 if params_valid and (y_remainder != 0 or x_remainder != 0):
1531 error_result = True
1532
1533 info_dict = {
1534 "error_name": error_name,
1535 "error_result": error_result,
1536 "error_reason": error_reason,
1537 "param_reqs": param_reqs,
1538 }
1539 return info_dict
1540
1541 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001542 def checkConvParams(op, weight_shape, stride, pad, dilation):
1543 if op == Op.TRANSPOSE_CONV2D:
1544 pad_ok = (
1545 pad[0] > -weight_shape[1]
1546 and pad[1] > -weight_shape[1]
1547 and pad[2] > -weight_shape[2]
1548 and pad[3] > -weight_shape[2]
1549 )
1550 else:
1551 pad_ok = min(pad) >= 0
1552
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001553 return (
1554 # Check kernel sizes
1555 min(weight_shape[1:-1]) >= 1
1556 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001557 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001558 and (dilation is None or min(dilation) >= 1)
1559 )
1560
1561 @staticmethod
1562 def evConvOutputShapeMismatch(check=False, **kwargs):
1563 error_name = ErrorIf.ConvOutputShapeMismatch
1564 param_reqs = {"rank": None, "dtype": None, "shape": None}
1565 error_result = False
1566 error_reason = (
1567 "Mismatch between output shape provided and expected output shape"
1568 )
1569
1570 if check:
1571 op = kwargs["op"]
1572 pad = kwargs["pad"]
1573 weight_shape = kwargs["weight_shape"]
1574 input_shape = kwargs["input_shape"]
1575 output_shape = kwargs["output_shape"]
1576 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1577 stride = kwargs["stride"]
1578
1579 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1580
1581 # calculate correct dimensions
1582 dims_correct = []
1583 if min(stride) > 0:
1584 for index in range(len(stride)):
1585 pad_offset = index * 2
1586 if op["op"] == Op.TRANSPOSE_CONV2D:
1587 dims_correct.append(
1588 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001589 + pad[pad_offset]
1590 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001591 + weight_shape[index + kernel_offset]
1592 )
1593 else:
1594 dims_correct.append(
1595 (
1596 input_shape[index + 1]
1597 - 1
1598 + pad[pad_offset]
1599 + pad[pad_offset + 1]
1600 - (weight_shape[index + kernel_offset] - 1)
1601 * dilation[index]
1602 )
1603 // stride[index]
1604 + 1
1605 )
1606
1607 # ensure parameters are valid
1608 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001609 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001610 )
1611
1612 if params_valid and output_shape[1:-1] != dims_correct:
1613 error_result = True
1614
1615 info_dict = {
1616 "error_name": error_name,
1617 "error_result": error_result,
1618 "error_reason": error_reason,
1619 "param_reqs": param_reqs,
1620 }
1621 return info_dict
1622
1623 @staticmethod
1624 def evConvOutputShapeNonInteger(check=False, **kwargs):
1625 error_name = ErrorIf.ConvOutputShapeNonInteger
1626 param_reqs = {"rank": None, "dtype": None, "shape": None}
1627 error_result = False
1628 error_reason = "Parameters do not yield exact integer output dimensions"
1629
1630 if check:
1631 op = kwargs["op"]
1632 pad = kwargs["pad"]
1633 weight_shape = kwargs["weight_shape"]
1634 input_shape = kwargs["input_shape"]
1635 dilation = kwargs["dilation"]
1636 stride = kwargs["stride"]
1637
1638 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1639
1640 # calculate correct height, width dimensions
1641 remainders = []
1642 if min(stride) > 0:
1643 for index in range(len(stride)):
1644 pad_offset = index * 2
1645 remainders.append(
1646 (
1647 input_shape[index + 1]
1648 - 1
1649 + pad[pad_offset]
1650 + pad[pad_offset + 1]
1651 - (weight_shape[index + kernel_offset] - 1)
1652 * dilation[index]
1653 )
1654 % stride[index]
1655 )
1656
1657 # ensure parameters are valid
1658 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001659 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001660 )
1661 if params_valid and max(remainders) > 0:
1662 error_result = True
1663
1664 info_dict = {
1665 "error_name": error_name,
1666 "error_result": error_result,
1667 "error_reason": error_reason,
1668 "param_reqs": param_reqs,
1669 }
1670 return info_dict
1671
1672 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001673 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1674 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1675 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1676 error_result = False
1677 error_reason = (
1678 "Mismatch between output shape provided and expected output shape"
1679 )
1680
1681 if check:
1682 output_shape = kwargs["output_shape"]
1683 input_shape = kwargs["input_shape"]
1684 axis = kwargs["axis"]
1685
1686 dimension_match = True
1687 axis_shift = 0
1688
1689 # Check that rank is correct before trying to check dimensions
1690 if (len(input_shape) - 1) == len(output_shape):
1691 for i in range(len(input_shape)):
1692 if i == axis:
1693 axis_shift = 1
1694 continue
1695 if input_shape[i] != output_shape[i - axis_shift]:
1696 dimension_match = False
1697
1698 if not dimension_match:
1699 error_result = True
1700
1701 info_dict = {
1702 "error_name": error_name,
1703 "error_result": error_result,
1704 "error_reason": error_reason,
1705 "param_reqs": param_reqs,
1706 }
1707 return info_dict
1708
1709 @staticmethod
1710 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1711 error_name = ErrorIf.ArgmaxOutputRankMismatch
1712 param_reqs = {"rank": None, "dtype": None, "shape": None}
1713 error_result = False
1714 error_reason = (
1715 "Mismatch between output shape provided and expected output shape"
1716 )
1717
1718 if check:
1719 output_shape = kwargs["output_shape"]
1720 input_shape = kwargs["input_shape"]
1721 axis = kwargs["axis"]
1722 valid_params = axis >= 0 and axis < len(input_shape)
1723
1724 if valid_params and (len(input_shape) - 1) != len(output_shape):
1725 error_result = True
1726
1727 info_dict = {
1728 "error_name": error_name,
1729 "error_result": error_result,
1730 "error_reason": error_reason,
1731 "param_reqs": param_reqs,
1732 }
1733 return info_dict
1734
1735 @staticmethod
1736 def evKernelSmallerOne(check=False, **kwargs):
1737 error_name = ErrorIf.KernelSmallerOne
1738 param_reqs = {"rank": None, "dtype": None, "shape": None}
1739 error_result = False
1740 error_reason = "At least one kernel dimension is smaller than zero"
1741
1742 if check:
1743 kernel = kwargs["kernel"]
1744 if min(kernel) < 1:
1745 error_result = True
1746
1747 info_dict = {
1748 "error_name": error_name,
1749 "error_result": error_result,
1750 "error_reason": error_reason,
1751 "param_reqs": param_reqs,
1752 }
1753 return info_dict
1754
1755 @staticmethod
1756 def evStrideSmallerOne(check=False, **kwargs):
1757 error_name = ErrorIf.StrideSmallerOne
1758 param_reqs = {"rank": None, "dtype": None, "shape": None}
1759 error_result = False
1760 error_reason = "At least one stride dimension is smaller than zero"
1761
1762 if check:
1763 stride = kwargs["stride"]
1764 if min(stride) < 1:
1765 error_result = True
1766
1767 info_dict = {
1768 "error_name": error_name,
1769 "error_result": error_result,
1770 "error_reason": error_reason,
1771 "param_reqs": param_reqs,
1772 }
1773 return info_dict
1774
1775 @staticmethod
1776 def evDilationSmallerOne(check=False, **kwargs):
1777 error_result = check and min(kwargs["dilation"]) < 1
1778 return {
1779 "error_name": ErrorIf.DilationSmallerOne,
1780 "error_reason": "At least one dilation is smaller than one",
1781 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1782 "error_result": error_result,
1783 }
1784
1785 @staticmethod
1786 def evScaleTrue(check=False, **kwargs):
1787 error_name = ErrorIf.ScaleTrue
1788 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1789 error_result = False
1790 error_reason = "Scale set to true but input type is INT48"
1791
1792 if check:
1793 input_dtype = kwargs["input_dtype"]
1794 scale32 = kwargs["scale32"]
1795 if scale32 and input_dtype == DType.INT48:
1796 error_result = True
1797
1798 info_dict = {
1799 "error_name": error_name,
1800 "error_result": error_result,
1801 "error_reason": error_reason,
1802 "param_reqs": param_reqs,
1803 }
1804 return info_dict
1805
1806 @staticmethod
1807 def evScaleNotTrue(check=False, **kwargs):
1808 error_name = ErrorIf.ScaleNotTrue
1809 param_reqs = {"rank": None, "dtype": None, "shape": None}
1810 error_result = False
1811 error_reason = "Scale set to false but double round set to true"
1812
1813 if check:
1814 scale32 = kwargs["scale32"]
1815 double_round = kwargs["double_round"]
1816 if not scale32 and double_round:
1817 error_result = True
1818
1819 info_dict = {
1820 "error_name": error_name,
1821 "error_result": error_result,
1822 "error_reason": error_reason,
1823 "param_reqs": param_reqs,
1824 }
1825 return info_dict
1826
1827 @staticmethod
1828 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1829 error_name = ErrorIf.TensorSizeInputOutputMismatch
1830 param_reqs = {"rank": None, "dtype": None, "shape": None}
1831 error_result = False
1832 error_reason = "Input tensor size does not match output tensor size"
Jerry Ge264f7fa2023-04-21 22:49:57 +00001833 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001834
1835 if check:
1836 input_shape = kwargs["input_shape"]
1837 output_shape = kwargs["output_shape"]
Jerry Ge264f7fa2023-04-21 22:49:57 +00001838 shape_inferencing = False
1839 if -1 in output_shape and op["op"] == Op.RESHAPE:
1840 shape_inferencing = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001841 input_size = np.prod(input_shape)
1842 output_size = np.prod(output_shape)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001843 if input_size != output_size and not shape_inferencing:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001844 error_result = True
1845
1846 info_dict = {
1847 "error_name": error_name,
1848 "error_result": error_result,
1849 "error_reason": error_reason,
1850 "param_reqs": param_reqs,
1851 }
1852 return info_dict
1853
1854 @staticmethod
1855 def evStartSmallerZero(check=False, **kwargs):
1856 error_name = ErrorIf.StartSmallerZero
1857 param_reqs = {"rank": None, "dtype": None, "shape": None}
1858 error_result = False
1859 error_reason = "Starting point smaller than zero"
1860
1861 if check:
1862 input_shape = kwargs["input_shape"]
1863 start = kwargs["start"]
1864 rank = len(input_shape)
1865 if len(start) == rank:
1866 for index in range(rank):
1867 if start[index] < 0:
1868 error_result = True
1869
1870 info_dict = {
1871 "error_name": error_name,
1872 "error_result": error_result,
1873 "error_reason": error_reason,
1874 "param_reqs": param_reqs,
1875 }
1876 return info_dict
1877
1878 @staticmethod
1879 def evSizeSmallerEqualZero(check=False, **kwargs):
1880 error_name = ErrorIf.SizeSmallerEqualZero
1881 param_reqs = {"rank": None, "dtype": None, "shape": None}
1882 error_result = False
1883 error_reason = "Size smaller than or equal to zero"
1884
1885 if check:
1886 input_shape = kwargs["input_shape"]
1887 size = kwargs["size"]
1888 rank = len(input_shape)
1889 if len(size) == rank:
1890 for index in range(rank):
1891 if size[index] <= 0:
1892 error_result = True
1893
1894 info_dict = {
1895 "error_name": error_name,
1896 "error_result": error_result,
1897 "error_reason": error_reason,
1898 "param_reqs": param_reqs,
1899 }
1900 return info_dict
1901
1902 @staticmethod
1903 def evStartSizeOutsideBounds(check=False, **kwargs):
1904 error_name = ErrorIf.StartSizeOutsideBounds
1905 param_reqs = {"rank": None, "dtype": None, "shape": None}
1906 error_result = False
1907 error_reason = "starting point plus size larger than input dimension"
1908
1909 if check:
1910 input_shape = kwargs["input_shape"]
1911 start = kwargs["start"]
1912 size = kwargs["size"]
1913 rank = len(input_shape)
1914 if len(start) == rank and len(size) == rank:
1915 for index in range(rank):
1916 if start[index] + size[index] > input_shape[index]:
1917 error_result = True
1918
1919 info_dict = {
1920 "error_name": error_name,
1921 "error_result": error_result,
1922 "error_reason": error_reason,
1923 "param_reqs": param_reqs,
1924 }
1925 return info_dict
1926
1927 @staticmethod
1928 def evSizeOutputShapeMismatch(check=False, **kwargs):
1929 error_name = ErrorIf.SizeOutputShapeMismatch
1930 param_reqs = {"rank": None, "dtype": None, "shape": None}
1931 error_result = False
1932 error_reason = "Size does not match output dimension"
1933
1934 if check:
1935 input_shape = kwargs["input_shape"]
1936 output_shape = kwargs["output_shape"]
1937 size = kwargs["size"]
Luke Huttona4e48ca2023-02-22 11:53:48 +00001938
1939 if len(input_shape) == len(output_shape):
1940 rank = len(input_shape)
1941 if len(size) == rank:
1942 for index in range(rank):
1943 if size[index] != output_shape[index]:
1944 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001945
1946 info_dict = {
1947 "error_name": error_name,
1948 "error_result": error_result,
1949 "error_reason": error_reason,
1950 "param_reqs": param_reqs,
1951 }
1952 return info_dict
1953
1954 @staticmethod
1955 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1956 error_name = ErrorIf.InputSizeStartLengthMismatch
1957 param_reqs = {"rank": None, "dtype": None, "shape": None}
1958 error_result = False
1959 error_reason = "rank of input not equal to length of start or size"
1960
1961 if check:
1962 input_shape = kwargs["input_shape"]
1963 start = kwargs["start"]
1964 size = kwargs["size"]
1965 rank = len(input_shape)
1966 if rank != len(start) or rank != len(size):
1967 error_result = True
1968
1969 info_dict = {
1970 "error_name": error_name,
1971 "error_result": error_result,
1972 "error_reason": error_reason,
1973 "param_reqs": param_reqs,
1974 }
1975 return info_dict
1976
1977 @staticmethod
1978 def evIndexOutsideBounds(check=False, **kwargs):
1979 error_name = ErrorIf.IndexOutsideBounds
1980 param_reqs = {"rank": None, "dtype": None, "shape": None}
1981 error_result = False
1982 error_reason = "Index outside of allowed bounds"
1983
1984 if check:
1985 input_shape = kwargs["input_shape"]
1986 perms = kwargs["perms"]
1987 rank = len(input_shape)
1988
1989 for index in perms:
1990 if index < 0 or index > rank:
1991 error_result = True
1992
1993 info_dict = {
1994 "error_name": error_name,
1995 "error_result": error_result,
1996 "error_reason": error_reason,
1997 "param_reqs": param_reqs,
1998 }
1999 return info_dict
2000
2001 @staticmethod
2002 def evIndexUsedTwice(check=False, **kwargs):
2003 error_name = ErrorIf.IndexUsedTwice
2004 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2005 error_result = False
2006 error_reason = "Index used multiple times"
2007
2008 if check:
2009 perms = kwargs["perms"]
2010
2011 unique_indices = []
2012 for index in perms:
2013 if index in unique_indices:
2014 error_result = True
2015 else:
2016 unique_indices.append(index)
2017
2018 info_dict = {
2019 "error_name": error_name,
2020 "error_result": error_result,
2021 "error_reason": error_reason,
2022 "param_reqs": param_reqs,
2023 }
2024 return info_dict
2025
2026 @staticmethod
2027 def evMaxSmallerMin(check=False, **kwargs):
2028 error_name = ErrorIf.MaxSmallerMin
2029 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2030 error_result = False
2031 error_reason = "Max value smaller than min value"
2032
2033 if check:
2034 max_val = kwargs["max_val"]
2035 min_val = kwargs["min_val"]
2036 if max_val < min_val:
2037 error_result = True
2038
2039 info_dict = {
2040 "error_name": error_name,
2041 "error_result": error_result,
2042 "error_reason": error_reason,
2043 "param_reqs": param_reqs,
2044 }
2045 return info_dict
2046
2047 @staticmethod
2048 def evConcatInputRankMismatch(check=False, **kwargs):
2049 error_name = ErrorIf.ConcatInputRankMismatch
2050 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2051 error_result = False
2052 error_reason = "Input ranks are not identical"
2053
2054 if check:
2055 inputs = kwargs["inputs"]
2056 input_shape = kwargs["input_shape"]
2057 for input in inputs:
2058 if len(input.shape) != len(input_shape):
2059 error_result = True
2060
2061 info_dict = {
2062 "error_name": error_name,
2063 "error_result": error_result,
2064 "error_reason": error_reason,
2065 "param_reqs": param_reqs,
2066 }
2067 return info_dict
2068
2069 @staticmethod
2070 def evConcatInputDimMismatch(check=False, **kwargs):
2071 error_name = ErrorIf.ConcatInputDimMismatch
2072 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2073 error_result = False
2074 error_reason = "Input dimensions differ on too many axes"
2075
2076 if check:
2077 inputs = kwargs["inputs"]
2078 input_shape = kwargs["input_shape"]
2079 axis = kwargs["axis"]
2080
2081 # Ensure rank is valid before checking dims.
2082 valid_rank = True
2083 for input in inputs:
2084 if len(input.shape) != len(input_shape):
2085 valid_rank = False
2086
2087 if valid_rank:
2088 for input in inputs:
2089 for i, dim in enumerate(input.shape):
2090 if dim != input_shape[i] and axis != i:
2091 error_result = True
2092
2093 info_dict = {
2094 "error_name": error_name,
2095 "error_result": error_result,
2096 "error_reason": error_reason,
2097 "param_reqs": param_reqs,
2098 }
2099 return info_dict
2100
2101 @staticmethod
2102 def evConcatShapeSumMismatch(check=False, **kwargs):
2103 error_name = ErrorIf.ConcatShapeSumMismatch
2104 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2105 error_result = False
2106 error_reason = "Sum of dimensions on axis not equal to output dimension"
2107
2108 if check:
2109 inputs = kwargs["inputs"]
2110 input_shape = kwargs["input_shape"]
2111 output_shape = kwargs["output_shape"]
2112 axis = kwargs["axis"]
2113
2114 # Ensure rank is valid before checking dims.
2115 valid_params = True
2116 for input in inputs:
2117 if len(input.shape) != len(input_shape):
2118 valid_params = False
2119 if axis < 0 or axis > len(input_shape):
2120 valid_params = False
2121
2122 if valid_params:
2123 axis_dim_sum = 0
2124 for input in inputs:
2125 axis_dim_sum += input.shape[axis]
2126
2127 if axis_dim_sum != output_shape[axis]:
2128 error_result = True
2129
2130 info_dict = {
2131 "error_name": error_name,
2132 "error_result": error_result,
2133 "error_reason": error_reason,
2134 "param_reqs": param_reqs,
2135 }
2136 return info_dict
2137
2138 @staticmethod
2139 def evInputListThenGraphMismatch(check=False, **kwargs):
2140 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2141 param_reqs = {"rank": None, "dtype": None, "shape": None}
2142 error_result = False
2143 error_reason = "Input list shape does not match then-graph shape"
2144
2145 if check:
2146 a = kwargs["a"]
2147 b = kwargs["b"]
2148 basicBlocks = kwargs["basicBlocks"]
2149 then_block = basicBlocks[1]
2150 then_inputs = then_block.inputs
2151 then_tens = then_block.tensors
2152 if (a.shape != then_tens[then_inputs[0]].shape) or (
2153 b.shape != then_tens[then_inputs[1]].shape
2154 ):
2155 error_result = True
2156
2157 info_dict = {
2158 "error_name": error_name,
2159 "error_result": error_result,
2160 "error_reason": error_reason,
2161 "param_reqs": param_reqs,
2162 }
2163 return info_dict
2164
2165 @staticmethod
2166 def evInputListElseGraphMismatch(check=False, **kwargs):
2167 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2168 param_reqs = {"rank": None, "dtype": None, "shape": None}
2169 error_result = False
2170 error_reason = "Input list shape does not match else-graph shape"
2171
2172 if check:
2173 a = kwargs["a"]
2174 b = kwargs["b"]
2175 basicBlocks = kwargs["basicBlocks"]
2176 else_block = basicBlocks[2]
2177 else_inputs = else_block.inputs
2178 else_tens = else_block.tensors
2179 if (a.shape != else_tens[else_inputs[0]].shape) or (
2180 b.shape != else_tens[else_inputs[1]].shape
2181 ):
2182 error_result = True
2183
2184 info_dict = {
2185 "error_name": error_name,
2186 "error_result": error_result,
2187 "error_reason": error_reason,
2188 "param_reqs": param_reqs,
2189 }
2190 return info_dict
2191
2192 @staticmethod
2193 def evOutputListThenGraphMismatch(check=False, **kwargs):
2194 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2195 param_reqs = {"rank": None, "dtype": None, "shape": None}
2196 error_result = False
2197 error_reason = "Output list shape does not match then-graph shape"
2198
2199 if check:
2200 basicBlocks = kwargs["basicBlocks"]
2201 cond_block = basicBlocks[0]
2202 cond_outputs = cond_block.outputs
2203 cond_tens = cond_block.tensors
2204 then_block = basicBlocks[1]
2205 then_outputs = then_block.outputs
2206 then_tens = then_block.tensors
2207 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2208 error_result = True
2209
2210 info_dict = {
2211 "error_name": error_name,
2212 "error_result": error_result,
2213 "error_reason": error_reason,
2214 "param_reqs": param_reqs,
2215 }
2216 return info_dict
2217
2218 @staticmethod
2219 def evOutputListElseGraphMismatch(check=False, **kwargs):
2220 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2221 param_reqs = {"rank": None, "dtype": None, "shape": None}
2222 error_result = False
2223 error_reason = "Output list shape does not match else-graph shape"
2224
2225 if check:
2226 basicBlocks = kwargs["basicBlocks"]
2227 cond_block = basicBlocks[0]
2228 cond_outputs = cond_block.outputs
2229 cond_tens = cond_block.tensors
2230 else_block = basicBlocks[2]
2231 else_outputs = else_block.outputs
2232 else_tens = else_block.tensors
2233 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2234 error_result = True
2235
2236 info_dict = {
2237 "error_name": error_name,
2238 "error_result": error_result,
2239 "error_reason": error_reason,
2240 "param_reqs": param_reqs,
2241 }
2242 return info_dict
2243
2244 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002245 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2246 error_name = ErrorIf.CondIfCondNotMatchingBool
2247 param_reqs = {"rank": None, "dtype": None, "shape": None}
2248 error_result = False
2249 error_reason = "Conditional tensor does not match bool type"
2250
2251 if check:
2252 cond = kwargs["cond"]
2253 if cond.dtype != DType.BOOL:
2254 error_result = True
2255
2256 info_dict = {
2257 "error_name": error_name,
2258 "error_result": error_result,
2259 "error_reason": error_reason,
2260 "param_reqs": param_reqs,
2261 }
2262 return info_dict
2263
2264 @staticmethod
2265 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2266 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2267 param_reqs = {"rank": None, "dtype": None, "shape": None}
2268 error_result = False
2269 error_reason = "Conditional tensor is not equal to a size of one"
2270
2271 if check:
2272 cond = kwargs["cond"]
2273 # Size of 1 is equivalent to rank 0
2274 if len(cond.shape) != 0:
2275 error_result = True
2276
2277 info_dict = {
2278 "error_name": error_name,
2279 "error_result": error_result,
2280 "error_reason": error_reason,
2281 "param_reqs": param_reqs,
2282 }
2283 return info_dict
2284
2285 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002286 def evInputListOutputListMismatch(check=False, **kwargs):
2287 error_name = ErrorIf.InputListOutputListMismatch
2288 param_reqs = {"rank": None, "dtype": None, "shape": None}
2289 error_result = False
2290 error_reason = "Input list does not match output list"
2291
2292 if check:
2293 basicBlocks = kwargs["basicBlocks"]
2294 while_block = basicBlocks[0]
2295 while_inputs = while_block.inputs
2296 while_outputs = while_block.outputs
2297 while_tens = while_block.tensors
2298 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2299 error_result = True
2300
2301 info_dict = {
2302 "error_name": error_name,
2303 "error_result": error_result,
2304 "error_reason": error_reason,
2305 "param_reqs": param_reqs,
2306 }
2307 return info_dict
2308
2309 @staticmethod
2310 def evInputListCondGraphMismatch(check=False, **kwargs):
2311 error_name = ErrorIf.InputListCondGraphMismatch
2312 param_reqs = {"rank": None, "dtype": None, "shape": None}
2313 error_result = False
2314 error_reason = "Input list does not match cond graph"
2315
2316 if check:
2317 basicBlocks = kwargs["basicBlocks"]
2318 while_block = basicBlocks[0]
2319 while_inputs = while_block.inputs
2320 while_tens = while_block.tensors
2321 cond_block = basicBlocks[1]
2322 cond_inputs = cond_block.inputs
2323 cond_tens = cond_block.tensors
2324 if (
2325 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2326 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2327 error_result = True
2328
2329 info_dict = {
2330 "error_name": error_name,
2331 "error_result": error_result,
2332 "error_reason": error_reason,
2333 "param_reqs": param_reqs,
2334 }
2335 return info_dict
2336
2337 @staticmethod
2338 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2339 error_name = ErrorIf.InputListBodyGraphInputMismatch
2340 param_reqs = {"rank": None, "dtype": None, "shape": None}
2341 error_result = False
2342 error_reason = "Input list does not match body graph input"
2343
2344 if check:
2345 basicBlocks = kwargs["basicBlocks"]
2346 while_block = basicBlocks[0]
2347 while_inputs = while_block.inputs
2348 while_tens = while_block.tensors
2349 body_block = basicBlocks[2]
2350 body_outputs = body_block.inputs
2351 body_tens = body_block.tensors
2352 if (
2353 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2354 ) or (
2355 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2356 ):
2357 error_result = True
2358
2359 info_dict = {
2360 "error_name": error_name,
2361 "error_result": error_result,
2362 "error_reason": error_reason,
2363 "param_reqs": param_reqs,
2364 }
2365 return info_dict
2366
2367 @staticmethod
2368 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2369 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2370 param_reqs = {"rank": None, "dtype": None, "shape": None}
2371 error_result = False
2372 error_reason = "Input list does not match body graph output"
2373
2374 if check:
2375 basicBlocks = kwargs["basicBlocks"]
2376 while_block = basicBlocks[0]
2377 while_inputs = while_block.inputs
2378 while_tens = while_block.tensors
2379 body_block = basicBlocks[2]
2380 body_outputs = body_block.outputs
2381 body_tens = body_block.tensors
2382 if (
2383 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2384 ) or (
2385 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2386 ):
2387 error_result = True
2388 info_dict = {
2389 "error_name": error_name,
2390 "error_result": error_result,
2391 "error_reason": error_reason,
2392 "param_reqs": param_reqs,
2393 }
2394 return info_dict
2395
2396 @staticmethod
2397 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2398 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2399 param_reqs = {"rank": None, "dtype": None, "shape": None}
2400 error_result = False
2401 error_reason = "Cond graph output is not a match list of booleans"
2402
2403 if check:
2404 basicBlocks = kwargs["basicBlocks"]
2405 cond_block = basicBlocks[1]
2406 cond_outputs = cond_block.outputs
2407 cond_tens = cond_block.tensors
2408 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2409 error_result = True
2410
2411 info_dict = {
2412 "error_name": error_name,
2413 "error_result": error_result,
2414 "error_reason": error_reason,
2415 "param_reqs": param_reqs,
2416 }
2417 return info_dict
2418
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002419 @staticmethod
2420 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2421 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2422 param_reqs = {"rank": None, "dtype": None, "shape": None}
2423 error_result = False
2424 error_reason = "Cond graph output is not a shape of size one"
2425
2426 if check:
2427 basicBlocks = kwargs["basicBlocks"]
2428 cond_block = basicBlocks[1]
2429 cond_outputs = cond_block.outputs
2430 cond_tens = cond_block.tensors
2431 # Size of 1 is equivalent to rank 0
2432 if len(cond_tens[cond_outputs[0]].shape) != 0:
2433 error_result = True
2434
2435 info_dict = {
2436 "error_name": error_name,
2437 "error_result": error_result,
2438 "error_reason": error_reason,
2439 "param_reqs": param_reqs,
2440 }
2441 return info_dict
2442
Luke Hutton261b7b62023-01-10 14:50:31 +00002443 @staticmethod
2444 def evKernelNotPowerOfTwo(check=False, **kwargs):
2445 error_name = ErrorIf.KernelNotPowerOfTwo
2446 param_reqs = {"rank": None, "dtype": None, "shape": None}
2447 error_result = False
2448 error_reason = "kernel height and/or width not a power of two"
2449
2450 def is_power_of_two(x):
2451 return math.log(x, 2).is_integer()
2452
2453 if check:
2454 shape = kwargs["input_shape"]
2455 if len(shape) == 3:
2456 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2457 error_result = not valid_kernel
2458
2459 info_dict = {
2460 "error_name": error_name,
2461 "error_result": error_result,
2462 "error_reason": error_reason,
2463 "param_reqs": param_reqs,
2464 }
2465 return info_dict
2466
Luke Hutton57287132023-02-06 14:54:18 +00002467 @staticmethod
2468 def evFFTInputShapeMismatch(check=False, **kwargs):
2469 error_name = ErrorIf.FFTInputShapeMismatch
2470 param_reqs = {"rank": None, "dtype": None, "shape": None}
2471 error_result = False
2472 error_reason = "Mismatch between real and imaginary input shapes"
2473
2474 if check:
2475 input1 = kwargs["input1"]
2476 input2 = kwargs["input2"]
2477
2478 if input1.shape != input2.shape:
2479 error_result = True
2480
2481 info_dict = {
2482 "error_name": error_name,
2483 "error_result": error_result,
2484 "error_reason": error_reason,
2485 "param_reqs": param_reqs,
2486 }
2487 return info_dict
2488
2489 @staticmethod
2490 def evFFTOutputShapeMismatch(check=False, **kwargs):
2491 error_name = ErrorIf.FFTOutputShapeMismatch
2492 param_reqs = {"rank": None, "dtype": None, "shape": None}
2493 error_result = False
2494 error_reason = (
2495 "Mismatch between provided and expected output kernel (H, W) shape"
2496 )
2497
2498 if check:
2499 op = kwargs["op"]
2500 input_shape = kwargs["input_shape"]
2501
2502 if len(input_shape) == 3:
2503 output_shapes = kwargs["output_shape"]
2504
2505 # Ignoring batch size (N) from input shape
2506 expected_shape = input_shape[1:]
2507 if op["op"] == Op.RFFT2D:
2508 expected_shape[1] = expected_shape[1] // 2 + 1
2509
2510 # Ignoring batch size (N) from output shapes
2511 output_shape_0 = output_shapes[0][1:]
2512 output_shape_1 = output_shapes[1][1:]
2513 # Ensure sure the kernel sizes (H, W) of both outputs match the expected
2514 if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
2515 error_result = True
2516
2517 info_dict = {
2518 "error_name": error_name,
2519 "error_result": error_result,
2520 "error_reason": error_reason,
2521 "param_reqs": param_reqs,
2522 }
2523 return info_dict
2524
Jerry Ge264f7fa2023-04-21 22:49:57 +00002525 @staticmethod
2526 def evReshapeOutputSizeMultiInference(check=False, **kwargs):
2527 error_name = ErrorIf.ReshapeOutputSizeMultiInference
2528 param_reqs = {"rank": None, "dtype": None, "shape": None}
2529 error_result = False
2530 error_reason = "Reshape output tensor contains more than one inferred dimension"
2531
2532 if check:
2533 output_shape = kwargs["output_shape"]
2534 inferences = 0
2535 for dim in output_shape:
2536 if dim == -1:
2537 inferences += 1
2538 if inferences > 1:
2539 error_result = True
2540
2541 info_dict = {
2542 "error_name": error_name,
2543 "error_result": error_result,
2544 "error_reason": error_reason,
2545 "param_reqs": param_reqs,
2546 }
2547 return info_dict
2548
2549 @staticmethod
2550 def evReshapeOutputSizeNonInteger(check=False, **kwargs):
2551 error_name = ErrorIf.ReshapeOutputSizeNonInteger
2552 param_reqs = {"rank": None, "dtype": None, "shape": None}
2553 error_result = False
2554 error_reason = "Reshape inferred output tensor dimension is non-integer"
2555
2556 if check:
2557 input_shape = kwargs["input_shape"]
2558 output_shape = kwargs["output_shape"]
2559 input_size = np.prod(input_shape)
2560 output_size = 1
2561 for dim in output_shape:
2562 if dim != -1:
2563 output_size *= dim
2564 if -1 in output_shape and input_size % output_size != 0:
2565 error_result = True
2566
2567 info_dict = {
2568 "error_name": error_name,
2569 "error_result": error_result,
2570 "error_reason": error_reason,
2571 "param_reqs": param_reqs,
2572 }
2573 return info_dict
2574
Jerry Ge135c9552023-05-23 20:59:32 +00002575 @staticmethod
2576 def calculateBroadcastShape(input_shape_a, input_shape_b):
2577 if input_shape_a is not None and input_shape_b is not None:
2578 calculated_shape = input_shape_a.copy()
2579 for idx in range(len(calculated_shape)):
2580 if calculated_shape[idx] == 1:
2581 calculated_shape[idx] = input_shape_b[idx]
2582 elif (
2583 input_shape_b[idx] != 1
2584 and input_shape_b[idx] != calculated_shape[idx]
2585 ):
2586 return None
2587 return calculated_shape
2588 else:
2589 return None
2590
2591 @staticmethod
2592 def evBroadcastShapesMismatch(check=False, **kwargs):
2593 error_name = ErrorIf.BroadcastShapesMismatch
2594 param_reqs = {"rank": None, "dtype": None, "shape": None}
2595 error_result = False
2596 error_reason = "Broadcast shape calculating failed"
2597
2598 if check:
2599 input_shape_a = kwargs["input1"].shape
2600 input_shape_b = kwargs["input2"].shape
2601 input_shape_c = (
2602 kwargs["input3"].shape if "input3" in kwargs else input_shape_b
2603 )
2604
2605 if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
2606 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
2607 input_shape_c,
2608 TosaErrorValidator.calculateBroadcastShape(
2609 input_shape_a, input_shape_b
2610 ),
2611 )
2612 error_result = calculated_shape is None
2613
2614 info_dict = {
2615 "error_name": error_name,
2616 "error_result": error_result,
2617 "error_reason": error_reason,
2618 "param_reqs": param_reqs,
2619 }
2620 return info_dict
2621
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002622
2623class TosaInvalidValidator:
2624 @staticmethod
2625 def ivWrongDataTypeOrModeResize(**kwargs):
2626 input_dtype = kwargs["input_dtype"]
2627 args = kwargs["args"]
2628 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002629 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002630
2631 if mode == ResizeMode.BILINEAR:
2632 # Invalid output data type / Invalid input datatype
2633 return (
2634 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002635 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002636 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002637 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002638 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002639 )
2640 elif mode == ResizeMode.NEAREST:
2641 # Invalid output data type / Invalid input datatype
2642 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002643 input_dtype
2644 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002645 )
2646 else:
2647 # Invalid resize mode
2648 return True
2649
2650 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002651 def ivHeightWidthInvalid(**kwargs):
2652 opName = kwargs["opName"]
2653
2654 inputShapes = kwargs["shapeList"]
2655 input_shape = inputShapes[0]
2656
2657 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002658
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002659 if isinstance(args, dict):
2660 args_dict = args
2661 else:
2662 # Create args_dict from list elements
2663 # TODO - Remove this once all NWHC operators agFunctions have been
2664 # converted to args_dict output
2665
2666 # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
2667 stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
2668 args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]}
2669 # Alias different info for each op
2670 args_dict["kernel"] = args[pad_idx + 1]
2671 args_dict["out_shape"] = args[pad_idx + 1]
2672 args_dict["dilation"] = args[pad_idx + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002673
2674 # Common info for all ops
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002675 strides = args_dict["stride"]
2676 padding = args_dict["pad"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002677
2678 if opName.endswith("pool2d"):
2679 # avg_pool2d, max_pool2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002680 kernel_shape = args_dict["kernel"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002681 h = (
2682 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2683 ) // strides[0]
2684 w = (
2685 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2686 ) // strides[1]
2687 # return True if any dimension is < 1
2688 return h < 1 or w < 1
2689
2690 if opName.startswith("transpose_conv2d"):
2691 # transpose_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002692 output_shape = args_dict["out_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002693 filter_shape = inputShapes[1]
2694 kernel_shape = filter_shape[1:-1]
2695
TatWai Chong24594f52022-06-08 00:48:04 -07002696 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002697 """Calculate the transpose_conv2d output size for a dimension."""
2698 return (in_size - 1) * stride + kernel_size + in_pad + out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002699
Jeremy Johnson0c716862023-04-13 17:18:19 +01002700 h = get_out_size(
2701 input_shape[1],
2702 strides[0],
2703 kernel_shape[0],
2704 padding[0],
2705 padding[1],
2706 )
2707 w = get_out_size(
2708 input_shape[2],
2709 strides[1],
2710 kernel_shape[1],
2711 padding[2],
2712 padding[3],
2713 )
2714 if output_shape[1] == h and output_shape[2] == w:
2715 return False
2716 # output shape does not match the expected shape
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002717 return True
2718
2719 if "conv2d" in opName or "conv3d" in opName:
2720 # conv2d, conv3d, depthwise_conv2d
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002721 dilations = args_dict["dilation"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002722 filter_shape = inputShapes[1]
2723 kernel_shape = (
2724 filter_shape[0:2]
2725 if opName.startswith("depthwise_conv2d")
2726 else filter_shape[1:-1]
2727 )
2728
2729 for i in range(len(kernel_shape)):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002730 pad_offset = i * 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002731 dim = (
2732 input_shape[i + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002733 - 1
2734 + padding[pad_offset]
2735 + padding[pad_offset + 1]
2736 - (kernel_shape[i] - 1) * dilations[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002737 ) // strides[i] + 1
2738 # return True if any dimension is < 1
2739 if dim < 1:
2740 return True
2741 return False
2742
2743 assert False, f"Unrecognized Op: {opName}"
2744
2745 @staticmethod
2746 def ivNonPositiveOutputShape(**kwargs):
2747 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002748 output_shape = args[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002749 if output_shape[1] <= 0 or output_shape[2] <= 0:
2750 # Negative output shape
2751 return True
2752 return False