blob: d490cf2436b7405d341c97d0915f853b67b21f2a [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01006from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from generator.tosa_utils import product
8from generator.tosa_utils import usableDTypes
9from generator.tosa_utils import valueToName
10from tosa.DType import DType
11from tosa.Op import Op
12from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000013
Matthew Haddone86fd342021-09-07 16:12:21 +010014
15class ErrorIf(object):
16 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010017 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
18 ScaleNLargerMax = "ScaleNLargerMax"
19 ScaleDLargerMax = "ScaleDLargerMax"
20 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010021 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010022 BorderSmallerMin = "BorderSmallerMin"
23 BorderLargerEqualMax = "BorderLargerEqualMax"
24 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
25 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010026 WrongInputType = "WrongInputType"
27 WrongOutputType = "WrongOutputType"
28 WrongInputList = "WrongInputList"
29 WrongOutputList = "WrongOutputList"
30 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010031 BatchMismatch = "BatchMismatch"
32 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010033 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000034 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010037 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 AxisSmallerZero = "AxisSmallerZero"
39 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010040 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
41 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010042 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010043 KernelSmallerOne = "KernelSmallerOne"
44 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000045 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010046 PadSmallerZero = "PadSmallerZero"
47 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010048 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010049 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010050 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
51 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
52 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010053 ScaleNotTrue = "ScaleNotTrue"
54 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010055 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
56 StartSmallerZero = "StartSmallerZero"
57 SizeSmallerEqualZero = "SizeSmallerEqualZero"
58 StartSizeOutsideBounds = "StartSizeOutsideBounds"
59 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
60 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
61 IndexOutsideBounds = "IndexOutsideBounds"
62 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010063 MaxSmallerMin = "MaxSmallerMin"
64 ConcatInputRankMismatch = "ConcatInputRankMismatch"
65 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010066 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010067 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
68 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
69 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
70 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
71 InputListOutputListMismatch = "InputListOutputListMismatch"
72 InputListCondGraphMismatch = "InputListCondGraphMismatch"
73 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
74 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
75 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010076 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
77 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000078 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
79 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
80 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000081 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Luke Hutton57287132023-02-06 14:54:18 +000082 FFTInputShapeMismatch = "FFTInputShapeMismatch"
83 FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
Jerry Ge264f7fa2023-04-21 22:49:57 +000084 ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
85 ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
Jerry Ge135c9552023-05-23 20:59:32 +000086 BroadcastShapesMismatch = "BroadcastShapesMismatch"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010087
88
89class TosaErrorIfArgGen:
90 @staticmethod
91 def eiResizeErrorIf(
92 testGen,
93 error_name,
94 mode,
95 dtype,
96 shapeList,
97 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010098 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010099 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100100 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100101 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100102 if error_name == ErrorIf.ScaleSmallerEqualZero:
103 index = testGen.randInt(low=0, high=4)
104 scale[index] = testGen.rng.choice([-2, -1, 0])
105 elif error_name == ErrorIf.ScaleNLargerMax:
106 index = testGen.rng.choice([0, 2])
107 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
108 elif error_name == ErrorIf.ScaleDLargerMax:
109 index = testGen.rng.choice([1, 3])
110 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100111
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100112 if error_name == ErrorIf.OffsetLargerEqualMax:
113 index = testGen.rng.choice([0, 1])
114 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
115 elif error_name == ErrorIf.OffsetSmallerMin:
116 index = testGen.rng.choice([0, 1])
117 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
118
119 if error_name == ErrorIf.BorderLargerEqualMax:
120 index = testGen.rng.choice([0, 1])
121 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
122 elif error_name == ErrorIf.BorderSmallerMin:
123 index = testGen.rng.choice([0, 1])
124 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100125
126 if error_name == ErrorIf.WrongOutputType:
127 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
128 incorrect_types = (
129 DType.INT4,
130 DType.INT16,
131 DType.INT32,
132 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100133 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100134 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100135 )
136 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
137 incorrect_types = (
138 DType.INT4,
139 DType.INT8,
140 DType.INT32,
141 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100142 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100143 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100144 )
145 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
146 incorrect_types = (
147 DType.INT4,
148 DType.INT8,
149 DType.INT16,
150 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100151 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100152 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100153 )
154 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
155 incorrect_types = (
156 DType.INT4,
157 DType.INT8,
158 DType.INT16,
159 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100160 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100161 DType.FP16,
162 )
163 elif dtype == DType.FP16:
164 incorrect_types = (
165 DType.INT4,
166 DType.INT8,
167 DType.INT16,
168 DType.INT32,
169 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100170 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100171 )
James Ward24dbc422022-10-19 12:20:31 +0100172 elif dtype == DType.BF16:
173 incorrect_types = (
174 DType.INT4,
175 DType.INT8,
176 DType.INT16,
177 DType.INT32,
178 DType.INT48,
179 DType.FP32,
180 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100181 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182 incorrect_types = (
183 DType.INT4,
184 DType.INT8,
185 DType.INT16,
186 DType.INT32,
187 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100188 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100189 )
190 outputDType = testGen.rng.choice(a=incorrect_types)
191
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100192 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100193
194 @staticmethod
195 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
196 if (
197 error_name == ErrorIf.StrideSmallerOne
198 # padding must not exceed the kernel size
199 and pad[0] < kernel[0]
200 and pad[1] < kernel[0]
201 and pad[2] < kernel[1]
202 and pad[3] < kernel[1]
203 ):
204 wrongStride = (
205 testGen.rng.choice([0, -1, -2, -3]),
206 testGen.rng.choice([0, -1, -2, -3]),
207 )
208 return wrongStride, pad, kernel
209 elif error_name == ErrorIf.PadSmallerZero:
210 wrongPad = (
211 testGen.rng.choice([-1, -2, -3]),
212 testGen.rng.choice([-1, -2, -3]),
213 testGen.rng.choice([-1, -2, -3]),
214 testGen.rng.choice([-1, -2, -3]),
215 )
216 return stride, wrongPad, kernel
217 elif error_name == ErrorIf.KernelSmallerOne:
218 wrongKernel = (
219 testGen.rng.choice([0, -1, -2, -3]),
220 testGen.rng.choice([0, -1, -2, -3]),
221 )
222 return stride, pad, wrongKernel
223 elif error_name == ErrorIf.PadLargerEqualKernel:
224 wrongPad = (
225 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
226 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
227 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
228 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
229 )
230 return stride, wrongPad, kernel
231 else:
232 return None, None, None
233
234 @staticmethod
235 def eiRescaleWrongOutputType(input_dtype, output_dtype):
236 if input_dtype == DType.INT8:
237 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
238 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100239 elif input_dtype == DType.INT16:
240 if output_dtype not in [
241 DType.UINT8,
242 DType.INT8,
243 DType.UINT16,
244 DType.INT16,
245 DType.INT32,
246 ]:
247 return True
248 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
250 return True
251 elif input_dtype == DType.INT48:
252 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
253 return True
254 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100255 if output_dtype not in [DType.INT8, DType.INT16]:
256 return True
257 elif input_dtype == DType.UINT16:
258 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 return True
260 return False
261
262 @staticmethod
263 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
264 # Mess up input/output tensors for ERROR_IF checks
265 if error_name == "WrongInputList":
266 add_input = testGen.rng.choice([True, False])
267 if add_input:
268 input_list.append("eiDummyInput")
269 else:
270 input_list = input_list[:-1]
271 elif error_name == "WrongOutputList":
272 add_output = testGen.rng.choice([True, False])
273 if add_output:
274 output_list.append("eiDummyOutput")
275 else:
276 output_list = []
277 return input_list, output_list
278
279 @staticmethod
280 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
281 """Restrict the dimensions and overall size of a shape to
282 max_dim and max_items.
283 """
284 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
285 while product(new_shape) > max_items:
286 new_shape = [max(d - 1, 1) for d in new_shape]
287 return new_shape
288
289 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
290 if error_name == ErrorIf.StartSmallerZero:
291 newStart = []
292 for i in range(len(input_shape)):
293 newStart.append(testGen.rng.choice([-3, -2, -1]))
294 return newStart, size
295 elif error_name == ErrorIf.SizeSmallerEqualZero:
296 newSize = []
297 for i in range(len(input_shape)):
298 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
299 return start, newSize
300 elif error_name == ErrorIf.StartSizeOutsideBounds:
301 newStart, newSize = [], []
302 for i in range(len(input_shape)):
303 newStart.append(input_shape[i] - 1)
304 newSize.append(testGen.rng.choice([2, 3, 4]))
305 return newStart, newSize
306 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
307 remove = testGen.rng.choice([True, False])
308 if remove:
309 newStart = start[1:]
310 newSize = size[1:]
311 else:
312 newStart = start
313 newStart.append(1)
314 newSize = size
315 newSize.append(1)
316 return newStart, newSize
317 else:
318 return start, size
319
320 @staticmethod
321 def eiCastErrorIf(testGen, input_dtype):
James Ward736fd1a2023-01-23 17:13:37 +0000322 if input_dtype in [DType.BOOL, DType.FP32]:
323 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
324 elif input_dtype in [DType.FP16, DType.BF16]:
325 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100326 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
327 outputDType = [DType.INT48]
328 else:
James Ward736fd1a2023-01-23 17:13:37 +0000329 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330 return outputDType
331
332
333class TosaErrorValidator:
334 @staticmethod
335 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
336 """Check ERROR_IF statements are caught and set the expected result.
337
338 Args:
339 serializer: the serializer to set the expected result in
340 validator_fcns: a sequence of validator functions to verify the result
341 error_name: the name of the ERROR_IF condition to check for
342 kwargs: keyword arguments for the validator functions
343 Returns:
344 True if the result matches the expected result; otherwise False
345 """
346 overall_result = True
347 for val_fcn in validator_fcns:
348 val_result = val_fcn(True, **kwargs)
349 validator_name = val_result["error_name"]
350 error_result = val_result["error_result"]
351 error_reason = val_result["error_reason"]
352
353 # expect an error IFF the error_name and validator_name match
354 expected_result = error_result == (error_name == validator_name)
355 overall_result &= expected_result
356
357 if expected_result and error_result:
358 serializer.setExpectedReturnCode(2, True, desc=error_reason)
359 elif error_result: # and not expected_result
360 print(
361 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
362 f" Expected: {error_name}, Got: {validator_name}"
363 )
364 elif not expected_result: # and not error_result
365 print(
366 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
367 f" Expected: {error_name}"
368 )
369
370 if not expected_result:
371 for k, v in sorted(kwargs.items()):
372 if k != "op":
373 if k.endswith("dtype"):
374 v = valueToName(DType, v)
375 print(f" {k} = {v}")
376
377 return overall_result
378
379 @staticmethod
380 def evWrongInputType(check=False, **kwargs):
381 error_result = False
382
383 # Find the unsupported input data types
384 op = kwargs["op"]
385 input_dtypes = op["types"]
386 allowed_input_dtypes = {
387 t[0] if isinstance(t, list) else t for t in input_dtypes
388 }
389 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
390
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100391 # Turn the wrong dtypes into required list of types
392 if op["op"] in [
393 Op.FULLY_CONNECTED,
394 Op.CONV2D,
395 Op.CONV3D,
396 Op.DEPTHWISE_CONV2D,
397 Op.TRANSPOSE_CONV2D,
398 ]:
399 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
400
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100401 if op["op"] == Op.CLAMP:
402 wrong_input_dtypes.remove(DType.INT48)
403
404 if check:
405 input_dtype = kwargs["input_dtype"]
406 if input_dtype not in allowed_input_dtypes:
407 error_result = True
408
409 info_dict = {
410 "error_name": ErrorIf.WrongInputType,
411 "error_result": error_result,
412 "error_reason": "Input data type not supported for this operator",
413 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
414 }
415 return info_dict
416
417 @staticmethod
418 def evWrongOutputType(check=False, **kwargs):
419 error_result = False
420
421 if check:
422 input_dtype = kwargs["input_dtype"]
423 output_dtype = kwargs["output_dtype"]
424 op = kwargs["op"]
425
426 if op["op"] == Op.RESIZE:
427 mode = kwargs["mode"]
428 if (
429 (
430 mode == ResizeMode.NEAREST
431 and input_dtype == DType.INT8
432 and output_dtype != DType.INT8
433 )
434 or (
435 mode == ResizeMode.NEAREST
436 and input_dtype == DType.INT16
437 and output_dtype != DType.INT16
438 )
439 or (
440 mode == ResizeMode.BILINEAR
441 and input_dtype == DType.INT8
442 and output_dtype != DType.INT32
443 )
444 or (
445 mode == ResizeMode.BILINEAR
446 and input_dtype == DType.INT16
447 and output_dtype != DType.INT48
448 )
James Ward8b390432022-08-12 20:48:56 +0100449 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100450 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100451 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100452 ):
453 error_result = True
454
455 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100456 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
457 input_dtype, output_dtype
458 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100459
460 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
461 if (
462 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
463 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100464 or (
465 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100466 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100467 )
James Ward24dbc422022-10-19 12:20:31 +0100468 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100469 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100470 ):
471 error_result = True
472
473 elif op["op"] == Op.ARGMAX:
474 if (
James Ward24dbc422022-10-19 12:20:31 +0100475 input_dtype
476 in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100477 and output_dtype != DType.INT32
478 ):
479 error_result = True
480
481 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100482 if (
James Ward24dbc422022-10-19 12:20:31 +0100483 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100484 and output_dtype != DType.INT32
485 ):
486 error_result = True
487 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100488 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100489 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
490 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100491 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100492 error_result = True
493
494 elif op["op"] == Op.TABLE:
495 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
496 error_result = True
497 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
498 error_result = True
499
500 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
501 if output_dtype != DType.BOOL:
502 error_result = True
503
504 elif op["op"] == Op.CAST:
505 if (
506 (
507 input_dtype == DType.BOOL
508 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
509 )
510 or (
511 input_dtype == DType.INT8
512 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100513 not in [
514 DType.BOOL,
515 DType.INT16,
516 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100517 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100518 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100519 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100520 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100521 )
522 or (
523 input_dtype == DType.INT16
524 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100525 not in [
526 DType.BOOL,
527 DType.INT8,
528 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100529 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100530 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100531 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100532 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100533 )
534 or (
535 input_dtype == DType.INT32
536 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100537 not in [
538 DType.BOOL,
539 DType.INT8,
540 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100541 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100542 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100543 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100544 ]
545 )
546 or (
547 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000548 and output_dtype
549 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100550 )
551 or (
James Ward24dbc422022-10-19 12:20:31 +0100552 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000553 and output_dtype
554 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +0100555 )
556 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100557 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000558 and output_dtype
559 not in [
560 DType.INT8,
561 DType.INT16,
562 DType.INT32,
563 DType.FP16,
564 DType.BF16,
565 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100566 )
567 ):
568 error_result = True
569
Luke Hutton57287132023-02-06 14:54:18 +0000570 elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000571 if not all([ty == input_dtype for ty in output_dtype]):
572 error_result = True
573
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100574 elif op["op"] in {
575 Op.CONV2D,
576 Op.CONV3D,
577 Op.DEPTHWISE_CONV2D,
578 Op.TRANSPOSE_CONV2D,
579 }:
580 if (
581 input_dtype == DType.INT8
582 and output_dtype != DType.INT32
583 or input_dtype == DType.INT16
584 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100585 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100586 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100587 or input_dtype == DType.BF16
588 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100589 or input_dtype == DType.FP32
590 and output_dtype != DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100591 ):
592 error_result = True
593 # invalid input types are ignored, to avoid reporting multiple errors
594
595 else:
596 if output_dtype != input_dtype:
597 error_result = True
598
599 info_dict = {
600 "error_name": ErrorIf.WrongOutputType,
601 "error_result": error_result,
602 "error_reason": (
603 "Output data type not supported for this configuration of operator"
604 ),
605 "param_reqs": {"rank": None, "dtype": None, "shape": None},
606 }
607 return info_dict
608
609 @staticmethod
610 def evWrongRank(check=False, **kwargs):
611 all_ranks = (1, 2, 3, 4, 5)
612
613 # Make a list of incorrect ranks
614 assert "op" in kwargs
615 op = kwargs["op"]
616 rmin, rmax = op["rank"]
617 rank_range = range(rmin, rmax + 1)
618 incorrect_ranks = list(set(all_ranks) - set(rank_range))
619 # Remove small incorrect ranks to avoid index errors
620 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
621 # Set minimum incorrect rank to 3 to avoid index error
622 if op["op"] in [Op.RESIZE]:
623 incorrect_ranks = [3, 5]
624 elif op["op"] in [Op.TRANSPOSE]:
625 incorrect_ranks = [7, 8]
626 elif op["op"] in [Op.CONV3D]:
627 incorrect_ranks = [6, 7]
628
629 error_name = ErrorIf.WrongRank
630 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
631 error_result = False
632 error_reason = "Rank not supported for this operator"
633
634 if check:
635 input_shape = kwargs["input_shape"]
636
637 if (
638 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
639 and len(input_shape) != 4
640 ):
641 error_result = True
642 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
643 error_result = True
644 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
645 error_result = True
646 else:
647 if len(input_shape) not in rank_range:
648 error_result = True
649
650 info_dict = {
651 "error_name": error_name,
652 "error_result": error_result,
653 "error_reason": error_reason,
654 "param_reqs": param_reqs,
655 }
656 return info_dict
657
658 @staticmethod
659 def evWrongInputList(check=False, **kwargs):
660 error_name = ErrorIf.WrongInputList
661 param_reqs = {"rank": None, "dtype": None, "shape": None}
662 error_result = False
663 error_reason = "Op input list does not match expected input"
664
665 if check:
666 op = kwargs["op"]
667 input_list = kwargs["input_list"]
668 num_operands = kwargs["num_operands"]
669 if op["op"] in [Op.SCATTER, Op.GATHER]:
670 # SCATTER/GATHER add an indices input tensor in their build functions
671 num_operands += 1
672 if len(input_list) != num_operands:
673 error_result = True
674
675 info_dict = {
676 "error_name": error_name,
677 "error_result": error_result,
678 "error_reason": error_reason,
679 "param_reqs": param_reqs,
680 }
681 return info_dict
682
683 @staticmethod
684 def evWrongOutputList(check=False, **kwargs):
685 error_name = ErrorIf.WrongOutputList
686 param_reqs = {"rank": None, "dtype": None, "shape": None}
687 error_result = False
688 error_reason = "Op output list does not match expected output"
689
690 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000691 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100692 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000693 expected_length = 1
Luke Hutton57287132023-02-06 14:54:18 +0000694 if op["op"] in [Op.FFT2D, Op.RFFT2D]:
Luke Hutton261b7b62023-01-10 14:50:31 +0000695 expected_length = 2
696
697 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100698 error_result = True
699
700 info_dict = {
701 "error_name": error_name,
702 "error_result": error_result,
703 "error_reason": error_reason,
704 "param_reqs": param_reqs,
705 }
706 return info_dict
707
708 @staticmethod
709 def evMaxDimExceeded(check=False, **kwargs):
710 error_name = ErrorIf.MaxDimExceeded
711 param_reqs = {
712 "rank": [4, 4],
713 "dtype": [DType.INT8],
714 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
715 }
716 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100717 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100718
719 if check:
720 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100721 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100722 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100723 (input_shape[1] >= MAX_RESIZE_DIMENSION)
724 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
725 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
726 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100727 ):
728 error_result = True
729
730 info_dict = {
731 "error_name": error_name,
732 "error_result": error_result,
733 "error_reason": error_reason,
734 "param_reqs": param_reqs,
735 }
736 return info_dict
737
738 @staticmethod
739 def evBatchMismatch(check=False, **kwargs):
740 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000741 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100742 error_result = False
743 error_reason = "Input batch size not equal to output batch size"
744
745 assert "op" in kwargs
746 op = kwargs["op"]
747 rmin, rmax = op["rank"]
748 rank_range = range(rmin, rmax + 1)
749
750 if check:
751 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100752
Luke Hutton261b7b62023-01-10 14:50:31 +0000753 for output in kwargs["result_tensors"]:
754 output_shape = (
755 output.shape
756 ) # Note batch is expected to be the first dim
757 if (len(input_shape) in rank_range) and (
758 input_shape[0] != output_shape[0]
759 ):
760 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100761
762 info_dict = {
763 "error_name": error_name,
764 "error_result": error_result,
765 "error_reason": error_reason,
766 "param_reqs": param_reqs,
767 }
768 return info_dict
769
770 @staticmethod
771 def evChannelMismatch(check=False, **kwargs):
772 error_name = ErrorIf.ChannelMismatch
773 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
774 error_result = False
775 error_reason = "Input channel size not equal to output channel size"
776
777 assert "op" in kwargs
778 op = kwargs["op"]
779 rmin, rmax = op["rank"]
780 rank_range = range(rmin, rmax + 1)
781
782 if check:
783 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000784 for output in kwargs["result_tensors"]:
785 output_shape = output.shape # Note this is just (N, OH, OW, C)
786 if (len(input_shape) in rank_range) and (
787 input_shape[3] != output_shape[3]
788 ):
789 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100790
791 info_dict = {
792 "error_name": error_name,
793 "error_result": error_result,
794 "error_reason": error_reason,
795 "param_reqs": param_reqs,
796 }
797 return info_dict
798
799 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100800 def evScaleSmallerEqualZero(check=False, **kwargs):
801 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100802 param_reqs = {"rank": None, "dtype": None, "shape": None}
803 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100804 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100805
806 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100807 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100808
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100809 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100810 error_result = True
811
812 info_dict = {
813 "error_name": error_name,
814 "error_result": error_result,
815 "error_reason": error_reason,
816 "param_reqs": param_reqs,
817 }
818 return info_dict
819
820 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100821 def evScaleNLargerMax(check=False, **kwargs):
822 error_name = ErrorIf.ScaleNLargerMax
823 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100824 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100825 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100826
827 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100828 scale = kwargs["scale"]
829
830 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
831 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100832
833 info_dict = {
834 "error_name": error_name,
835 "error_result": error_result,
836 "error_reason": error_reason,
837 "param_reqs": param_reqs,
838 }
839 return info_dict
840
841 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100842 def evScaleDLargerMax(check=False, **kwargs):
843 error_name = ErrorIf.ScaleDLargerMax
844 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100845 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100846 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100847
848 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100849 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100850
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100851 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
852 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100853 ):
854 error_result = True
855
856 info_dict = {
857 "error_name": error_name,
858 "error_result": error_result,
859 "error_reason": error_reason,
860 "param_reqs": param_reqs,
861 }
862 return info_dict
863
864 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100865 def evOffsetSmallerMin(check=False, **kwargs):
866 error_name = ErrorIf.OffsetSmallerMin
867 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100868 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100869 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100870
871 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100872 scale = kwargs["scale"]
873 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100874
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100875 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100876 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100877 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100878 error_result = True
879
880 info_dict = {
881 "error_name": error_name,
882 "error_result": error_result,
883 "error_reason": error_reason,
884 "param_reqs": param_reqs,
885 }
886 return info_dict
887
888 @staticmethod
889 def evOffsetLargerEqualMax(check=False, **kwargs):
890 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100891 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100892 error_result = False
893 error_reason = "Offset value larger than or equal to maximum value"
894
895 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100896 scale = kwargs["scale"]
897 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100898
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100899 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
900 error_result = True
901 elif (
902 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
903 ):
904 error_result = True
905
906 info_dict = {
907 "error_name": error_name,
908 "error_result": error_result,
909 "error_reason": error_reason,
910 "param_reqs": param_reqs,
911 }
912 return info_dict
913
914 @staticmethod
915 def evBorderSmallerMin(check=False, **kwargs):
916 error_name = ErrorIf.BorderSmallerMin
917 param_reqs = {"rank": None, "dtype": None, "shape": None}
918 error_result = False
919 error_reason = "Border value smaller than minimum value"
920
921 if check:
922 scale = kwargs["scale"]
923 border = kwargs["border"]
924
925 if (
926 scale[0] > 0
927 and scale[0] <= (1 << 11)
928 and (border[0] < (-16 * scale[0]))
929 ):
930 error_result = True
931 elif (
932 scale[2] > 0
933 and scale[2] <= (1 << 11)
934 and (border[1] < (-16 * scale[2]))
935 ):
936 error_result = True
937
938 info_dict = {
939 "error_name": error_name,
940 "error_result": error_result,
941 "error_reason": error_reason,
942 "param_reqs": param_reqs,
943 }
944 return info_dict
945
946 @staticmethod
947 def evBorderLargerEqualMax(check=False, **kwargs):
948 error_name = ErrorIf.BorderLargerEqualMax
949 param_reqs = {"rank": None, "dtype": None, "shape": None}
950 error_result = False
951 error_reason = "Border value larger than or equal to maximum value"
952
953 if check:
954 scale = kwargs["scale"]
955 border = kwargs["border"]
956
957 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
958 error_result = True
959 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
960 error_result = True
961
962 info_dict = {
963 "error_name": error_name,
964 "error_result": error_result,
965 "error_reason": error_reason,
966 "param_reqs": param_reqs,
967 }
968 return info_dict
969
970 @staticmethod
971 def checkResizeParams(scale, offset, border):
972 return (
973 min(scale) > 0
974 and max(scale[0], scale[2]) <= (1 << 11)
975 and scale[1] < 16 * scale[0]
976 and scale[3] < 16 * scale[2]
977 and offset[0] >= -scale[0]
978 and offset[1] >= -scale[2]
979 and offset[0] < 16 * scale[0]
980 and offset[1] < 16 * scale[2]
981 and border[0] >= -16 * scale[0]
982 and border[1] >= -16 * scale[2]
983 and border[0] < scale[0]
984 and border[1] < scale[2]
985 )
986
987 @staticmethod
988 def evResizeOutputShapeMismatch(check=False, **kwargs):
989 error_name = ErrorIf.ResizeOutputShapeMismatch
990 param_reqs = {"rank": None, "dtype": None, "shape": None}
991 error_result = False
992 error_reason = (
993 "Mismatch between output shape provided and expected output shape"
994 )
995
996 if check:
997 input_shape = kwargs["input_shape"]
998 output_shape = kwargs["output_shape"]
999 scale = kwargs["scale"]
1000 offset = kwargs["offset"]
1001 border = kwargs["border"]
1002
1003 # Ensure parameters are valid
1004 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1005
1006 if (
1007 params_valid
1008 and max(output_shape) < MAX_RESIZE_DIMENSION
1009 and max(input_shape) < MAX_RESIZE_DIMENSION
1010 ):
1011 output_y = (
1012 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1013 ) // scale[1] + 1
1014 output_x = (
1015 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1016 ) // scale[3] + 1
1017
1018 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 error_result = True
1020
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 info_dict = {
1022 "error_name": error_name,
1023 "error_result": error_result,
1024 "error_reason": error_reason,
1025 "param_reqs": param_reqs,
1026 }
1027 return info_dict
1028
1029 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001030 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1031 error_name = ErrorIf.ResizeOutputShapeNonInteger
1032 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001033 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001034 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001035
1036 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001037 input_shape = kwargs["input_shape"]
1038 scale = kwargs["scale"]
1039 offset = kwargs["offset"]
1040 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001041
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001042 # Ensure parameters are valid
1043 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001045 if params_valid:
1046 remainder_y = (
1047 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1048 ) % scale[1]
1049 remainder_x = (
1050 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1051 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001052
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001053 if max(remainder_y, remainder_x) > 0:
1054 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001055
1056 info_dict = {
1057 "error_name": error_name,
1058 "error_result": error_result,
1059 "error_reason": error_reason,
1060 "param_reqs": param_reqs,
1061 }
1062 return info_dict
1063
1064 @staticmethod
1065 def evRankMismatch(check=False, **kwargs):
1066 error_name = ErrorIf.RankMismatch
1067 param_reqs = {"rank": None, "dtype": None, "shape": None}
1068 error_result = False
1069 error_reason = "Input Rank does not match output rank"
1070
1071 if check:
1072 input1_shape = kwargs["input1"].shape
Luke Huttona4e48ca2023-02-22 11:53:48 +00001073 input2_shape = (
1074 kwargs["input2"].shape if "input2" in kwargs else input1_shape
1075 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001076 # In case of SELECT op
1077 input3_shape = (
1078 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1079 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001080
1081 for output in kwargs["result_tensors"]:
1082 output_shape = output.shape
1083 if (
1084 (len(input1_shape) != len(output_shape))
1085 or (len(input2_shape) != len(output_shape))
1086 or (len(input3_shape) != len(output_shape))
1087 ):
1088 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001089
1090 info_dict = {
1091 "error_name": error_name,
1092 "error_result": error_result,
1093 "error_reason": error_reason,
1094 "param_reqs": param_reqs,
1095 }
1096 return info_dict
1097
1098 @staticmethod
1099 def evDimensionMismatch(check=False, **kwargs):
1100 error_name = ErrorIf.DimensionMismatch
1101 param_reqs = {"rank": None, "dtype": None, "shape": None}
1102 error_result = False
1103 error_reason = "Input Dimensions do not match output"
1104
1105 if check:
1106 input1_shape = kwargs["input1"].shape
1107 input2_shape = kwargs["input2"].shape
1108 # In case of SELECT op
1109 input3_shape = (
1110 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1111 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001112
Jerry Ge135c9552023-05-23 20:59:32 +00001113 if len(input1_shape) == len(input2_shape) == len(input3_shape):
1114 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
1115 input3_shape,
1116 TosaErrorValidator.calculateBroadcastShape(
1117 input1_shape, input2_shape
1118 ),
1119 )
1120 if calculated_shape is not None:
1121 # Valid inputs - check for output mismatch
1122 for output in kwargs["result_tensors"]:
1123 output_shape = output.shape
1124 if calculated_shape != output_shape:
1125 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001126
1127 info_dict = {
1128 "error_name": error_name,
1129 "error_result": error_result,
1130 "error_reason": error_reason,
1131 "param_reqs": param_reqs,
1132 }
1133 return info_dict
1134
1135 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001136 def _getZeroPoint(qinfo, index):
1137 """Return zero point value from quantization info.
1138
1139 Generally input_zp is index 0, output_zp is index 1
1140 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001141 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001142
1143 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001144 def evInputZeroPointNotZero(check=False, **kwargs):
1145 op = kwargs["op"]
1146 error_result = False
1147
1148 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001149 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001150
1151 # This does not apply to quantizable types
1152 inputDtypes = [
1153 dtype
1154 for dtype in op["types"]
1155 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1156 or (not isinstance(dtype, list) and dtype not in qTypes)
1157 ]
1158
1159 if check:
1160 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001161 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001162 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001163 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001164 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001165 (kwargs["input_dtype"], input_zero_point),
1166 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001167 ):
1168 if dtype not in qTypes and zp != 0:
1169 error_result = True
1170 break
1171 else:
1172 error_result = input_dtype not in qTypes and input_zero_point != 0
1173
1174 info_dict = {
1175 "error_name": ErrorIf.InputZeroPointNotZero,
1176 "error_result": error_result,
1177 "error_reason": "Input DType not INT8 and zero point not 0",
1178 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1179 }
1180 return info_dict
1181
1182 @staticmethod
1183 def evWeightZeroPointNotZero(check=False, **kwargs):
1184 op = kwargs["op"]
1185
1186 # exclude inputs with INT8 weights
1187 inputDtypes = [
1188 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1189 ]
1190
1191 error_name = ErrorIf.WeightZeroPointNotZero
1192 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1193 error_result = False
1194 error_reason = "Weight DType not INT8 and zero point not 0"
1195
1196 if check:
1197 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001198 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001199 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1200 error_result = True
1201
1202 info_dict = {
1203 "error_name": error_name,
1204 "error_result": error_result,
1205 "error_reason": error_reason,
1206 "param_reqs": param_reqs,
1207 }
1208 return info_dict
1209
1210 @staticmethod
1211 def evOutputZeroPointNotZero(check=False, **kwargs):
1212 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001213 inputDtypes = [
1214 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1215 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001216
1217 error_name = ErrorIf.OutputZeroPointNotZero
1218 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1219 error_result = False
1220 error_reason = "Output DType not INT8 and zero point not 0"
1221
1222 if check:
1223 input_dtype = kwargs["input_dtype"]
1224 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001225 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001226 if op["op"] == Op.AVG_POOL2D:
1227 if input_dtype != DType.INT8 and output_zero_point != 0:
1228 error_result = True
1229 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001230 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1231 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001232 ):
1233 error_result = True
1234
1235 info_dict = {
1236 "error_name": error_name,
1237 "error_result": error_result,
1238 "error_reason": error_reason,
1239 "param_reqs": param_reqs,
1240 }
1241 return info_dict
1242
1243 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001244 def evU16InputZeroPointNotValid(check=False, **kwargs):
1245 error_name = ErrorIf.U16InputZeroPointNotValid
1246 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1247 error_result = False
1248 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1249
1250 if check:
1251 input_dtype = kwargs["input_dtype"]
1252 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1253 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1254 0,
1255 32768,
1256 ]
1257
1258 info_dict = {
1259 "error_name": error_name,
1260 "error_result": error_result,
1261 "error_reason": error_reason,
1262 "param_reqs": param_reqs,
1263 }
1264 return info_dict
1265
1266 @staticmethod
1267 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1268 error_name = ErrorIf.U16OutputZeroPointNotValid
1269 param_reqs = {"rank": None, "dtype": None, "shape": None}
1270 error_result = False
1271 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1272
1273 if check:
1274 output_dtype = kwargs["output_dtype"]
1275 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1276
1277 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1278 0,
1279 32768,
1280 ]
1281
1282 info_dict = {
1283 "error_name": error_name,
1284 "error_result": error_result,
1285 "error_reason": error_reason,
1286 "param_reqs": param_reqs,
1287 }
1288 return info_dict
1289
1290 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001291 def evAxisSmallerZero(check=False, **kwargs):
1292 error_name = ErrorIf.AxisSmallerZero
1293 param_reqs = {"rank": None, "dtype": None, "shape": None}
1294 error_result = False
1295 error_reason = "Axis smaller than zero"
1296
1297 if check:
1298 axis = kwargs["axis"]
1299 if axis < 0:
1300 error_result = True
1301
1302 info_dict = {
1303 "error_name": error_name,
1304 "error_result": error_result,
1305 "error_reason": error_reason,
1306 "param_reqs": param_reqs,
1307 }
1308 return info_dict
1309
1310 @staticmethod
1311 def evAxisLargerRank(check=False, **kwargs):
1312 error_name = ErrorIf.AxisLargerRank
1313 param_reqs = {"rank": None, "dtype": None, "shape": None}
1314 error_result = False
1315 error_reason = "Axis larger than rank"
1316
1317 if check:
1318 axis = kwargs["axis"]
1319 shape = kwargs["input_shape"]
1320 if axis > len(shape):
1321 error_result = True
1322
1323 info_dict = {
1324 "error_name": error_name,
1325 "error_result": error_result,
1326 "error_reason": error_reason,
1327 "param_reqs": param_reqs,
1328 }
1329 return info_dict
1330
1331 @staticmethod
1332 def evShapeOfAxisNotOne(check=False, **kwargs):
1333 error_name = ErrorIf.ShapeOfAxisNotOne
1334 param_reqs = {"rank": None, "dtype": None, "shape": None}
1335 error_result = False
1336 error_reason = "shape[axis] is not equal to 1"
1337
1338 if check:
1339 axis = kwargs["axis"]
1340 shape = kwargs["output_shape"]
1341 if (0 <= axis < len(shape)) and shape[axis] != 1:
1342 error_result = True
1343
1344 info_dict = {
1345 "error_name": error_name,
1346 "error_result": error_result,
1347 "error_reason": error_reason,
1348 "param_reqs": param_reqs,
1349 }
1350 return info_dict
1351
1352 @staticmethod
1353 def evPadSmallerZero(check=False, **kwargs):
1354 error_name = ErrorIf.PadSmallerZero
1355 param_reqs = {"rank": None, "dtype": None, "shape": None}
1356 error_result = False
1357 error_reason = "At least one pad is smaller than zero"
1358
1359 if check:
1360 op = kwargs["op"]
1361 pad = kwargs["pad"]
1362 if op["op"] == Op.PAD:
1363 for padding in pad:
1364 if min(padding) < 0:
1365 error_result = True
1366 else:
1367 if min(pad) < 0:
1368 error_result = True
1369
1370 info_dict = {
1371 "error_name": error_name,
1372 "error_result": error_result,
1373 "error_reason": error_reason,
1374 "param_reqs": param_reqs,
1375 }
1376 return info_dict
1377
1378 @staticmethod
1379 def evPadLargerEqualKernel(check=False, **kwargs):
1380 error_name = ErrorIf.PadLargerEqualKernel
1381 param_reqs = {"rank": None, "dtype": None, "shape": None}
1382 error_result = False
1383 error_reason = "At least one pad is larger than kernel dimension"
1384
1385 if check:
1386 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001387 op = kwargs["op"]
1388 if op["op"] == Op.TRANSPOSE_CONV2D:
1389 # transpose_conv2d
1390 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001391 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001392 pad[0] <= -kernel[0]
1393 or pad[1] <= -kernel[0]
1394 or pad[2] <= -kernel[1]
1395 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001396 ):
1397 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001398 else:
1399 # pooling op
1400 kernel = kwargs["kernel"]
1401 if min(pad) > 0 and min(kernel) > 1:
1402 if (
1403 pad[0] >= kernel[0]
1404 or pad[1] >= kernel[0]
1405 or pad[2] >= kernel[1]
1406 or pad[3] >= kernel[1]
1407 ):
1408 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001409
1410 info_dict = {
1411 "error_name": error_name,
1412 "error_result": error_result,
1413 "error_reason": error_reason,
1414 "param_reqs": param_reqs,
1415 }
1416 return info_dict
1417
1418 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001419 def evPadOutputShapeMismatch(check=False, **kwargs):
1420 error_name = ErrorIf.PadOutputShapeMismatch
1421 param_reqs = {"rank": None, "dtype": None, "shape": None}
1422 error_result = False
1423 error_reason = "Pad output shape mismatch for requested padding"
1424
1425 if check:
1426 pad = kwargs["pad"]
1427 input_shape = kwargs["input_shape"]
1428 output_shape = kwargs["output_shape"]
1429 for dim, padding in enumerate(pad):
1430 expected_size = input_shape[dim] + padding[0] + padding[1]
1431 if expected_size != output_shape[dim]:
1432 error_result = True
1433
1434 info_dict = {
1435 "error_name": error_name,
1436 "error_result": error_result,
1437 "error_reason": error_reason,
1438 "param_reqs": param_reqs,
1439 }
1440 return info_dict
1441
1442 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001443 def checkPoolingParams(kernel, stride, pad):
1444 return (
1445 min(kernel) >= 1
1446 and min(stride) >= 1
1447 and min(pad) >= 0
1448 and not (
1449 pad[0] >= kernel[0]
1450 or pad[1] >= kernel[0]
1451 or pad[2] >= kernel[1]
1452 or pad[3] >= kernel[1]
1453 )
1454 )
1455
1456 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001457 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1458 error_name = ErrorIf.PoolingOutputShapeMismatch
1459 param_reqs = {"rank": None, "dtype": None, "shape": None}
1460 error_result = False
1461 error_reason = (
1462 "Mismatch between output shape provided and expected output shape"
1463 )
1464
1465 if check:
1466 pad = kwargs["pad"]
1467 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1468
1469 kernel = kwargs["kernel"]
1470 kernel_y, kernel_x = kernel[0], kernel[1]
1471
1472 input_shape = kwargs["input_shape"]
1473 IH, IW = input_shape[1], input_shape[2]
1474
1475 output_shape = kwargs["output_shape"]
1476 OH, OW = output_shape[1], output_shape[2]
1477
1478 stride = kwargs["stride"]
1479 stride_y, stride_x = stride[0], stride[1]
1480
1481 # calculate correct height, width dimensions
1482 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001483 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1484 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001485
1486 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001487 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001488
1489 if params_valid and (OH != y_correct or OW != x_correct):
1490 error_result = True
1491
1492 info_dict = {
1493 "error_name": error_name,
1494 "error_result": error_result,
1495 "error_reason": error_reason,
1496 "param_reqs": param_reqs,
1497 }
1498 return info_dict
1499
1500 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001501 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1502 error_name = ErrorIf.PoolingOutputShapeNonInteger
1503 param_reqs = {"rank": None, "dtype": None, "shape": None}
1504 error_result = False
1505 error_reason = "Parameters do not yield exact integer output dimensions"
1506
1507 if check:
1508 pad = kwargs["pad"]
1509 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1510
1511 kernel = kwargs["kernel"]
1512 kernel_y, kernel_x = kernel[0], kernel[1]
1513
1514 input_shape = kwargs["input_shape"]
1515 IH, IW = input_shape[1], input_shape[2]
1516
1517 stride = kwargs["stride"]
1518 stride_y, stride_x = stride[0], stride[1]
1519
1520 # calculate remainder of height, width dimensions
1521 if stride_x != 0 and stride_y != 0:
1522 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1523 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1524
1525 # ensure parameters are valid
1526 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1527 if params_valid and (y_remainder != 0 or x_remainder != 0):
1528 error_result = True
1529
1530 info_dict = {
1531 "error_name": error_name,
1532 "error_result": error_result,
1533 "error_reason": error_reason,
1534 "param_reqs": param_reqs,
1535 }
1536 return info_dict
1537
1538 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001539 def checkConvParams(op, weight_shape, stride, pad, dilation):
1540 if op == Op.TRANSPOSE_CONV2D:
1541 pad_ok = (
1542 pad[0] > -weight_shape[1]
1543 and pad[1] > -weight_shape[1]
1544 and pad[2] > -weight_shape[2]
1545 and pad[3] > -weight_shape[2]
1546 )
1547 else:
1548 pad_ok = min(pad) >= 0
1549
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001550 return (
1551 # Check kernel sizes
1552 min(weight_shape[1:-1]) >= 1
1553 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001554 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001555 and (dilation is None or min(dilation) >= 1)
1556 )
1557
1558 @staticmethod
1559 def evConvOutputShapeMismatch(check=False, **kwargs):
1560 error_name = ErrorIf.ConvOutputShapeMismatch
1561 param_reqs = {"rank": None, "dtype": None, "shape": None}
1562 error_result = False
1563 error_reason = (
1564 "Mismatch between output shape provided and expected output shape"
1565 )
1566
1567 if check:
1568 op = kwargs["op"]
1569 pad = kwargs["pad"]
1570 weight_shape = kwargs["weight_shape"]
1571 input_shape = kwargs["input_shape"]
1572 output_shape = kwargs["output_shape"]
1573 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1574 stride = kwargs["stride"]
1575
1576 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1577
1578 # calculate correct dimensions
1579 dims_correct = []
1580 if min(stride) > 0:
1581 for index in range(len(stride)):
1582 pad_offset = index * 2
1583 if op["op"] == Op.TRANSPOSE_CONV2D:
1584 dims_correct.append(
1585 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001586 + pad[pad_offset]
1587 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001588 + weight_shape[index + kernel_offset]
1589 )
1590 else:
1591 dims_correct.append(
1592 (
1593 input_shape[index + 1]
1594 - 1
1595 + pad[pad_offset]
1596 + pad[pad_offset + 1]
1597 - (weight_shape[index + kernel_offset] - 1)
1598 * dilation[index]
1599 )
1600 // stride[index]
1601 + 1
1602 )
1603
1604 # ensure parameters are valid
1605 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001606 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001607 )
1608
1609 if params_valid and output_shape[1:-1] != dims_correct:
1610 error_result = True
1611
1612 info_dict = {
1613 "error_name": error_name,
1614 "error_result": error_result,
1615 "error_reason": error_reason,
1616 "param_reqs": param_reqs,
1617 }
1618 return info_dict
1619
1620 @staticmethod
1621 def evConvOutputShapeNonInteger(check=False, **kwargs):
1622 error_name = ErrorIf.ConvOutputShapeNonInteger
1623 param_reqs = {"rank": None, "dtype": None, "shape": None}
1624 error_result = False
1625 error_reason = "Parameters do not yield exact integer output dimensions"
1626
1627 if check:
1628 op = kwargs["op"]
1629 pad = kwargs["pad"]
1630 weight_shape = kwargs["weight_shape"]
1631 input_shape = kwargs["input_shape"]
1632 dilation = kwargs["dilation"]
1633 stride = kwargs["stride"]
1634
1635 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1636
1637 # calculate correct height, width dimensions
1638 remainders = []
1639 if min(stride) > 0:
1640 for index in range(len(stride)):
1641 pad_offset = index * 2
1642 remainders.append(
1643 (
1644 input_shape[index + 1]
1645 - 1
1646 + pad[pad_offset]
1647 + pad[pad_offset + 1]
1648 - (weight_shape[index + kernel_offset] - 1)
1649 * dilation[index]
1650 )
1651 % stride[index]
1652 )
1653
1654 # ensure parameters are valid
1655 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001656 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001657 )
1658 if params_valid and max(remainders) > 0:
1659 error_result = True
1660
1661 info_dict = {
1662 "error_name": error_name,
1663 "error_result": error_result,
1664 "error_reason": error_reason,
1665 "param_reqs": param_reqs,
1666 }
1667 return info_dict
1668
1669 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001670 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1671 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1672 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1673 error_result = False
1674 error_reason = (
1675 "Mismatch between output shape provided and expected output shape"
1676 )
1677
1678 if check:
1679 output_shape = kwargs["output_shape"]
1680 input_shape = kwargs["input_shape"]
1681 axis = kwargs["axis"]
1682
1683 dimension_match = True
1684 axis_shift = 0
1685
1686 # Check that rank is correct before trying to check dimensions
1687 if (len(input_shape) - 1) == len(output_shape):
1688 for i in range(len(input_shape)):
1689 if i == axis:
1690 axis_shift = 1
1691 continue
1692 if input_shape[i] != output_shape[i - axis_shift]:
1693 dimension_match = False
1694
1695 if not dimension_match:
1696 error_result = True
1697
1698 info_dict = {
1699 "error_name": error_name,
1700 "error_result": error_result,
1701 "error_reason": error_reason,
1702 "param_reqs": param_reqs,
1703 }
1704 return info_dict
1705
1706 @staticmethod
1707 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1708 error_name = ErrorIf.ArgmaxOutputRankMismatch
1709 param_reqs = {"rank": None, "dtype": None, "shape": None}
1710 error_result = False
1711 error_reason = (
1712 "Mismatch between output shape provided and expected output shape"
1713 )
1714
1715 if check:
1716 output_shape = kwargs["output_shape"]
1717 input_shape = kwargs["input_shape"]
1718 axis = kwargs["axis"]
1719 valid_params = axis >= 0 and axis < len(input_shape)
1720
1721 if valid_params and (len(input_shape) - 1) != len(output_shape):
1722 error_result = True
1723
1724 info_dict = {
1725 "error_name": error_name,
1726 "error_result": error_result,
1727 "error_reason": error_reason,
1728 "param_reqs": param_reqs,
1729 }
1730 return info_dict
1731
1732 @staticmethod
1733 def evKernelSmallerOne(check=False, **kwargs):
1734 error_name = ErrorIf.KernelSmallerOne
1735 param_reqs = {"rank": None, "dtype": None, "shape": None}
1736 error_result = False
1737 error_reason = "At least one kernel dimension is smaller than zero"
1738
1739 if check:
1740 kernel = kwargs["kernel"]
1741 if min(kernel) < 1:
1742 error_result = True
1743
1744 info_dict = {
1745 "error_name": error_name,
1746 "error_result": error_result,
1747 "error_reason": error_reason,
1748 "param_reqs": param_reqs,
1749 }
1750 return info_dict
1751
1752 @staticmethod
1753 def evStrideSmallerOne(check=False, **kwargs):
1754 error_name = ErrorIf.StrideSmallerOne
1755 param_reqs = {"rank": None, "dtype": None, "shape": None}
1756 error_result = False
1757 error_reason = "At least one stride dimension is smaller than zero"
1758
1759 if check:
1760 stride = kwargs["stride"]
1761 if min(stride) < 1:
1762 error_result = True
1763
1764 info_dict = {
1765 "error_name": error_name,
1766 "error_result": error_result,
1767 "error_reason": error_reason,
1768 "param_reqs": param_reqs,
1769 }
1770 return info_dict
1771
1772 @staticmethod
1773 def evDilationSmallerOne(check=False, **kwargs):
1774 error_result = check and min(kwargs["dilation"]) < 1
1775 return {
1776 "error_name": ErrorIf.DilationSmallerOne,
1777 "error_reason": "At least one dilation is smaller than one",
1778 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1779 "error_result": error_result,
1780 }
1781
1782 @staticmethod
1783 def evScaleTrue(check=False, **kwargs):
1784 error_name = ErrorIf.ScaleTrue
1785 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1786 error_result = False
1787 error_reason = "Scale set to true but input type is INT48"
1788
1789 if check:
1790 input_dtype = kwargs["input_dtype"]
1791 scale32 = kwargs["scale32"]
1792 if scale32 and input_dtype == DType.INT48:
1793 error_result = True
1794
1795 info_dict = {
1796 "error_name": error_name,
1797 "error_result": error_result,
1798 "error_reason": error_reason,
1799 "param_reqs": param_reqs,
1800 }
1801 return info_dict
1802
1803 @staticmethod
1804 def evScaleNotTrue(check=False, **kwargs):
1805 error_name = ErrorIf.ScaleNotTrue
1806 param_reqs = {"rank": None, "dtype": None, "shape": None}
1807 error_result = False
1808 error_reason = "Scale set to false but double round set to true"
1809
1810 if check:
1811 scale32 = kwargs["scale32"]
1812 double_round = kwargs["double_round"]
1813 if not scale32 and double_round:
1814 error_result = True
1815
1816 info_dict = {
1817 "error_name": error_name,
1818 "error_result": error_result,
1819 "error_reason": error_reason,
1820 "param_reqs": param_reqs,
1821 }
1822 return info_dict
1823
1824 @staticmethod
1825 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1826 error_name = ErrorIf.TensorSizeInputOutputMismatch
1827 param_reqs = {"rank": None, "dtype": None, "shape": None}
1828 error_result = False
1829 error_reason = "Input tensor size does not match output tensor size"
Jerry Ge264f7fa2023-04-21 22:49:57 +00001830 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001831
1832 if check:
1833 input_shape = kwargs["input_shape"]
1834 output_shape = kwargs["output_shape"]
Jerry Ge264f7fa2023-04-21 22:49:57 +00001835 shape_inferencing = False
1836 if -1 in output_shape and op["op"] == Op.RESHAPE:
1837 shape_inferencing = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001838 input_size = np.prod(input_shape)
1839 output_size = np.prod(output_shape)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001840 if input_size != output_size and not shape_inferencing:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001841 error_result = True
1842
1843 info_dict = {
1844 "error_name": error_name,
1845 "error_result": error_result,
1846 "error_reason": error_reason,
1847 "param_reqs": param_reqs,
1848 }
1849 return info_dict
1850
1851 @staticmethod
1852 def evStartSmallerZero(check=False, **kwargs):
1853 error_name = ErrorIf.StartSmallerZero
1854 param_reqs = {"rank": None, "dtype": None, "shape": None}
1855 error_result = False
1856 error_reason = "Starting point smaller than zero"
1857
1858 if check:
1859 input_shape = kwargs["input_shape"]
1860 start = kwargs["start"]
1861 rank = len(input_shape)
1862 if len(start) == rank:
1863 for index in range(rank):
1864 if start[index] < 0:
1865 error_result = True
1866
1867 info_dict = {
1868 "error_name": error_name,
1869 "error_result": error_result,
1870 "error_reason": error_reason,
1871 "param_reqs": param_reqs,
1872 }
1873 return info_dict
1874
1875 @staticmethod
1876 def evSizeSmallerEqualZero(check=False, **kwargs):
1877 error_name = ErrorIf.SizeSmallerEqualZero
1878 param_reqs = {"rank": None, "dtype": None, "shape": None}
1879 error_result = False
1880 error_reason = "Size smaller than or equal to zero"
1881
1882 if check:
1883 input_shape = kwargs["input_shape"]
1884 size = kwargs["size"]
1885 rank = len(input_shape)
1886 if len(size) == rank:
1887 for index in range(rank):
1888 if size[index] <= 0:
1889 error_result = True
1890
1891 info_dict = {
1892 "error_name": error_name,
1893 "error_result": error_result,
1894 "error_reason": error_reason,
1895 "param_reqs": param_reqs,
1896 }
1897 return info_dict
1898
1899 @staticmethod
1900 def evStartSizeOutsideBounds(check=False, **kwargs):
1901 error_name = ErrorIf.StartSizeOutsideBounds
1902 param_reqs = {"rank": None, "dtype": None, "shape": None}
1903 error_result = False
1904 error_reason = "starting point plus size larger than input dimension"
1905
1906 if check:
1907 input_shape = kwargs["input_shape"]
1908 start = kwargs["start"]
1909 size = kwargs["size"]
1910 rank = len(input_shape)
1911 if len(start) == rank and len(size) == rank:
1912 for index in range(rank):
1913 if start[index] + size[index] > input_shape[index]:
1914 error_result = True
1915
1916 info_dict = {
1917 "error_name": error_name,
1918 "error_result": error_result,
1919 "error_reason": error_reason,
1920 "param_reqs": param_reqs,
1921 }
1922 return info_dict
1923
1924 @staticmethod
1925 def evSizeOutputShapeMismatch(check=False, **kwargs):
1926 error_name = ErrorIf.SizeOutputShapeMismatch
1927 param_reqs = {"rank": None, "dtype": None, "shape": None}
1928 error_result = False
1929 error_reason = "Size does not match output dimension"
1930
1931 if check:
1932 input_shape = kwargs["input_shape"]
1933 output_shape = kwargs["output_shape"]
1934 size = kwargs["size"]
Luke Huttona4e48ca2023-02-22 11:53:48 +00001935
1936 if len(input_shape) == len(output_shape):
1937 rank = len(input_shape)
1938 if len(size) == rank:
1939 for index in range(rank):
1940 if size[index] != output_shape[index]:
1941 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001942
1943 info_dict = {
1944 "error_name": error_name,
1945 "error_result": error_result,
1946 "error_reason": error_reason,
1947 "param_reqs": param_reqs,
1948 }
1949 return info_dict
1950
1951 @staticmethod
1952 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1953 error_name = ErrorIf.InputSizeStartLengthMismatch
1954 param_reqs = {"rank": None, "dtype": None, "shape": None}
1955 error_result = False
1956 error_reason = "rank of input not equal to length of start or size"
1957
1958 if check:
1959 input_shape = kwargs["input_shape"]
1960 start = kwargs["start"]
1961 size = kwargs["size"]
1962 rank = len(input_shape)
1963 if rank != len(start) or rank != len(size):
1964 error_result = True
1965
1966 info_dict = {
1967 "error_name": error_name,
1968 "error_result": error_result,
1969 "error_reason": error_reason,
1970 "param_reqs": param_reqs,
1971 }
1972 return info_dict
1973
1974 @staticmethod
1975 def evIndexOutsideBounds(check=False, **kwargs):
1976 error_name = ErrorIf.IndexOutsideBounds
1977 param_reqs = {"rank": None, "dtype": None, "shape": None}
1978 error_result = False
1979 error_reason = "Index outside of allowed bounds"
1980
1981 if check:
1982 input_shape = kwargs["input_shape"]
1983 perms = kwargs["perms"]
1984 rank = len(input_shape)
1985
1986 for index in perms:
1987 if index < 0 or index > rank:
1988 error_result = True
1989
1990 info_dict = {
1991 "error_name": error_name,
1992 "error_result": error_result,
1993 "error_reason": error_reason,
1994 "param_reqs": param_reqs,
1995 }
1996 return info_dict
1997
1998 @staticmethod
1999 def evIndexUsedTwice(check=False, **kwargs):
2000 error_name = ErrorIf.IndexUsedTwice
2001 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2002 error_result = False
2003 error_reason = "Index used multiple times"
2004
2005 if check:
2006 perms = kwargs["perms"]
2007
2008 unique_indices = []
2009 for index in perms:
2010 if index in unique_indices:
2011 error_result = True
2012 else:
2013 unique_indices.append(index)
2014
2015 info_dict = {
2016 "error_name": error_name,
2017 "error_result": error_result,
2018 "error_reason": error_reason,
2019 "param_reqs": param_reqs,
2020 }
2021 return info_dict
2022
2023 @staticmethod
2024 def evMaxSmallerMin(check=False, **kwargs):
2025 error_name = ErrorIf.MaxSmallerMin
2026 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2027 error_result = False
2028 error_reason = "Max value smaller than min value"
2029
2030 if check:
2031 max_val = kwargs["max_val"]
2032 min_val = kwargs["min_val"]
2033 if max_val < min_val:
2034 error_result = True
2035
2036 info_dict = {
2037 "error_name": error_name,
2038 "error_result": error_result,
2039 "error_reason": error_reason,
2040 "param_reqs": param_reqs,
2041 }
2042 return info_dict
2043
2044 @staticmethod
2045 def evConcatInputRankMismatch(check=False, **kwargs):
2046 error_name = ErrorIf.ConcatInputRankMismatch
2047 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2048 error_result = False
2049 error_reason = "Input ranks are not identical"
2050
2051 if check:
2052 inputs = kwargs["inputs"]
2053 input_shape = kwargs["input_shape"]
2054 for input in inputs:
2055 if len(input.shape) != len(input_shape):
2056 error_result = True
2057
2058 info_dict = {
2059 "error_name": error_name,
2060 "error_result": error_result,
2061 "error_reason": error_reason,
2062 "param_reqs": param_reqs,
2063 }
2064 return info_dict
2065
2066 @staticmethod
2067 def evConcatInputDimMismatch(check=False, **kwargs):
2068 error_name = ErrorIf.ConcatInputDimMismatch
2069 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2070 error_result = False
2071 error_reason = "Input dimensions differ on too many axes"
2072
2073 if check:
2074 inputs = kwargs["inputs"]
2075 input_shape = kwargs["input_shape"]
2076 axis = kwargs["axis"]
2077
2078 # Ensure rank is valid before checking dims.
2079 valid_rank = True
2080 for input in inputs:
2081 if len(input.shape) != len(input_shape):
2082 valid_rank = False
2083
2084 if valid_rank:
2085 for input in inputs:
2086 for i, dim in enumerate(input.shape):
2087 if dim != input_shape[i] and axis != i:
2088 error_result = True
2089
2090 info_dict = {
2091 "error_name": error_name,
2092 "error_result": error_result,
2093 "error_reason": error_reason,
2094 "param_reqs": param_reqs,
2095 }
2096 return info_dict
2097
2098 @staticmethod
2099 def evConcatShapeSumMismatch(check=False, **kwargs):
2100 error_name = ErrorIf.ConcatShapeSumMismatch
2101 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2102 error_result = False
2103 error_reason = "Sum of dimensions on axis not equal to output dimension"
2104
2105 if check:
2106 inputs = kwargs["inputs"]
2107 input_shape = kwargs["input_shape"]
2108 output_shape = kwargs["output_shape"]
2109 axis = kwargs["axis"]
2110
2111 # Ensure rank is valid before checking dims.
2112 valid_params = True
2113 for input in inputs:
2114 if len(input.shape) != len(input_shape):
2115 valid_params = False
2116 if axis < 0 or axis > len(input_shape):
2117 valid_params = False
2118
2119 if valid_params:
2120 axis_dim_sum = 0
2121 for input in inputs:
2122 axis_dim_sum += input.shape[axis]
2123
2124 if axis_dim_sum != output_shape[axis]:
2125 error_result = True
2126
2127 info_dict = {
2128 "error_name": error_name,
2129 "error_result": error_result,
2130 "error_reason": error_reason,
2131 "param_reqs": param_reqs,
2132 }
2133 return info_dict
2134
2135 @staticmethod
2136 def evInputListThenGraphMismatch(check=False, **kwargs):
2137 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2138 param_reqs = {"rank": None, "dtype": None, "shape": None}
2139 error_result = False
2140 error_reason = "Input list shape does not match then-graph shape"
2141
2142 if check:
2143 a = kwargs["a"]
2144 b = kwargs["b"]
2145 basicBlocks = kwargs["basicBlocks"]
2146 then_block = basicBlocks[1]
2147 then_inputs = then_block.inputs
2148 then_tens = then_block.tensors
2149 if (a.shape != then_tens[then_inputs[0]].shape) or (
2150 b.shape != then_tens[then_inputs[1]].shape
2151 ):
2152 error_result = True
2153
2154 info_dict = {
2155 "error_name": error_name,
2156 "error_result": error_result,
2157 "error_reason": error_reason,
2158 "param_reqs": param_reqs,
2159 }
2160 return info_dict
2161
2162 @staticmethod
2163 def evInputListElseGraphMismatch(check=False, **kwargs):
2164 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2165 param_reqs = {"rank": None, "dtype": None, "shape": None}
2166 error_result = False
2167 error_reason = "Input list shape does not match else-graph shape"
2168
2169 if check:
2170 a = kwargs["a"]
2171 b = kwargs["b"]
2172 basicBlocks = kwargs["basicBlocks"]
2173 else_block = basicBlocks[2]
2174 else_inputs = else_block.inputs
2175 else_tens = else_block.tensors
2176 if (a.shape != else_tens[else_inputs[0]].shape) or (
2177 b.shape != else_tens[else_inputs[1]].shape
2178 ):
2179 error_result = True
2180
2181 info_dict = {
2182 "error_name": error_name,
2183 "error_result": error_result,
2184 "error_reason": error_reason,
2185 "param_reqs": param_reqs,
2186 }
2187 return info_dict
2188
2189 @staticmethod
2190 def evOutputListThenGraphMismatch(check=False, **kwargs):
2191 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2192 param_reqs = {"rank": None, "dtype": None, "shape": None}
2193 error_result = False
2194 error_reason = "Output list shape does not match then-graph shape"
2195
2196 if check:
2197 basicBlocks = kwargs["basicBlocks"]
2198 cond_block = basicBlocks[0]
2199 cond_outputs = cond_block.outputs
2200 cond_tens = cond_block.tensors
2201 then_block = basicBlocks[1]
2202 then_outputs = then_block.outputs
2203 then_tens = then_block.tensors
2204 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2205 error_result = True
2206
2207 info_dict = {
2208 "error_name": error_name,
2209 "error_result": error_result,
2210 "error_reason": error_reason,
2211 "param_reqs": param_reqs,
2212 }
2213 return info_dict
2214
2215 @staticmethod
2216 def evOutputListElseGraphMismatch(check=False, **kwargs):
2217 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2218 param_reqs = {"rank": None, "dtype": None, "shape": None}
2219 error_result = False
2220 error_reason = "Output list shape does not match else-graph shape"
2221
2222 if check:
2223 basicBlocks = kwargs["basicBlocks"]
2224 cond_block = basicBlocks[0]
2225 cond_outputs = cond_block.outputs
2226 cond_tens = cond_block.tensors
2227 else_block = basicBlocks[2]
2228 else_outputs = else_block.outputs
2229 else_tens = else_block.tensors
2230 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2231 error_result = True
2232
2233 info_dict = {
2234 "error_name": error_name,
2235 "error_result": error_result,
2236 "error_reason": error_reason,
2237 "param_reqs": param_reqs,
2238 }
2239 return info_dict
2240
2241 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002242 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2243 error_name = ErrorIf.CondIfCondNotMatchingBool
2244 param_reqs = {"rank": None, "dtype": None, "shape": None}
2245 error_result = False
2246 error_reason = "Conditional tensor does not match bool type"
2247
2248 if check:
2249 cond = kwargs["cond"]
2250 if cond.dtype != DType.BOOL:
2251 error_result = True
2252
2253 info_dict = {
2254 "error_name": error_name,
2255 "error_result": error_result,
2256 "error_reason": error_reason,
2257 "param_reqs": param_reqs,
2258 }
2259 return info_dict
2260
2261 @staticmethod
2262 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2263 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2264 param_reqs = {"rank": None, "dtype": None, "shape": None}
2265 error_result = False
2266 error_reason = "Conditional tensor is not equal to a size of one"
2267
2268 if check:
2269 cond = kwargs["cond"]
2270 # Size of 1 is equivalent to rank 0
2271 if len(cond.shape) != 0:
2272 error_result = True
2273
2274 info_dict = {
2275 "error_name": error_name,
2276 "error_result": error_result,
2277 "error_reason": error_reason,
2278 "param_reqs": param_reqs,
2279 }
2280 return info_dict
2281
2282 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002283 def evInputListOutputListMismatch(check=False, **kwargs):
2284 error_name = ErrorIf.InputListOutputListMismatch
2285 param_reqs = {"rank": None, "dtype": None, "shape": None}
2286 error_result = False
2287 error_reason = "Input list does not match output list"
2288
2289 if check:
2290 basicBlocks = kwargs["basicBlocks"]
2291 while_block = basicBlocks[0]
2292 while_inputs = while_block.inputs
2293 while_outputs = while_block.outputs
2294 while_tens = while_block.tensors
2295 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2296 error_result = True
2297
2298 info_dict = {
2299 "error_name": error_name,
2300 "error_result": error_result,
2301 "error_reason": error_reason,
2302 "param_reqs": param_reqs,
2303 }
2304 return info_dict
2305
2306 @staticmethod
2307 def evInputListCondGraphMismatch(check=False, **kwargs):
2308 error_name = ErrorIf.InputListCondGraphMismatch
2309 param_reqs = {"rank": None, "dtype": None, "shape": None}
2310 error_result = False
2311 error_reason = "Input list does not match cond graph"
2312
2313 if check:
2314 basicBlocks = kwargs["basicBlocks"]
2315 while_block = basicBlocks[0]
2316 while_inputs = while_block.inputs
2317 while_tens = while_block.tensors
2318 cond_block = basicBlocks[1]
2319 cond_inputs = cond_block.inputs
2320 cond_tens = cond_block.tensors
2321 if (
2322 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2323 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2324 error_result = True
2325
2326 info_dict = {
2327 "error_name": error_name,
2328 "error_result": error_result,
2329 "error_reason": error_reason,
2330 "param_reqs": param_reqs,
2331 }
2332 return info_dict
2333
2334 @staticmethod
2335 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2336 error_name = ErrorIf.InputListBodyGraphInputMismatch
2337 param_reqs = {"rank": None, "dtype": None, "shape": None}
2338 error_result = False
2339 error_reason = "Input list does not match body graph input"
2340
2341 if check:
2342 basicBlocks = kwargs["basicBlocks"]
2343 while_block = basicBlocks[0]
2344 while_inputs = while_block.inputs
2345 while_tens = while_block.tensors
2346 body_block = basicBlocks[2]
2347 body_outputs = body_block.inputs
2348 body_tens = body_block.tensors
2349 if (
2350 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2351 ) or (
2352 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2353 ):
2354 error_result = True
2355
2356 info_dict = {
2357 "error_name": error_name,
2358 "error_result": error_result,
2359 "error_reason": error_reason,
2360 "param_reqs": param_reqs,
2361 }
2362 return info_dict
2363
2364 @staticmethod
2365 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2366 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2367 param_reqs = {"rank": None, "dtype": None, "shape": None}
2368 error_result = False
2369 error_reason = "Input list does not match body graph output"
2370
2371 if check:
2372 basicBlocks = kwargs["basicBlocks"]
2373 while_block = basicBlocks[0]
2374 while_inputs = while_block.inputs
2375 while_tens = while_block.tensors
2376 body_block = basicBlocks[2]
2377 body_outputs = body_block.outputs
2378 body_tens = body_block.tensors
2379 if (
2380 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2381 ) or (
2382 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2383 ):
2384 error_result = True
2385 info_dict = {
2386 "error_name": error_name,
2387 "error_result": error_result,
2388 "error_reason": error_reason,
2389 "param_reqs": param_reqs,
2390 }
2391 return info_dict
2392
2393 @staticmethod
2394 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2395 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2396 param_reqs = {"rank": None, "dtype": None, "shape": None}
2397 error_result = False
2398 error_reason = "Cond graph output is not a match list of booleans"
2399
2400 if check:
2401 basicBlocks = kwargs["basicBlocks"]
2402 cond_block = basicBlocks[1]
2403 cond_outputs = cond_block.outputs
2404 cond_tens = cond_block.tensors
2405 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2406 error_result = True
2407
2408 info_dict = {
2409 "error_name": error_name,
2410 "error_result": error_result,
2411 "error_reason": error_reason,
2412 "param_reqs": param_reqs,
2413 }
2414 return info_dict
2415
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002416 @staticmethod
2417 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2418 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2419 param_reqs = {"rank": None, "dtype": None, "shape": None}
2420 error_result = False
2421 error_reason = "Cond graph output is not a shape of size one"
2422
2423 if check:
2424 basicBlocks = kwargs["basicBlocks"]
2425 cond_block = basicBlocks[1]
2426 cond_outputs = cond_block.outputs
2427 cond_tens = cond_block.tensors
2428 # Size of 1 is equivalent to rank 0
2429 if len(cond_tens[cond_outputs[0]].shape) != 0:
2430 error_result = True
2431
2432 info_dict = {
2433 "error_name": error_name,
2434 "error_result": error_result,
2435 "error_reason": error_reason,
2436 "param_reqs": param_reqs,
2437 }
2438 return info_dict
2439
Luke Hutton261b7b62023-01-10 14:50:31 +00002440 @staticmethod
2441 def evKernelNotPowerOfTwo(check=False, **kwargs):
2442 error_name = ErrorIf.KernelNotPowerOfTwo
2443 param_reqs = {"rank": None, "dtype": None, "shape": None}
2444 error_result = False
2445 error_reason = "kernel height and/or width not a power of two"
2446
2447 def is_power_of_two(x):
2448 return math.log(x, 2).is_integer()
2449
2450 if check:
2451 shape = kwargs["input_shape"]
2452 if len(shape) == 3:
2453 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2454 error_result = not valid_kernel
2455
2456 info_dict = {
2457 "error_name": error_name,
2458 "error_result": error_result,
2459 "error_reason": error_reason,
2460 "param_reqs": param_reqs,
2461 }
2462 return info_dict
2463
Luke Hutton57287132023-02-06 14:54:18 +00002464 @staticmethod
2465 def evFFTInputShapeMismatch(check=False, **kwargs):
2466 error_name = ErrorIf.FFTInputShapeMismatch
2467 param_reqs = {"rank": None, "dtype": None, "shape": None}
2468 error_result = False
2469 error_reason = "Mismatch between real and imaginary input shapes"
2470
2471 if check:
2472 input1 = kwargs["input1"]
2473 input2 = kwargs["input2"]
2474
2475 if input1.shape != input2.shape:
2476 error_result = True
2477
2478 info_dict = {
2479 "error_name": error_name,
2480 "error_result": error_result,
2481 "error_reason": error_reason,
2482 "param_reqs": param_reqs,
2483 }
2484 return info_dict
2485
2486 @staticmethod
2487 def evFFTOutputShapeMismatch(check=False, **kwargs):
2488 error_name = ErrorIf.FFTOutputShapeMismatch
2489 param_reqs = {"rank": None, "dtype": None, "shape": None}
2490 error_result = False
2491 error_reason = (
2492 "Mismatch between provided and expected output kernel (H, W) shape"
2493 )
2494
2495 if check:
2496 op = kwargs["op"]
2497 input_shape = kwargs["input_shape"]
2498
2499 if len(input_shape) == 3:
2500 output_shapes = kwargs["output_shape"]
2501
2502 # Ignoring batch size (N) from input shape
2503 expected_shape = input_shape[1:]
2504 if op["op"] == Op.RFFT2D:
2505 expected_shape[1] = expected_shape[1] // 2 + 1
2506
2507 # Ignoring batch size (N) from output shapes
2508 output_shape_0 = output_shapes[0][1:]
2509 output_shape_1 = output_shapes[1][1:]
2510 # Ensure sure the kernel sizes (H, W) of both outputs match the expected
2511 if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
2512 error_result = True
2513
2514 info_dict = {
2515 "error_name": error_name,
2516 "error_result": error_result,
2517 "error_reason": error_reason,
2518 "param_reqs": param_reqs,
2519 }
2520 return info_dict
2521
Jerry Ge264f7fa2023-04-21 22:49:57 +00002522 @staticmethod
2523 def evReshapeOutputSizeMultiInference(check=False, **kwargs):
2524 error_name = ErrorIf.ReshapeOutputSizeMultiInference
2525 param_reqs = {"rank": None, "dtype": None, "shape": None}
2526 error_result = False
2527 error_reason = "Reshape output tensor contains more than one inferred dimension"
2528
2529 if check:
2530 output_shape = kwargs["output_shape"]
2531 inferences = 0
2532 for dim in output_shape:
2533 if dim == -1:
2534 inferences += 1
2535 if inferences > 1:
2536 error_result = True
2537
2538 info_dict = {
2539 "error_name": error_name,
2540 "error_result": error_result,
2541 "error_reason": error_reason,
2542 "param_reqs": param_reqs,
2543 }
2544 return info_dict
2545
2546 @staticmethod
2547 def evReshapeOutputSizeNonInteger(check=False, **kwargs):
2548 error_name = ErrorIf.ReshapeOutputSizeNonInteger
2549 param_reqs = {"rank": None, "dtype": None, "shape": None}
2550 error_result = False
2551 error_reason = "Reshape inferred output tensor dimension is non-integer"
2552
2553 if check:
2554 input_shape = kwargs["input_shape"]
2555 output_shape = kwargs["output_shape"]
2556 input_size = np.prod(input_shape)
2557 output_size = 1
2558 for dim in output_shape:
2559 if dim != -1:
2560 output_size *= dim
2561 if -1 in output_shape and input_size % output_size != 0:
2562 error_result = True
2563
2564 info_dict = {
2565 "error_name": error_name,
2566 "error_result": error_result,
2567 "error_reason": error_reason,
2568 "param_reqs": param_reqs,
2569 }
2570 return info_dict
2571
Jerry Ge135c9552023-05-23 20:59:32 +00002572 @staticmethod
2573 def calculateBroadcastShape(input_shape_a, input_shape_b):
2574 if input_shape_a is not None and input_shape_b is not None:
2575 calculated_shape = input_shape_a.copy()
2576 for idx in range(len(calculated_shape)):
2577 if calculated_shape[idx] == 1:
2578 calculated_shape[idx] = input_shape_b[idx]
2579 elif (
2580 input_shape_b[idx] != 1
2581 and input_shape_b[idx] != calculated_shape[idx]
2582 ):
2583 return None
2584 return calculated_shape
2585 else:
2586 return None
2587
2588 @staticmethod
2589 def evBroadcastShapesMismatch(check=False, **kwargs):
2590 error_name = ErrorIf.BroadcastShapesMismatch
2591 param_reqs = {"rank": None, "dtype": None, "shape": None}
2592 error_result = False
2593 error_reason = "Broadcast shape calculating failed"
2594
2595 if check:
2596 input_shape_a = kwargs["input1"].shape
2597 input_shape_b = kwargs["input2"].shape
2598 input_shape_c = (
2599 kwargs["input3"].shape if "input3" in kwargs else input_shape_b
2600 )
2601
2602 if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
2603 calculated_shape = TosaErrorValidator.calculateBroadcastShape(
2604 input_shape_c,
2605 TosaErrorValidator.calculateBroadcastShape(
2606 input_shape_a, input_shape_b
2607 ),
2608 )
2609 error_result = calculated_shape is None
2610
2611 info_dict = {
2612 "error_name": error_name,
2613 "error_result": error_result,
2614 "error_reason": error_reason,
2615 "param_reqs": param_reqs,
2616 }
2617 return info_dict
2618
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002619
2620class TosaInvalidValidator:
2621 @staticmethod
2622 def ivWrongDataTypeOrModeResize(**kwargs):
2623 input_dtype = kwargs["input_dtype"]
2624 args = kwargs["args"]
2625 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002626 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002627
2628 if mode == ResizeMode.BILINEAR:
2629 # Invalid output data type / Invalid input datatype
2630 return (
2631 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002632 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002633 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002634 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002635 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002636 )
2637 elif mode == ResizeMode.NEAREST:
2638 # Invalid output data type / Invalid input datatype
2639 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002640 input_dtype
2641 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002642 )
2643 else:
2644 # Invalid resize mode
2645 return True
2646
2647 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002648 def ivHeightWidthInvalid(**kwargs):
2649 opName = kwargs["opName"]
2650
2651 inputShapes = kwargs["shapeList"]
2652 input_shape = inputShapes[0]
2653
2654 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002655
Jeremy Johnson0c716862023-04-13 17:18:19 +01002656 # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
2657 stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
2658
2659 # Common info for all ops
James Ward8b390432022-08-12 20:48:56 +01002660 strides = args[stride_idx]
2661 padding = args[pad_idx]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002662
2663 if opName.endswith("pool2d"):
2664 # avg_pool2d, max_pool2d
Jeremy Johnson0c716862023-04-13 17:18:19 +01002665 kernel_shape = args[pad_idx + 1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002666 h = (
2667 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2668 ) // strides[0]
2669 w = (
2670 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2671 ) // strides[1]
2672 # return True if any dimension is < 1
2673 return h < 1 or w < 1
2674
2675 if opName.startswith("transpose_conv2d"):
2676 # transpose_conv2d
Jeremy Johnson0c716862023-04-13 17:18:19 +01002677 output_shape = args[pad_idx + 1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002678 filter_shape = inputShapes[1]
2679 kernel_shape = filter_shape[1:-1]
2680
TatWai Chong24594f52022-06-08 00:48:04 -07002681 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002682 """Calculate the transpose_conv2d output size for a dimension."""
2683 return (in_size - 1) * stride + kernel_size + in_pad + out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002684
Jeremy Johnson0c716862023-04-13 17:18:19 +01002685 h = get_out_size(
2686 input_shape[1],
2687 strides[0],
2688 kernel_shape[0],
2689 padding[0],
2690 padding[1],
2691 )
2692 w = get_out_size(
2693 input_shape[2],
2694 strides[1],
2695 kernel_shape[1],
2696 padding[2],
2697 padding[3],
2698 )
2699 if output_shape[1] == h and output_shape[2] == w:
2700 return False
2701 # output shape does not match the expected shape
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002702 return True
2703
2704 if "conv2d" in opName or "conv3d" in opName:
2705 # conv2d, conv3d, depthwise_conv2d
Jeremy Johnson0c716862023-04-13 17:18:19 +01002706 dilations = args[pad_idx + 1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002707 filter_shape = inputShapes[1]
2708 kernel_shape = (
2709 filter_shape[0:2]
2710 if opName.startswith("depthwise_conv2d")
2711 else filter_shape[1:-1]
2712 )
2713
2714 for i in range(len(kernel_shape)):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002715 pad_offset = i * 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002716 dim = (
2717 input_shape[i + 1]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002718 - 1
2719 + padding[pad_offset]
2720 + padding[pad_offset + 1]
2721 - (kernel_shape[i] - 1) * dilations[i]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002722 ) // strides[i] + 1
2723 # return True if any dimension is < 1
2724 if dim < 1:
2725 return True
2726 return False
2727
2728 assert False, f"Unrecognized Op: {opName}"
2729
2730 @staticmethod
2731 def ivNonPositiveOutputShape(**kwargs):
2732 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002733 output_shape = args[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002734 if output_shape[1] <= 0 or output_shape[2] <= 0:
2735 # Negative output shape
2736 return True
2737 return False