blob: 93f975d112b8a05dd63026037e9362f77e5cc918 [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Luke Hutton261b7b62023-01-10 14:50:31 +00003import math
4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01006from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from generator.tosa_utils import product
8from generator.tosa_utils import usableDTypes
9from generator.tosa_utils import valueToName
10from tosa.DType import DType
11from tosa.Op import Op
12from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000013
Matthew Haddone86fd342021-09-07 16:12:21 +010014
15class ErrorIf(object):
16 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010017 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
18 ScaleNLargerMax = "ScaleNLargerMax"
19 ScaleDLargerMax = "ScaleDLargerMax"
20 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010021 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010022 BorderSmallerMin = "BorderSmallerMin"
23 BorderLargerEqualMax = "BorderLargerEqualMax"
24 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
25 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010026 WrongInputType = "WrongInputType"
27 WrongOutputType = "WrongOutputType"
28 WrongInputList = "WrongInputList"
29 WrongOutputList = "WrongOutputList"
30 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010031 BatchMismatch = "BatchMismatch"
32 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010033 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000034 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010036 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010037 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010038 AxisSmallerZero = "AxisSmallerZero"
39 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010040 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
41 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010042 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010043 KernelSmallerOne = "KernelSmallerOne"
44 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000045 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010046 PadSmallerZero = "PadSmallerZero"
47 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010048 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010049 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010050 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
51 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
52 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010053 ScaleNotTrue = "ScaleNotTrue"
54 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010055 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
56 StartSmallerZero = "StartSmallerZero"
57 SizeSmallerEqualZero = "SizeSmallerEqualZero"
58 StartSizeOutsideBounds = "StartSizeOutsideBounds"
59 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
60 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
61 IndexOutsideBounds = "IndexOutsideBounds"
62 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010063 MaxSmallerMin = "MaxSmallerMin"
64 ConcatInputRankMismatch = "ConcatInputRankMismatch"
65 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010066 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010067 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
68 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
69 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
70 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
71 InputListOutputListMismatch = "InputListOutputListMismatch"
72 InputListCondGraphMismatch = "InputListCondGraphMismatch"
73 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
74 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
75 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010076 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
77 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000078 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
79 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
80 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Luke Hutton261b7b62023-01-10 14:50:31 +000081 KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010082
83
84class TosaErrorIfArgGen:
85 @staticmethod
86 def eiResizeErrorIf(
87 testGen,
88 error_name,
89 mode,
90 dtype,
91 shapeList,
92 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010093 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010095 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010097 if error_name == ErrorIf.ScaleSmallerEqualZero:
98 index = testGen.randInt(low=0, high=4)
99 scale[index] = testGen.rng.choice([-2, -1, 0])
100 elif error_name == ErrorIf.ScaleNLargerMax:
101 index = testGen.rng.choice([0, 2])
102 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
103 elif error_name == ErrorIf.ScaleDLargerMax:
104 index = testGen.rng.choice([1, 3])
105 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100106
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100107 if error_name == ErrorIf.OffsetLargerEqualMax:
108 index = testGen.rng.choice([0, 1])
109 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
110 elif error_name == ErrorIf.OffsetSmallerMin:
111 index = testGen.rng.choice([0, 1])
112 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
113
114 if error_name == ErrorIf.BorderLargerEqualMax:
115 index = testGen.rng.choice([0, 1])
116 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
117 elif error_name == ErrorIf.BorderSmallerMin:
118 index = testGen.rng.choice([0, 1])
119 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100120
121 if error_name == ErrorIf.WrongOutputType:
122 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
123 incorrect_types = (
124 DType.INT4,
125 DType.INT16,
126 DType.INT32,
127 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100128 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100129 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100130 )
131 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
132 incorrect_types = (
133 DType.INT4,
134 DType.INT8,
135 DType.INT32,
136 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100137 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100138 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100139 )
140 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
141 incorrect_types = (
142 DType.INT4,
143 DType.INT8,
144 DType.INT16,
145 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100146 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100147 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100148 )
149 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
150 incorrect_types = (
151 DType.INT4,
152 DType.INT8,
153 DType.INT16,
154 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100155 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100156 DType.FP16,
157 )
158 elif dtype == DType.FP16:
159 incorrect_types = (
160 DType.INT4,
161 DType.INT8,
162 DType.INT16,
163 DType.INT32,
164 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100165 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100166 )
James Ward24dbc422022-10-19 12:20:31 +0100167 elif dtype == DType.BF16:
168 incorrect_types = (
169 DType.INT4,
170 DType.INT8,
171 DType.INT16,
172 DType.INT32,
173 DType.INT48,
174 DType.FP32,
175 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100176 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100177 incorrect_types = (
178 DType.INT4,
179 DType.INT8,
180 DType.INT16,
181 DType.INT32,
182 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100183 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100184 )
185 outputDType = testGen.rng.choice(a=incorrect_types)
186
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100187 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100188
189 @staticmethod
190 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
191 if (
192 error_name == ErrorIf.StrideSmallerOne
193 # padding must not exceed the kernel size
194 and pad[0] < kernel[0]
195 and pad[1] < kernel[0]
196 and pad[2] < kernel[1]
197 and pad[3] < kernel[1]
198 ):
199 wrongStride = (
200 testGen.rng.choice([0, -1, -2, -3]),
201 testGen.rng.choice([0, -1, -2, -3]),
202 )
203 return wrongStride, pad, kernel
204 elif error_name == ErrorIf.PadSmallerZero:
205 wrongPad = (
206 testGen.rng.choice([-1, -2, -3]),
207 testGen.rng.choice([-1, -2, -3]),
208 testGen.rng.choice([-1, -2, -3]),
209 testGen.rng.choice([-1, -2, -3]),
210 )
211 return stride, wrongPad, kernel
212 elif error_name == ErrorIf.KernelSmallerOne:
213 wrongKernel = (
214 testGen.rng.choice([0, -1, -2, -3]),
215 testGen.rng.choice([0, -1, -2, -3]),
216 )
217 return stride, pad, wrongKernel
218 elif error_name == ErrorIf.PadLargerEqualKernel:
219 wrongPad = (
220 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
221 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
222 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
223 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
224 )
225 return stride, wrongPad, kernel
226 else:
227 return None, None, None
228
229 @staticmethod
230 def eiRescaleWrongOutputType(input_dtype, output_dtype):
231 if input_dtype == DType.INT8:
232 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
233 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100234 elif input_dtype == DType.INT16:
235 if output_dtype not in [
236 DType.UINT8,
237 DType.INT8,
238 DType.UINT16,
239 DType.INT16,
240 DType.INT32,
241 ]:
242 return True
243 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100244 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
245 return True
246 elif input_dtype == DType.INT48:
247 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
248 return True
249 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100250 if output_dtype not in [DType.INT8, DType.INT16]:
251 return True
252 elif input_dtype == DType.UINT16:
253 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100254 return True
255 return False
256
257 @staticmethod
258 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
259 # Mess up input/output tensors for ERROR_IF checks
260 if error_name == "WrongInputList":
261 add_input = testGen.rng.choice([True, False])
262 if add_input:
263 input_list.append("eiDummyInput")
264 else:
265 input_list = input_list[:-1]
266 elif error_name == "WrongOutputList":
267 add_output = testGen.rng.choice([True, False])
268 if add_output:
269 output_list.append("eiDummyOutput")
270 else:
271 output_list = []
272 return input_list, output_list
273
274 @staticmethod
275 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
276 """Restrict the dimensions and overall size of a shape to
277 max_dim and max_items.
278 """
279 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
280 while product(new_shape) > max_items:
281 new_shape = [max(d - 1, 1) for d in new_shape]
282 return new_shape
283
284 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
285 if error_name == ErrorIf.StartSmallerZero:
286 newStart = []
287 for i in range(len(input_shape)):
288 newStart.append(testGen.rng.choice([-3, -2, -1]))
289 return newStart, size
290 elif error_name == ErrorIf.SizeSmallerEqualZero:
291 newSize = []
292 for i in range(len(input_shape)):
293 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
294 return start, newSize
295 elif error_name == ErrorIf.StartSizeOutsideBounds:
296 newStart, newSize = [], []
297 for i in range(len(input_shape)):
298 newStart.append(input_shape[i] - 1)
299 newSize.append(testGen.rng.choice([2, 3, 4]))
300 return newStart, newSize
301 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
302 remove = testGen.rng.choice([True, False])
303 if remove:
304 newStart = start[1:]
305 newSize = size[1:]
306 else:
307 newStart = start
308 newStart.append(1)
309 newSize = size
310 newSize.append(1)
311 return newStart, newSize
312 else:
313 return start, size
314
315 @staticmethod
316 def eiCastErrorIf(testGen, input_dtype):
James Ward736fd1a2023-01-23 17:13:37 +0000317 if input_dtype in [DType.BOOL, DType.FP32]:
318 outputDType = [DType.BOOL, DType.INT48, DType.FP32]
319 elif input_dtype in [DType.FP16, DType.BF16]:
320 outputDType = [DType.BOOL, DType.INT48]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100321 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
322 outputDType = [DType.INT48]
323 else:
James Ward736fd1a2023-01-23 17:13:37 +0000324 assert False, f"input_dtype ({input_dtype}) not supported"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100325 return outputDType
326
327
328class TosaErrorValidator:
329 @staticmethod
330 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
331 """Check ERROR_IF statements are caught and set the expected result.
332
333 Args:
334 serializer: the serializer to set the expected result in
335 validator_fcns: a sequence of validator functions to verify the result
336 error_name: the name of the ERROR_IF condition to check for
337 kwargs: keyword arguments for the validator functions
338 Returns:
339 True if the result matches the expected result; otherwise False
340 """
341 overall_result = True
342 for val_fcn in validator_fcns:
343 val_result = val_fcn(True, **kwargs)
344 validator_name = val_result["error_name"]
345 error_result = val_result["error_result"]
346 error_reason = val_result["error_reason"]
347
348 # expect an error IFF the error_name and validator_name match
349 expected_result = error_result == (error_name == validator_name)
350 overall_result &= expected_result
351
352 if expected_result and error_result:
353 serializer.setExpectedReturnCode(2, True, desc=error_reason)
354 elif error_result: # and not expected_result
355 print(
356 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
357 f" Expected: {error_name}, Got: {validator_name}"
358 )
359 elif not expected_result: # and not error_result
360 print(
361 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
362 f" Expected: {error_name}"
363 )
364
365 if not expected_result:
366 for k, v in sorted(kwargs.items()):
367 if k != "op":
368 if k.endswith("dtype"):
369 v = valueToName(DType, v)
370 print(f" {k} = {v}")
371
372 return overall_result
373
374 @staticmethod
375 def evWrongInputType(check=False, **kwargs):
376 error_result = False
377
378 # Find the unsupported input data types
379 op = kwargs["op"]
380 input_dtypes = op["types"]
381 allowed_input_dtypes = {
382 t[0] if isinstance(t, list) else t for t in input_dtypes
383 }
384 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
385
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100386 # Turn the wrong dtypes into required list of types
387 if op["op"] in [
388 Op.FULLY_CONNECTED,
389 Op.CONV2D,
390 Op.CONV3D,
391 Op.DEPTHWISE_CONV2D,
392 Op.TRANSPOSE_CONV2D,
393 ]:
394 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
395
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100396 if op["op"] == Op.CLAMP:
397 wrong_input_dtypes.remove(DType.INT48)
398
399 if check:
400 input_dtype = kwargs["input_dtype"]
401 if input_dtype not in allowed_input_dtypes:
402 error_result = True
403
404 info_dict = {
405 "error_name": ErrorIf.WrongInputType,
406 "error_result": error_result,
407 "error_reason": "Input data type not supported for this operator",
408 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
409 }
410 return info_dict
411
412 @staticmethod
413 def evWrongOutputType(check=False, **kwargs):
414 error_result = False
415
416 if check:
417 input_dtype = kwargs["input_dtype"]
418 output_dtype = kwargs["output_dtype"]
419 op = kwargs["op"]
420
421 if op["op"] == Op.RESIZE:
422 mode = kwargs["mode"]
423 if (
424 (
425 mode == ResizeMode.NEAREST
426 and input_dtype == DType.INT8
427 and output_dtype != DType.INT8
428 )
429 or (
430 mode == ResizeMode.NEAREST
431 and input_dtype == DType.INT16
432 and output_dtype != DType.INT16
433 )
434 or (
435 mode == ResizeMode.BILINEAR
436 and input_dtype == DType.INT8
437 and output_dtype != DType.INT32
438 )
439 or (
440 mode == ResizeMode.BILINEAR
441 and input_dtype == DType.INT16
442 and output_dtype != DType.INT48
443 )
James Ward8b390432022-08-12 20:48:56 +0100444 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100445 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100446 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100447 ):
448 error_result = True
449
450 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100451 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
452 input_dtype, output_dtype
453 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100454
455 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
456 if (
457 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
458 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100459 or (
460 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100461 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100462 )
James Ward24dbc422022-10-19 12:20:31 +0100463 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100464 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100465 ):
466 error_result = True
467
468 elif op["op"] == Op.ARGMAX:
469 if (
James Ward24dbc422022-10-19 12:20:31 +0100470 input_dtype
471 in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100472 and output_dtype != DType.INT32
473 ):
474 error_result = True
475
476 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100477 if (
James Ward24dbc422022-10-19 12:20:31 +0100478 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100479 and output_dtype != DType.INT32
480 ):
481 error_result = True
482 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100483 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100484 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
485 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100486 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100487 error_result = True
488
489 elif op["op"] == Op.TABLE:
490 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
491 error_result = True
492 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
493 error_result = True
494
495 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
496 if output_dtype != DType.BOOL:
497 error_result = True
498
499 elif op["op"] == Op.CAST:
500 if (
501 (
502 input_dtype == DType.BOOL
503 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
504 )
505 or (
506 input_dtype == DType.INT8
507 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100508 not in [
509 DType.BOOL,
510 DType.INT16,
511 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100512 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100513 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100514 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100515 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100516 )
517 or (
518 input_dtype == DType.INT16
519 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100520 not in [
521 DType.BOOL,
522 DType.INT8,
523 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100524 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100525 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100526 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100527 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100528 )
529 or (
530 input_dtype == DType.INT32
531 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100532 not in [
533 DType.BOOL,
534 DType.INT8,
535 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100536 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100537 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100538 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100539 ]
540 )
541 or (
542 input_dtype == DType.FP16
James Ward736fd1a2023-01-23 17:13:37 +0000543 and output_dtype
544 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100545 )
546 or (
James Ward24dbc422022-10-19 12:20:31 +0100547 input_dtype == DType.BF16
James Ward736fd1a2023-01-23 17:13:37 +0000548 and output_dtype
549 not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +0100550 )
551 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100552 input_dtype == DType.FP32
James Ward736fd1a2023-01-23 17:13:37 +0000553 and output_dtype
554 not in [
555 DType.INT8,
556 DType.INT16,
557 DType.INT32,
558 DType.FP16,
559 DType.BF16,
560 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100561 )
562 ):
563 error_result = True
564
Luke Hutton261b7b62023-01-10 14:50:31 +0000565 elif op["op"] == Op.RFFT2D:
566 if not all([ty == input_dtype for ty in output_dtype]):
567 error_result = True
568
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100569 elif op["op"] in {
570 Op.CONV2D,
571 Op.CONV3D,
572 Op.DEPTHWISE_CONV2D,
573 Op.TRANSPOSE_CONV2D,
574 }:
575 if (
576 input_dtype == DType.INT8
577 and output_dtype != DType.INT32
578 or input_dtype == DType.INT16
579 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100580 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100581 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100582 or input_dtype == DType.BF16
583 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100584 or input_dtype == DType.FP32
585 and output_dtype != DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100586 ):
587 error_result = True
588 # invalid input types are ignored, to avoid reporting multiple errors
589
590 else:
591 if output_dtype != input_dtype:
592 error_result = True
593
594 info_dict = {
595 "error_name": ErrorIf.WrongOutputType,
596 "error_result": error_result,
597 "error_reason": (
598 "Output data type not supported for this configuration of operator"
599 ),
600 "param_reqs": {"rank": None, "dtype": None, "shape": None},
601 }
602 return info_dict
603
604 @staticmethod
605 def evWrongRank(check=False, **kwargs):
606 all_ranks = (1, 2, 3, 4, 5)
607
608 # Make a list of incorrect ranks
609 assert "op" in kwargs
610 op = kwargs["op"]
611 rmin, rmax = op["rank"]
612 rank_range = range(rmin, rmax + 1)
613 incorrect_ranks = list(set(all_ranks) - set(rank_range))
614 # Remove small incorrect ranks to avoid index errors
615 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
616 # Set minimum incorrect rank to 3 to avoid index error
617 if op["op"] in [Op.RESIZE]:
618 incorrect_ranks = [3, 5]
619 elif op["op"] in [Op.TRANSPOSE]:
620 incorrect_ranks = [7, 8]
621 elif op["op"] in [Op.CONV3D]:
622 incorrect_ranks = [6, 7]
623
624 error_name = ErrorIf.WrongRank
625 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
626 error_result = False
627 error_reason = "Rank not supported for this operator"
628
629 if check:
630 input_shape = kwargs["input_shape"]
631
632 if (
633 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
634 and len(input_shape) != 4
635 ):
636 error_result = True
637 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
638 error_result = True
639 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
640 error_result = True
641 else:
642 if len(input_shape) not in rank_range:
643 error_result = True
644
645 info_dict = {
646 "error_name": error_name,
647 "error_result": error_result,
648 "error_reason": error_reason,
649 "param_reqs": param_reqs,
650 }
651 return info_dict
652
653 @staticmethod
654 def evWrongInputList(check=False, **kwargs):
655 error_name = ErrorIf.WrongInputList
656 param_reqs = {"rank": None, "dtype": None, "shape": None}
657 error_result = False
658 error_reason = "Op input list does not match expected input"
659
660 if check:
661 op = kwargs["op"]
662 input_list = kwargs["input_list"]
663 num_operands = kwargs["num_operands"]
664 if op["op"] in [Op.SCATTER, Op.GATHER]:
665 # SCATTER/GATHER add an indices input tensor in their build functions
666 num_operands += 1
667 if len(input_list) != num_operands:
668 error_result = True
669
670 info_dict = {
671 "error_name": error_name,
672 "error_result": error_result,
673 "error_reason": error_reason,
674 "param_reqs": param_reqs,
675 }
676 return info_dict
677
678 @staticmethod
679 def evWrongOutputList(check=False, **kwargs):
680 error_name = ErrorIf.WrongOutputList
681 param_reqs = {"rank": None, "dtype": None, "shape": None}
682 error_result = False
683 error_reason = "Op output list does not match expected output"
684
685 if check:
Luke Hutton261b7b62023-01-10 14:50:31 +0000686 op = kwargs["op"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100687 output_list = kwargs["output_list"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000688 expected_length = 1
689 if op["op"] == Op.RFFT2D:
690 expected_length = 2
691
692 if len(output_list) != expected_length:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100693 error_result = True
694
695 info_dict = {
696 "error_name": error_name,
697 "error_result": error_result,
698 "error_reason": error_reason,
699 "param_reqs": param_reqs,
700 }
701 return info_dict
702
703 @staticmethod
704 def evMaxDimExceeded(check=False, **kwargs):
705 error_name = ErrorIf.MaxDimExceeded
706 param_reqs = {
707 "rank": [4, 4],
708 "dtype": [DType.INT8],
709 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
710 }
711 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100712 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100713
714 if check:
715 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100716 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100717 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100718 (input_shape[1] >= MAX_RESIZE_DIMENSION)
719 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
720 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
721 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100722 ):
723 error_result = True
724
725 info_dict = {
726 "error_name": error_name,
727 "error_result": error_result,
728 "error_reason": error_reason,
729 "param_reqs": param_reqs,
730 }
731 return info_dict
732
733 @staticmethod
734 def evBatchMismatch(check=False, **kwargs):
735 error_name = ErrorIf.BatchMismatch
Luke Hutton261b7b62023-01-10 14:50:31 +0000736 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100737 error_result = False
738 error_reason = "Input batch size not equal to output batch size"
739
740 assert "op" in kwargs
741 op = kwargs["op"]
742 rmin, rmax = op["rank"]
743 rank_range = range(rmin, rmax + 1)
744
745 if check:
746 input_shape = kwargs["input_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100747
Luke Hutton261b7b62023-01-10 14:50:31 +0000748 for output in kwargs["result_tensors"]:
749 output_shape = (
750 output.shape
751 ) # Note batch is expected to be the first dim
752 if (len(input_shape) in rank_range) and (
753 input_shape[0] != output_shape[0]
754 ):
755 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100756
757 info_dict = {
758 "error_name": error_name,
759 "error_result": error_result,
760 "error_reason": error_reason,
761 "param_reqs": param_reqs,
762 }
763 return info_dict
764
765 @staticmethod
766 def evChannelMismatch(check=False, **kwargs):
767 error_name = ErrorIf.ChannelMismatch
768 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
769 error_result = False
770 error_reason = "Input channel size not equal to output channel size"
771
772 assert "op" in kwargs
773 op = kwargs["op"]
774 rmin, rmax = op["rank"]
775 rank_range = range(rmin, rmax + 1)
776
777 if check:
778 input_shape = kwargs["input_shape"]
Luke Hutton261b7b62023-01-10 14:50:31 +0000779 for output in kwargs["result_tensors"]:
780 output_shape = output.shape # Note this is just (N, OH, OW, C)
781 if (len(input_shape) in rank_range) and (
782 input_shape[3] != output_shape[3]
783 ):
784 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100785
786 info_dict = {
787 "error_name": error_name,
788 "error_result": error_result,
789 "error_reason": error_reason,
790 "param_reqs": param_reqs,
791 }
792 return info_dict
793
794 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100795 def evScaleSmallerEqualZero(check=False, **kwargs):
796 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100797 param_reqs = {"rank": None, "dtype": None, "shape": None}
798 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100799 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100800
801 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100802 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100803
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100804 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100805 error_result = True
806
807 info_dict = {
808 "error_name": error_name,
809 "error_result": error_result,
810 "error_reason": error_reason,
811 "param_reqs": param_reqs,
812 }
813 return info_dict
814
815 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100816 def evScaleNLargerMax(check=False, **kwargs):
817 error_name = ErrorIf.ScaleNLargerMax
818 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100819 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100820 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100821
822 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100823 scale = kwargs["scale"]
824
825 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
826 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100827
828 info_dict = {
829 "error_name": error_name,
830 "error_result": error_result,
831 "error_reason": error_reason,
832 "param_reqs": param_reqs,
833 }
834 return info_dict
835
836 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100837 def evScaleDLargerMax(check=False, **kwargs):
838 error_name = ErrorIf.ScaleDLargerMax
839 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100840 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100841 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100842
843 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100844 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100845
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100846 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
847 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100848 ):
849 error_result = True
850
851 info_dict = {
852 "error_name": error_name,
853 "error_result": error_result,
854 "error_reason": error_reason,
855 "param_reqs": param_reqs,
856 }
857 return info_dict
858
859 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100860 def evOffsetSmallerMin(check=False, **kwargs):
861 error_name = ErrorIf.OffsetSmallerMin
862 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100863 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100864 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100865
866 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100867 scale = kwargs["scale"]
868 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100869
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100870 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100871 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100872 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100873 error_result = True
874
875 info_dict = {
876 "error_name": error_name,
877 "error_result": error_result,
878 "error_reason": error_reason,
879 "param_reqs": param_reqs,
880 }
881 return info_dict
882
883 @staticmethod
884 def evOffsetLargerEqualMax(check=False, **kwargs):
885 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100886 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100887 error_result = False
888 error_reason = "Offset value larger than or equal to maximum value"
889
890 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100891 scale = kwargs["scale"]
892 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100893
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100894 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
895 error_result = True
896 elif (
897 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
898 ):
899 error_result = True
900
901 info_dict = {
902 "error_name": error_name,
903 "error_result": error_result,
904 "error_reason": error_reason,
905 "param_reqs": param_reqs,
906 }
907 return info_dict
908
909 @staticmethod
910 def evBorderSmallerMin(check=False, **kwargs):
911 error_name = ErrorIf.BorderSmallerMin
912 param_reqs = {"rank": None, "dtype": None, "shape": None}
913 error_result = False
914 error_reason = "Border value smaller than minimum value"
915
916 if check:
917 scale = kwargs["scale"]
918 border = kwargs["border"]
919
920 if (
921 scale[0] > 0
922 and scale[0] <= (1 << 11)
923 and (border[0] < (-16 * scale[0]))
924 ):
925 error_result = True
926 elif (
927 scale[2] > 0
928 and scale[2] <= (1 << 11)
929 and (border[1] < (-16 * scale[2]))
930 ):
931 error_result = True
932
933 info_dict = {
934 "error_name": error_name,
935 "error_result": error_result,
936 "error_reason": error_reason,
937 "param_reqs": param_reqs,
938 }
939 return info_dict
940
941 @staticmethod
942 def evBorderLargerEqualMax(check=False, **kwargs):
943 error_name = ErrorIf.BorderLargerEqualMax
944 param_reqs = {"rank": None, "dtype": None, "shape": None}
945 error_result = False
946 error_reason = "Border value larger than or equal to maximum value"
947
948 if check:
949 scale = kwargs["scale"]
950 border = kwargs["border"]
951
952 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
953 error_result = True
954 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
955 error_result = True
956
957 info_dict = {
958 "error_name": error_name,
959 "error_result": error_result,
960 "error_reason": error_reason,
961 "param_reqs": param_reqs,
962 }
963 return info_dict
964
965 @staticmethod
966 def checkResizeParams(scale, offset, border):
967 return (
968 min(scale) > 0
969 and max(scale[0], scale[2]) <= (1 << 11)
970 and scale[1] < 16 * scale[0]
971 and scale[3] < 16 * scale[2]
972 and offset[0] >= -scale[0]
973 and offset[1] >= -scale[2]
974 and offset[0] < 16 * scale[0]
975 and offset[1] < 16 * scale[2]
976 and border[0] >= -16 * scale[0]
977 and border[1] >= -16 * scale[2]
978 and border[0] < scale[0]
979 and border[1] < scale[2]
980 )
981
982 @staticmethod
983 def evResizeOutputShapeMismatch(check=False, **kwargs):
984 error_name = ErrorIf.ResizeOutputShapeMismatch
985 param_reqs = {"rank": None, "dtype": None, "shape": None}
986 error_result = False
987 error_reason = (
988 "Mismatch between output shape provided and expected output shape"
989 )
990
991 if check:
992 input_shape = kwargs["input_shape"]
993 output_shape = kwargs["output_shape"]
994 scale = kwargs["scale"]
995 offset = kwargs["offset"]
996 border = kwargs["border"]
997
998 # Ensure parameters are valid
999 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
1000
1001 if (
1002 params_valid
1003 and max(output_shape) < MAX_RESIZE_DIMENSION
1004 and max(input_shape) < MAX_RESIZE_DIMENSION
1005 ):
1006 output_y = (
1007 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1008 ) // scale[1] + 1
1009 output_x = (
1010 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1011 ) // scale[3] + 1
1012
1013 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001014 error_result = True
1015
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001016 info_dict = {
1017 "error_name": error_name,
1018 "error_result": error_result,
1019 "error_reason": error_reason,
1020 "param_reqs": param_reqs,
1021 }
1022 return info_dict
1023
1024 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001025 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1026 error_name = ErrorIf.ResizeOutputShapeNonInteger
1027 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001028 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001029 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001030
1031 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001032 input_shape = kwargs["input_shape"]
1033 scale = kwargs["scale"]
1034 offset = kwargs["offset"]
1035 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001037 # Ensure parameters are valid
1038 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001039
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001040 if params_valid:
1041 remainder_y = (
1042 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1043 ) % scale[1]
1044 remainder_x = (
1045 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1046 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001047
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001048 if max(remainder_y, remainder_x) > 0:
1049 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001050
1051 info_dict = {
1052 "error_name": error_name,
1053 "error_result": error_result,
1054 "error_reason": error_reason,
1055 "param_reqs": param_reqs,
1056 }
1057 return info_dict
1058
1059 @staticmethod
1060 def evRankMismatch(check=False, **kwargs):
1061 error_name = ErrorIf.RankMismatch
1062 param_reqs = {"rank": None, "dtype": None, "shape": None}
1063 error_result = False
1064 error_reason = "Input Rank does not match output rank"
1065
1066 if check:
1067 input1_shape = kwargs["input1"].shape
1068 input2_shape = kwargs["input2"].shape
1069 # In case of SELECT op
1070 input3_shape = (
1071 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1072 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001073
1074 for output in kwargs["result_tensors"]:
1075 output_shape = output.shape
1076 if (
1077 (len(input1_shape) != len(output_shape))
1078 or (len(input2_shape) != len(output_shape))
1079 or (len(input3_shape) != len(output_shape))
1080 ):
1081 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001082
1083 info_dict = {
1084 "error_name": error_name,
1085 "error_result": error_result,
1086 "error_reason": error_reason,
1087 "param_reqs": param_reqs,
1088 }
1089 return info_dict
1090
1091 @staticmethod
1092 def evDimensionMismatch(check=False, **kwargs):
1093 error_name = ErrorIf.DimensionMismatch
1094 param_reqs = {"rank": None, "dtype": None, "shape": None}
1095 error_result = False
1096 error_reason = "Input Dimensions do not match output"
1097
1098 if check:
1099 input1_shape = kwargs["input1"].shape
1100 input2_shape = kwargs["input2"].shape
1101 # In case of SELECT op
1102 input3_shape = (
1103 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1104 )
Luke Hutton261b7b62023-01-10 14:50:31 +00001105
1106 for output in kwargs["result_tensors"]:
1107 output_shape = output.shape
1108 for i in range(
1109 min(len(input1_shape), len(input2_shape), len(input3_shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001110 ):
Luke Hutton261b7b62023-01-10 14:50:31 +00001111 if (
1112 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
1113 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
1114 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
1115 ):
1116 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001117
1118 info_dict = {
1119 "error_name": error_name,
1120 "error_result": error_result,
1121 "error_reason": error_reason,
1122 "param_reqs": param_reqs,
1123 }
1124 return info_dict
1125
1126 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001127 def _getZeroPoint(qinfo, index):
1128 """Return zero point value from quantization info.
1129
1130 Generally input_zp is index 0, output_zp is index 1
1131 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001132 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001133
1134 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001135 def evInputZeroPointNotZero(check=False, **kwargs):
1136 op = kwargs["op"]
1137 error_result = False
1138
1139 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001140 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001141
1142 # This does not apply to quantizable types
1143 inputDtypes = [
1144 dtype
1145 for dtype in op["types"]
1146 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1147 or (not isinstance(dtype, list) and dtype not in qTypes)
1148 ]
1149
1150 if check:
1151 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001152 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001153 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001154 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001156 (kwargs["input_dtype"], input_zero_point),
1157 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001158 ):
1159 if dtype not in qTypes and zp != 0:
1160 error_result = True
1161 break
1162 else:
1163 error_result = input_dtype not in qTypes and input_zero_point != 0
1164
1165 info_dict = {
1166 "error_name": ErrorIf.InputZeroPointNotZero,
1167 "error_result": error_result,
1168 "error_reason": "Input DType not INT8 and zero point not 0",
1169 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1170 }
1171 return info_dict
1172
1173 @staticmethod
1174 def evWeightZeroPointNotZero(check=False, **kwargs):
1175 op = kwargs["op"]
1176
1177 # exclude inputs with INT8 weights
1178 inputDtypes = [
1179 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1180 ]
1181
1182 error_name = ErrorIf.WeightZeroPointNotZero
1183 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1184 error_result = False
1185 error_reason = "Weight DType not INT8 and zero point not 0"
1186
1187 if check:
1188 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001189 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001190 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1191 error_result = True
1192
1193 info_dict = {
1194 "error_name": error_name,
1195 "error_result": error_result,
1196 "error_reason": error_reason,
1197 "param_reqs": param_reqs,
1198 }
1199 return info_dict
1200
1201 @staticmethod
1202 def evOutputZeroPointNotZero(check=False, **kwargs):
1203 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001204 inputDtypes = [
1205 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1206 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001207
1208 error_name = ErrorIf.OutputZeroPointNotZero
1209 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1210 error_result = False
1211 error_reason = "Output DType not INT8 and zero point not 0"
1212
1213 if check:
1214 input_dtype = kwargs["input_dtype"]
1215 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001216 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001217 if op["op"] == Op.AVG_POOL2D:
1218 if input_dtype != DType.INT8 and output_zero_point != 0:
1219 error_result = True
1220 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001221 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1222 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001223 ):
1224 error_result = True
1225
1226 info_dict = {
1227 "error_name": error_name,
1228 "error_result": error_result,
1229 "error_reason": error_reason,
1230 "param_reqs": param_reqs,
1231 }
1232 return info_dict
1233
1234 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001235 def evU16InputZeroPointNotValid(check=False, **kwargs):
1236 error_name = ErrorIf.U16InputZeroPointNotValid
1237 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1238 error_result = False
1239 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1240
1241 if check:
1242 input_dtype = kwargs["input_dtype"]
1243 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1244 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1245 0,
1246 32768,
1247 ]
1248
1249 info_dict = {
1250 "error_name": error_name,
1251 "error_result": error_result,
1252 "error_reason": error_reason,
1253 "param_reqs": param_reqs,
1254 }
1255 return info_dict
1256
1257 @staticmethod
1258 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1259 error_name = ErrorIf.U16OutputZeroPointNotValid
1260 param_reqs = {"rank": None, "dtype": None, "shape": None}
1261 error_result = False
1262 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1263
1264 if check:
1265 output_dtype = kwargs["output_dtype"]
1266 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1267
1268 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1269 0,
1270 32768,
1271 ]
1272
1273 info_dict = {
1274 "error_name": error_name,
1275 "error_result": error_result,
1276 "error_reason": error_reason,
1277 "param_reqs": param_reqs,
1278 }
1279 return info_dict
1280
1281 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001282 def evAxisSmallerZero(check=False, **kwargs):
1283 error_name = ErrorIf.AxisSmallerZero
1284 param_reqs = {"rank": None, "dtype": None, "shape": None}
1285 error_result = False
1286 error_reason = "Axis smaller than zero"
1287
1288 if check:
1289 axis = kwargs["axis"]
1290 if axis < 0:
1291 error_result = True
1292
1293 info_dict = {
1294 "error_name": error_name,
1295 "error_result": error_result,
1296 "error_reason": error_reason,
1297 "param_reqs": param_reqs,
1298 }
1299 return info_dict
1300
1301 @staticmethod
1302 def evAxisLargerRank(check=False, **kwargs):
1303 error_name = ErrorIf.AxisLargerRank
1304 param_reqs = {"rank": None, "dtype": None, "shape": None}
1305 error_result = False
1306 error_reason = "Axis larger than rank"
1307
1308 if check:
1309 axis = kwargs["axis"]
1310 shape = kwargs["input_shape"]
1311 if axis > len(shape):
1312 error_result = True
1313
1314 info_dict = {
1315 "error_name": error_name,
1316 "error_result": error_result,
1317 "error_reason": error_reason,
1318 "param_reqs": param_reqs,
1319 }
1320 return info_dict
1321
1322 @staticmethod
1323 def evShapeOfAxisNotOne(check=False, **kwargs):
1324 error_name = ErrorIf.ShapeOfAxisNotOne
1325 param_reqs = {"rank": None, "dtype": None, "shape": None}
1326 error_result = False
1327 error_reason = "shape[axis] is not equal to 1"
1328
1329 if check:
1330 axis = kwargs["axis"]
1331 shape = kwargs["output_shape"]
1332 if (0 <= axis < len(shape)) and shape[axis] != 1:
1333 error_result = True
1334
1335 info_dict = {
1336 "error_name": error_name,
1337 "error_result": error_result,
1338 "error_reason": error_reason,
1339 "param_reqs": param_reqs,
1340 }
1341 return info_dict
1342
1343 @staticmethod
1344 def evPadSmallerZero(check=False, **kwargs):
1345 error_name = ErrorIf.PadSmallerZero
1346 param_reqs = {"rank": None, "dtype": None, "shape": None}
1347 error_result = False
1348 error_reason = "At least one pad is smaller than zero"
1349
1350 if check:
1351 op = kwargs["op"]
1352 pad = kwargs["pad"]
1353 if op["op"] == Op.PAD:
1354 for padding in pad:
1355 if min(padding) < 0:
1356 error_result = True
1357 else:
1358 if min(pad) < 0:
1359 error_result = True
1360
1361 info_dict = {
1362 "error_name": error_name,
1363 "error_result": error_result,
1364 "error_reason": error_reason,
1365 "param_reqs": param_reqs,
1366 }
1367 return info_dict
1368
1369 @staticmethod
1370 def evPadLargerEqualKernel(check=False, **kwargs):
1371 error_name = ErrorIf.PadLargerEqualKernel
1372 param_reqs = {"rank": None, "dtype": None, "shape": None}
1373 error_result = False
1374 error_reason = "At least one pad is larger than kernel dimension"
1375
1376 if check:
1377 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001378 op = kwargs["op"]
1379 if op["op"] == Op.TRANSPOSE_CONV2D:
1380 # transpose_conv2d
1381 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001382 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001383 pad[0] <= -kernel[0]
1384 or pad[1] <= -kernel[0]
1385 or pad[2] <= -kernel[1]
1386 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001387 ):
1388 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001389 else:
1390 # pooling op
1391 kernel = kwargs["kernel"]
1392 if min(pad) > 0 and min(kernel) > 1:
1393 if (
1394 pad[0] >= kernel[0]
1395 or pad[1] >= kernel[0]
1396 or pad[2] >= kernel[1]
1397 or pad[3] >= kernel[1]
1398 ):
1399 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001400
1401 info_dict = {
1402 "error_name": error_name,
1403 "error_result": error_result,
1404 "error_reason": error_reason,
1405 "param_reqs": param_reqs,
1406 }
1407 return info_dict
1408
1409 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001410 def evPadOutputShapeMismatch(check=False, **kwargs):
1411 error_name = ErrorIf.PadOutputShapeMismatch
1412 param_reqs = {"rank": None, "dtype": None, "shape": None}
1413 error_result = False
1414 error_reason = "Pad output shape mismatch for requested padding"
1415
1416 if check:
1417 pad = kwargs["pad"]
1418 input_shape = kwargs["input_shape"]
1419 output_shape = kwargs["output_shape"]
1420 for dim, padding in enumerate(pad):
1421 expected_size = input_shape[dim] + padding[0] + padding[1]
1422 if expected_size != output_shape[dim]:
1423 error_result = True
1424
1425 info_dict = {
1426 "error_name": error_name,
1427 "error_result": error_result,
1428 "error_reason": error_reason,
1429 "param_reqs": param_reqs,
1430 }
1431 return info_dict
1432
1433 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001434 def checkPoolingParams(kernel, stride, pad):
1435 return (
1436 min(kernel) >= 1
1437 and min(stride) >= 1
1438 and min(pad) >= 0
1439 and not (
1440 pad[0] >= kernel[0]
1441 or pad[1] >= kernel[0]
1442 or pad[2] >= kernel[1]
1443 or pad[3] >= kernel[1]
1444 )
1445 )
1446
1447 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001448 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1449 error_name = ErrorIf.PoolingOutputShapeMismatch
1450 param_reqs = {"rank": None, "dtype": None, "shape": None}
1451 error_result = False
1452 error_reason = (
1453 "Mismatch between output shape provided and expected output shape"
1454 )
1455
1456 if check:
1457 pad = kwargs["pad"]
1458 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1459
1460 kernel = kwargs["kernel"]
1461 kernel_y, kernel_x = kernel[0], kernel[1]
1462
1463 input_shape = kwargs["input_shape"]
1464 IH, IW = input_shape[1], input_shape[2]
1465
1466 output_shape = kwargs["output_shape"]
1467 OH, OW = output_shape[1], output_shape[2]
1468
1469 stride = kwargs["stride"]
1470 stride_y, stride_x = stride[0], stride[1]
1471
1472 # calculate correct height, width dimensions
1473 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001474 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1475 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001476
1477 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001478 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001479
1480 if params_valid and (OH != y_correct or OW != x_correct):
1481 error_result = True
1482
1483 info_dict = {
1484 "error_name": error_name,
1485 "error_result": error_result,
1486 "error_reason": error_reason,
1487 "param_reqs": param_reqs,
1488 }
1489 return info_dict
1490
1491 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001492 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1493 error_name = ErrorIf.PoolingOutputShapeNonInteger
1494 param_reqs = {"rank": None, "dtype": None, "shape": None}
1495 error_result = False
1496 error_reason = "Parameters do not yield exact integer output dimensions"
1497
1498 if check:
1499 pad = kwargs["pad"]
1500 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1501
1502 kernel = kwargs["kernel"]
1503 kernel_y, kernel_x = kernel[0], kernel[1]
1504
1505 input_shape = kwargs["input_shape"]
1506 IH, IW = input_shape[1], input_shape[2]
1507
1508 stride = kwargs["stride"]
1509 stride_y, stride_x = stride[0], stride[1]
1510
1511 # calculate remainder of height, width dimensions
1512 if stride_x != 0 and stride_y != 0:
1513 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1514 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1515
1516 # ensure parameters are valid
1517 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1518 if params_valid and (y_remainder != 0 or x_remainder != 0):
1519 error_result = True
1520
1521 info_dict = {
1522 "error_name": error_name,
1523 "error_result": error_result,
1524 "error_reason": error_reason,
1525 "param_reqs": param_reqs,
1526 }
1527 return info_dict
1528
1529 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001530 def checkConvParams(op, weight_shape, stride, pad, dilation):
1531 if op == Op.TRANSPOSE_CONV2D:
1532 pad_ok = (
1533 pad[0] > -weight_shape[1]
1534 and pad[1] > -weight_shape[1]
1535 and pad[2] > -weight_shape[2]
1536 and pad[3] > -weight_shape[2]
1537 )
1538 else:
1539 pad_ok = min(pad) >= 0
1540
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001541 return (
1542 # Check kernel sizes
1543 min(weight_shape[1:-1]) >= 1
1544 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001545 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001546 and (dilation is None or min(dilation) >= 1)
1547 )
1548
1549 @staticmethod
1550 def evConvOutputShapeMismatch(check=False, **kwargs):
1551 error_name = ErrorIf.ConvOutputShapeMismatch
1552 param_reqs = {"rank": None, "dtype": None, "shape": None}
1553 error_result = False
1554 error_reason = (
1555 "Mismatch between output shape provided and expected output shape"
1556 )
1557
1558 if check:
1559 op = kwargs["op"]
1560 pad = kwargs["pad"]
1561 weight_shape = kwargs["weight_shape"]
1562 input_shape = kwargs["input_shape"]
1563 output_shape = kwargs["output_shape"]
1564 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1565 stride = kwargs["stride"]
1566
1567 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1568
1569 # calculate correct dimensions
1570 dims_correct = []
1571 if min(stride) > 0:
1572 for index in range(len(stride)):
1573 pad_offset = index * 2
1574 if op["op"] == Op.TRANSPOSE_CONV2D:
1575 dims_correct.append(
1576 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001577 + pad[pad_offset]
1578 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001579 + weight_shape[index + kernel_offset]
1580 )
1581 else:
1582 dims_correct.append(
1583 (
1584 input_shape[index + 1]
1585 - 1
1586 + pad[pad_offset]
1587 + pad[pad_offset + 1]
1588 - (weight_shape[index + kernel_offset] - 1)
1589 * dilation[index]
1590 )
1591 // stride[index]
1592 + 1
1593 )
1594
1595 # ensure parameters are valid
1596 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001597 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001598 )
1599
1600 if params_valid and output_shape[1:-1] != dims_correct:
1601 error_result = True
1602
1603 info_dict = {
1604 "error_name": error_name,
1605 "error_result": error_result,
1606 "error_reason": error_reason,
1607 "param_reqs": param_reqs,
1608 }
1609 return info_dict
1610
1611 @staticmethod
1612 def evConvOutputShapeNonInteger(check=False, **kwargs):
1613 error_name = ErrorIf.ConvOutputShapeNonInteger
1614 param_reqs = {"rank": None, "dtype": None, "shape": None}
1615 error_result = False
1616 error_reason = "Parameters do not yield exact integer output dimensions"
1617
1618 if check:
1619 op = kwargs["op"]
1620 pad = kwargs["pad"]
1621 weight_shape = kwargs["weight_shape"]
1622 input_shape = kwargs["input_shape"]
1623 dilation = kwargs["dilation"]
1624 stride = kwargs["stride"]
1625
1626 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1627
1628 # calculate correct height, width dimensions
1629 remainders = []
1630 if min(stride) > 0:
1631 for index in range(len(stride)):
1632 pad_offset = index * 2
1633 remainders.append(
1634 (
1635 input_shape[index + 1]
1636 - 1
1637 + pad[pad_offset]
1638 + pad[pad_offset + 1]
1639 - (weight_shape[index + kernel_offset] - 1)
1640 * dilation[index]
1641 )
1642 % stride[index]
1643 )
1644
1645 # ensure parameters are valid
1646 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001647 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001648 )
1649 if params_valid and max(remainders) > 0:
1650 error_result = True
1651
1652 info_dict = {
1653 "error_name": error_name,
1654 "error_result": error_result,
1655 "error_reason": error_reason,
1656 "param_reqs": param_reqs,
1657 }
1658 return info_dict
1659
1660 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001661 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1662 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1663 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1664 error_result = False
1665 error_reason = (
1666 "Mismatch between output shape provided and expected output shape"
1667 )
1668
1669 if check:
1670 output_shape = kwargs["output_shape"]
1671 input_shape = kwargs["input_shape"]
1672 axis = kwargs["axis"]
1673
1674 dimension_match = True
1675 axis_shift = 0
1676
1677 # Check that rank is correct before trying to check dimensions
1678 if (len(input_shape) - 1) == len(output_shape):
1679 for i in range(len(input_shape)):
1680 if i == axis:
1681 axis_shift = 1
1682 continue
1683 if input_shape[i] != output_shape[i - axis_shift]:
1684 dimension_match = False
1685
1686 if not dimension_match:
1687 error_result = True
1688
1689 info_dict = {
1690 "error_name": error_name,
1691 "error_result": error_result,
1692 "error_reason": error_reason,
1693 "param_reqs": param_reqs,
1694 }
1695 return info_dict
1696
1697 @staticmethod
1698 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1699 error_name = ErrorIf.ArgmaxOutputRankMismatch
1700 param_reqs = {"rank": None, "dtype": None, "shape": None}
1701 error_result = False
1702 error_reason = (
1703 "Mismatch between output shape provided and expected output shape"
1704 )
1705
1706 if check:
1707 output_shape = kwargs["output_shape"]
1708 input_shape = kwargs["input_shape"]
1709 axis = kwargs["axis"]
1710 valid_params = axis >= 0 and axis < len(input_shape)
1711
1712 if valid_params and (len(input_shape) - 1) != len(output_shape):
1713 error_result = True
1714
1715 info_dict = {
1716 "error_name": error_name,
1717 "error_result": error_result,
1718 "error_reason": error_reason,
1719 "param_reqs": param_reqs,
1720 }
1721 return info_dict
1722
1723 @staticmethod
1724 def evKernelSmallerOne(check=False, **kwargs):
1725 error_name = ErrorIf.KernelSmallerOne
1726 param_reqs = {"rank": None, "dtype": None, "shape": None}
1727 error_result = False
1728 error_reason = "At least one kernel dimension is smaller than zero"
1729
1730 if check:
1731 kernel = kwargs["kernel"]
1732 if min(kernel) < 1:
1733 error_result = True
1734
1735 info_dict = {
1736 "error_name": error_name,
1737 "error_result": error_result,
1738 "error_reason": error_reason,
1739 "param_reqs": param_reqs,
1740 }
1741 return info_dict
1742
1743 @staticmethod
1744 def evStrideSmallerOne(check=False, **kwargs):
1745 error_name = ErrorIf.StrideSmallerOne
1746 param_reqs = {"rank": None, "dtype": None, "shape": None}
1747 error_result = False
1748 error_reason = "At least one stride dimension is smaller than zero"
1749
1750 if check:
1751 stride = kwargs["stride"]
1752 if min(stride) < 1:
1753 error_result = True
1754
1755 info_dict = {
1756 "error_name": error_name,
1757 "error_result": error_result,
1758 "error_reason": error_reason,
1759 "param_reqs": param_reqs,
1760 }
1761 return info_dict
1762
1763 @staticmethod
1764 def evDilationSmallerOne(check=False, **kwargs):
1765 error_result = check and min(kwargs["dilation"]) < 1
1766 return {
1767 "error_name": ErrorIf.DilationSmallerOne,
1768 "error_reason": "At least one dilation is smaller than one",
1769 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1770 "error_result": error_result,
1771 }
1772
1773 @staticmethod
1774 def evScaleTrue(check=False, **kwargs):
1775 error_name = ErrorIf.ScaleTrue
1776 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1777 error_result = False
1778 error_reason = "Scale set to true but input type is INT48"
1779
1780 if check:
1781 input_dtype = kwargs["input_dtype"]
1782 scale32 = kwargs["scale32"]
1783 if scale32 and input_dtype == DType.INT48:
1784 error_result = True
1785
1786 info_dict = {
1787 "error_name": error_name,
1788 "error_result": error_result,
1789 "error_reason": error_reason,
1790 "param_reqs": param_reqs,
1791 }
1792 return info_dict
1793
1794 @staticmethod
1795 def evScaleNotTrue(check=False, **kwargs):
1796 error_name = ErrorIf.ScaleNotTrue
1797 param_reqs = {"rank": None, "dtype": None, "shape": None}
1798 error_result = False
1799 error_reason = "Scale set to false but double round set to true"
1800
1801 if check:
1802 scale32 = kwargs["scale32"]
1803 double_round = kwargs["double_round"]
1804 if not scale32 and double_round:
1805 error_result = True
1806
1807 info_dict = {
1808 "error_name": error_name,
1809 "error_result": error_result,
1810 "error_reason": error_reason,
1811 "param_reqs": param_reqs,
1812 }
1813 return info_dict
1814
1815 @staticmethod
1816 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1817 error_name = ErrorIf.TensorSizeInputOutputMismatch
1818 param_reqs = {"rank": None, "dtype": None, "shape": None}
1819 error_result = False
1820 error_reason = "Input tensor size does not match output tensor size"
1821
1822 if check:
1823 input_shape = kwargs["input_shape"]
1824 output_shape = kwargs["output_shape"]
1825 input_size = np.prod(input_shape)
1826 output_size = np.prod(output_shape)
1827 if input_size != output_size:
1828 error_result = True
1829
1830 info_dict = {
1831 "error_name": error_name,
1832 "error_result": error_result,
1833 "error_reason": error_reason,
1834 "param_reqs": param_reqs,
1835 }
1836 return info_dict
1837
1838 @staticmethod
1839 def evStartSmallerZero(check=False, **kwargs):
1840 error_name = ErrorIf.StartSmallerZero
1841 param_reqs = {"rank": None, "dtype": None, "shape": None}
1842 error_result = False
1843 error_reason = "Starting point smaller than zero"
1844
1845 if check:
1846 input_shape = kwargs["input_shape"]
1847 start = kwargs["start"]
1848 rank = len(input_shape)
1849 if len(start) == rank:
1850 for index in range(rank):
1851 if start[index] < 0:
1852 error_result = True
1853
1854 info_dict = {
1855 "error_name": error_name,
1856 "error_result": error_result,
1857 "error_reason": error_reason,
1858 "param_reqs": param_reqs,
1859 }
1860 return info_dict
1861
1862 @staticmethod
1863 def evSizeSmallerEqualZero(check=False, **kwargs):
1864 error_name = ErrorIf.SizeSmallerEqualZero
1865 param_reqs = {"rank": None, "dtype": None, "shape": None}
1866 error_result = False
1867 error_reason = "Size smaller than or equal to zero"
1868
1869 if check:
1870 input_shape = kwargs["input_shape"]
1871 size = kwargs["size"]
1872 rank = len(input_shape)
1873 if len(size) == rank:
1874 for index in range(rank):
1875 if size[index] <= 0:
1876 error_result = True
1877
1878 info_dict = {
1879 "error_name": error_name,
1880 "error_result": error_result,
1881 "error_reason": error_reason,
1882 "param_reqs": param_reqs,
1883 }
1884 return info_dict
1885
1886 @staticmethod
1887 def evStartSizeOutsideBounds(check=False, **kwargs):
1888 error_name = ErrorIf.StartSizeOutsideBounds
1889 param_reqs = {"rank": None, "dtype": None, "shape": None}
1890 error_result = False
1891 error_reason = "starting point plus size larger than input dimension"
1892
1893 if check:
1894 input_shape = kwargs["input_shape"]
1895 start = kwargs["start"]
1896 size = kwargs["size"]
1897 rank = len(input_shape)
1898 if len(start) == rank and len(size) == rank:
1899 for index in range(rank):
1900 if start[index] + size[index] > input_shape[index]:
1901 error_result = True
1902
1903 info_dict = {
1904 "error_name": error_name,
1905 "error_result": error_result,
1906 "error_reason": error_reason,
1907 "param_reqs": param_reqs,
1908 }
1909 return info_dict
1910
1911 @staticmethod
1912 def evSizeOutputShapeMismatch(check=False, **kwargs):
1913 error_name = ErrorIf.SizeOutputShapeMismatch
1914 param_reqs = {"rank": None, "dtype": None, "shape": None}
1915 error_result = False
1916 error_reason = "Size does not match output dimension"
1917
1918 if check:
1919 input_shape = kwargs["input_shape"]
1920 output_shape = kwargs["output_shape"]
1921 size = kwargs["size"]
1922 rank = len(input_shape)
1923 if len(size) == rank:
1924 for index in range(rank):
1925 if size[index] != output_shape[index]:
1926 error_result = True
1927
1928 info_dict = {
1929 "error_name": error_name,
1930 "error_result": error_result,
1931 "error_reason": error_reason,
1932 "param_reqs": param_reqs,
1933 }
1934 return info_dict
1935
1936 @staticmethod
1937 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1938 error_name = ErrorIf.InputSizeStartLengthMismatch
1939 param_reqs = {"rank": None, "dtype": None, "shape": None}
1940 error_result = False
1941 error_reason = "rank of input not equal to length of start or size"
1942
1943 if check:
1944 input_shape = kwargs["input_shape"]
1945 start = kwargs["start"]
1946 size = kwargs["size"]
1947 rank = len(input_shape)
1948 if rank != len(start) or rank != len(size):
1949 error_result = True
1950
1951 info_dict = {
1952 "error_name": error_name,
1953 "error_result": error_result,
1954 "error_reason": error_reason,
1955 "param_reqs": param_reqs,
1956 }
1957 return info_dict
1958
1959 @staticmethod
1960 def evIndexOutsideBounds(check=False, **kwargs):
1961 error_name = ErrorIf.IndexOutsideBounds
1962 param_reqs = {"rank": None, "dtype": None, "shape": None}
1963 error_result = False
1964 error_reason = "Index outside of allowed bounds"
1965
1966 if check:
1967 input_shape = kwargs["input_shape"]
1968 perms = kwargs["perms"]
1969 rank = len(input_shape)
1970
1971 for index in perms:
1972 if index < 0 or index > rank:
1973 error_result = True
1974
1975 info_dict = {
1976 "error_name": error_name,
1977 "error_result": error_result,
1978 "error_reason": error_reason,
1979 "param_reqs": param_reqs,
1980 }
1981 return info_dict
1982
1983 @staticmethod
1984 def evIndexUsedTwice(check=False, **kwargs):
1985 error_name = ErrorIf.IndexUsedTwice
1986 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1987 error_result = False
1988 error_reason = "Index used multiple times"
1989
1990 if check:
1991 perms = kwargs["perms"]
1992
1993 unique_indices = []
1994 for index in perms:
1995 if index in unique_indices:
1996 error_result = True
1997 else:
1998 unique_indices.append(index)
1999
2000 info_dict = {
2001 "error_name": error_name,
2002 "error_result": error_result,
2003 "error_reason": error_reason,
2004 "param_reqs": param_reqs,
2005 }
2006 return info_dict
2007
2008 @staticmethod
2009 def evMaxSmallerMin(check=False, **kwargs):
2010 error_name = ErrorIf.MaxSmallerMin
2011 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2012 error_result = False
2013 error_reason = "Max value smaller than min value"
2014
2015 if check:
2016 max_val = kwargs["max_val"]
2017 min_val = kwargs["min_val"]
2018 if max_val < min_val:
2019 error_result = True
2020
2021 info_dict = {
2022 "error_name": error_name,
2023 "error_result": error_result,
2024 "error_reason": error_reason,
2025 "param_reqs": param_reqs,
2026 }
2027 return info_dict
2028
2029 @staticmethod
2030 def evConcatInputRankMismatch(check=False, **kwargs):
2031 error_name = ErrorIf.ConcatInputRankMismatch
2032 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2033 error_result = False
2034 error_reason = "Input ranks are not identical"
2035
2036 if check:
2037 inputs = kwargs["inputs"]
2038 input_shape = kwargs["input_shape"]
2039 for input in inputs:
2040 if len(input.shape) != len(input_shape):
2041 error_result = True
2042
2043 info_dict = {
2044 "error_name": error_name,
2045 "error_result": error_result,
2046 "error_reason": error_reason,
2047 "param_reqs": param_reqs,
2048 }
2049 return info_dict
2050
2051 @staticmethod
2052 def evConcatInputDimMismatch(check=False, **kwargs):
2053 error_name = ErrorIf.ConcatInputDimMismatch
2054 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2055 error_result = False
2056 error_reason = "Input dimensions differ on too many axes"
2057
2058 if check:
2059 inputs = kwargs["inputs"]
2060 input_shape = kwargs["input_shape"]
2061 axis = kwargs["axis"]
2062
2063 # Ensure rank is valid before checking dims.
2064 valid_rank = True
2065 for input in inputs:
2066 if len(input.shape) != len(input_shape):
2067 valid_rank = False
2068
2069 if valid_rank:
2070 for input in inputs:
2071 for i, dim in enumerate(input.shape):
2072 if dim != input_shape[i] and axis != i:
2073 error_result = True
2074
2075 info_dict = {
2076 "error_name": error_name,
2077 "error_result": error_result,
2078 "error_reason": error_reason,
2079 "param_reqs": param_reqs,
2080 }
2081 return info_dict
2082
2083 @staticmethod
2084 def evConcatShapeSumMismatch(check=False, **kwargs):
2085 error_name = ErrorIf.ConcatShapeSumMismatch
2086 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2087 error_result = False
2088 error_reason = "Sum of dimensions on axis not equal to output dimension"
2089
2090 if check:
2091 inputs = kwargs["inputs"]
2092 input_shape = kwargs["input_shape"]
2093 output_shape = kwargs["output_shape"]
2094 axis = kwargs["axis"]
2095
2096 # Ensure rank is valid before checking dims.
2097 valid_params = True
2098 for input in inputs:
2099 if len(input.shape) != len(input_shape):
2100 valid_params = False
2101 if axis < 0 or axis > len(input_shape):
2102 valid_params = False
2103
2104 if valid_params:
2105 axis_dim_sum = 0
2106 for input in inputs:
2107 axis_dim_sum += input.shape[axis]
2108
2109 if axis_dim_sum != output_shape[axis]:
2110 error_result = True
2111
2112 info_dict = {
2113 "error_name": error_name,
2114 "error_result": error_result,
2115 "error_reason": error_reason,
2116 "param_reqs": param_reqs,
2117 }
2118 return info_dict
2119
2120 @staticmethod
2121 def evInputListThenGraphMismatch(check=False, **kwargs):
2122 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2123 param_reqs = {"rank": None, "dtype": None, "shape": None}
2124 error_result = False
2125 error_reason = "Input list shape does not match then-graph shape"
2126
2127 if check:
2128 a = kwargs["a"]
2129 b = kwargs["b"]
2130 basicBlocks = kwargs["basicBlocks"]
2131 then_block = basicBlocks[1]
2132 then_inputs = then_block.inputs
2133 then_tens = then_block.tensors
2134 if (a.shape != then_tens[then_inputs[0]].shape) or (
2135 b.shape != then_tens[then_inputs[1]].shape
2136 ):
2137 error_result = True
2138
2139 info_dict = {
2140 "error_name": error_name,
2141 "error_result": error_result,
2142 "error_reason": error_reason,
2143 "param_reqs": param_reqs,
2144 }
2145 return info_dict
2146
2147 @staticmethod
2148 def evInputListElseGraphMismatch(check=False, **kwargs):
2149 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2150 param_reqs = {"rank": None, "dtype": None, "shape": None}
2151 error_result = False
2152 error_reason = "Input list shape does not match else-graph shape"
2153
2154 if check:
2155 a = kwargs["a"]
2156 b = kwargs["b"]
2157 basicBlocks = kwargs["basicBlocks"]
2158 else_block = basicBlocks[2]
2159 else_inputs = else_block.inputs
2160 else_tens = else_block.tensors
2161 if (a.shape != else_tens[else_inputs[0]].shape) or (
2162 b.shape != else_tens[else_inputs[1]].shape
2163 ):
2164 error_result = True
2165
2166 info_dict = {
2167 "error_name": error_name,
2168 "error_result": error_result,
2169 "error_reason": error_reason,
2170 "param_reqs": param_reqs,
2171 }
2172 return info_dict
2173
2174 @staticmethod
2175 def evOutputListThenGraphMismatch(check=False, **kwargs):
2176 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2177 param_reqs = {"rank": None, "dtype": None, "shape": None}
2178 error_result = False
2179 error_reason = "Output list shape does not match then-graph shape"
2180
2181 if check:
2182 basicBlocks = kwargs["basicBlocks"]
2183 cond_block = basicBlocks[0]
2184 cond_outputs = cond_block.outputs
2185 cond_tens = cond_block.tensors
2186 then_block = basicBlocks[1]
2187 then_outputs = then_block.outputs
2188 then_tens = then_block.tensors
2189 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2190 error_result = True
2191
2192 info_dict = {
2193 "error_name": error_name,
2194 "error_result": error_result,
2195 "error_reason": error_reason,
2196 "param_reqs": param_reqs,
2197 }
2198 return info_dict
2199
2200 @staticmethod
2201 def evOutputListElseGraphMismatch(check=False, **kwargs):
2202 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2203 param_reqs = {"rank": None, "dtype": None, "shape": None}
2204 error_result = False
2205 error_reason = "Output list shape does not match else-graph shape"
2206
2207 if check:
2208 basicBlocks = kwargs["basicBlocks"]
2209 cond_block = basicBlocks[0]
2210 cond_outputs = cond_block.outputs
2211 cond_tens = cond_block.tensors
2212 else_block = basicBlocks[2]
2213 else_outputs = else_block.outputs
2214 else_tens = else_block.tensors
2215 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2216 error_result = True
2217
2218 info_dict = {
2219 "error_name": error_name,
2220 "error_result": error_result,
2221 "error_reason": error_reason,
2222 "param_reqs": param_reqs,
2223 }
2224 return info_dict
2225
2226 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002227 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2228 error_name = ErrorIf.CondIfCondNotMatchingBool
2229 param_reqs = {"rank": None, "dtype": None, "shape": None}
2230 error_result = False
2231 error_reason = "Conditional tensor does not match bool type"
2232
2233 if check:
2234 cond = kwargs["cond"]
2235 if cond.dtype != DType.BOOL:
2236 error_result = True
2237
2238 info_dict = {
2239 "error_name": error_name,
2240 "error_result": error_result,
2241 "error_reason": error_reason,
2242 "param_reqs": param_reqs,
2243 }
2244 return info_dict
2245
2246 @staticmethod
2247 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2248 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2249 param_reqs = {"rank": None, "dtype": None, "shape": None}
2250 error_result = False
2251 error_reason = "Conditional tensor is not equal to a size of one"
2252
2253 if check:
2254 cond = kwargs["cond"]
2255 # Size of 1 is equivalent to rank 0
2256 if len(cond.shape) != 0:
2257 error_result = True
2258
2259 info_dict = {
2260 "error_name": error_name,
2261 "error_result": error_result,
2262 "error_reason": error_reason,
2263 "param_reqs": param_reqs,
2264 }
2265 return info_dict
2266
2267 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002268 def evInputListOutputListMismatch(check=False, **kwargs):
2269 error_name = ErrorIf.InputListOutputListMismatch
2270 param_reqs = {"rank": None, "dtype": None, "shape": None}
2271 error_result = False
2272 error_reason = "Input list does not match output list"
2273
2274 if check:
2275 basicBlocks = kwargs["basicBlocks"]
2276 while_block = basicBlocks[0]
2277 while_inputs = while_block.inputs
2278 while_outputs = while_block.outputs
2279 while_tens = while_block.tensors
2280 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2281 error_result = True
2282
2283 info_dict = {
2284 "error_name": error_name,
2285 "error_result": error_result,
2286 "error_reason": error_reason,
2287 "param_reqs": param_reqs,
2288 }
2289 return info_dict
2290
2291 @staticmethod
2292 def evInputListCondGraphMismatch(check=False, **kwargs):
2293 error_name = ErrorIf.InputListCondGraphMismatch
2294 param_reqs = {"rank": None, "dtype": None, "shape": None}
2295 error_result = False
2296 error_reason = "Input list does not match cond graph"
2297
2298 if check:
2299 basicBlocks = kwargs["basicBlocks"]
2300 while_block = basicBlocks[0]
2301 while_inputs = while_block.inputs
2302 while_tens = while_block.tensors
2303 cond_block = basicBlocks[1]
2304 cond_inputs = cond_block.inputs
2305 cond_tens = cond_block.tensors
2306 if (
2307 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2308 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2309 error_result = True
2310
2311 info_dict = {
2312 "error_name": error_name,
2313 "error_result": error_result,
2314 "error_reason": error_reason,
2315 "param_reqs": param_reqs,
2316 }
2317 return info_dict
2318
2319 @staticmethod
2320 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2321 error_name = ErrorIf.InputListBodyGraphInputMismatch
2322 param_reqs = {"rank": None, "dtype": None, "shape": None}
2323 error_result = False
2324 error_reason = "Input list does not match body graph input"
2325
2326 if check:
2327 basicBlocks = kwargs["basicBlocks"]
2328 while_block = basicBlocks[0]
2329 while_inputs = while_block.inputs
2330 while_tens = while_block.tensors
2331 body_block = basicBlocks[2]
2332 body_outputs = body_block.inputs
2333 body_tens = body_block.tensors
2334 if (
2335 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2336 ) or (
2337 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2338 ):
2339 error_result = True
2340
2341 info_dict = {
2342 "error_name": error_name,
2343 "error_result": error_result,
2344 "error_reason": error_reason,
2345 "param_reqs": param_reqs,
2346 }
2347 return info_dict
2348
2349 @staticmethod
2350 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2351 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2352 param_reqs = {"rank": None, "dtype": None, "shape": None}
2353 error_result = False
2354 error_reason = "Input list does not match body graph output"
2355
2356 if check:
2357 basicBlocks = kwargs["basicBlocks"]
2358 while_block = basicBlocks[0]
2359 while_inputs = while_block.inputs
2360 while_tens = while_block.tensors
2361 body_block = basicBlocks[2]
2362 body_outputs = body_block.outputs
2363 body_tens = body_block.tensors
2364 if (
2365 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2366 ) or (
2367 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2368 ):
2369 error_result = True
2370 info_dict = {
2371 "error_name": error_name,
2372 "error_result": error_result,
2373 "error_reason": error_reason,
2374 "param_reqs": param_reqs,
2375 }
2376 return info_dict
2377
2378 @staticmethod
2379 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2380 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2381 param_reqs = {"rank": None, "dtype": None, "shape": None}
2382 error_result = False
2383 error_reason = "Cond graph output is not a match list of booleans"
2384
2385 if check:
2386 basicBlocks = kwargs["basicBlocks"]
2387 cond_block = basicBlocks[1]
2388 cond_outputs = cond_block.outputs
2389 cond_tens = cond_block.tensors
2390 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2391 error_result = True
2392
2393 info_dict = {
2394 "error_name": error_name,
2395 "error_result": error_result,
2396 "error_reason": error_reason,
2397 "param_reqs": param_reqs,
2398 }
2399 return info_dict
2400
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002401 @staticmethod
2402 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2403 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2404 param_reqs = {"rank": None, "dtype": None, "shape": None}
2405 error_result = False
2406 error_reason = "Cond graph output is not a shape of size one"
2407
2408 if check:
2409 basicBlocks = kwargs["basicBlocks"]
2410 cond_block = basicBlocks[1]
2411 cond_outputs = cond_block.outputs
2412 cond_tens = cond_block.tensors
2413 # Size of 1 is equivalent to rank 0
2414 if len(cond_tens[cond_outputs[0]].shape) != 0:
2415 error_result = True
2416
2417 info_dict = {
2418 "error_name": error_name,
2419 "error_result": error_result,
2420 "error_reason": error_reason,
2421 "param_reqs": param_reqs,
2422 }
2423 return info_dict
2424
Luke Hutton261b7b62023-01-10 14:50:31 +00002425 @staticmethod
2426 def evKernelNotPowerOfTwo(check=False, **kwargs):
2427 error_name = ErrorIf.KernelNotPowerOfTwo
2428 param_reqs = {"rank": None, "dtype": None, "shape": None}
2429 error_result = False
2430 error_reason = "kernel height and/or width not a power of two"
2431
2432 def is_power_of_two(x):
2433 return math.log(x, 2).is_integer()
2434
2435 if check:
2436 shape = kwargs["input_shape"]
2437 if len(shape) == 3:
2438 valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
2439 error_result = not valid_kernel
2440
2441 info_dict = {
2442 "error_name": error_name,
2443 "error_result": error_result,
2444 "error_reason": error_reason,
2445 "param_reqs": param_reqs,
2446 }
2447 return info_dict
2448
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002449
2450class TosaInvalidValidator:
2451 @staticmethod
2452 def ivWrongDataTypeOrModeResize(**kwargs):
2453 input_dtype = kwargs["input_dtype"]
2454 args = kwargs["args"]
2455 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002456 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002457
2458 if mode == ResizeMode.BILINEAR:
2459 # Invalid output data type / Invalid input datatype
2460 return (
2461 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002462 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002463 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002464 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002465 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002466 )
2467 elif mode == ResizeMode.NEAREST:
2468 # Invalid output data type / Invalid input datatype
2469 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002470 input_dtype
2471 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002472 )
2473 else:
2474 # Invalid resize mode
2475 return True
2476
2477 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002478 def ivHeightWidthInvalid(**kwargs):
2479 opName = kwargs["opName"]
2480
2481 inputShapes = kwargs["shapeList"]
2482 input_shape = inputShapes[0]
2483
2484 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002485
2486 # MaxPool2D has no accum_dtype arg
2487 stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2)
2488 strides = args[stride_idx]
2489 padding = args[pad_idx]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002490
2491 if opName.endswith("pool2d"):
2492 # avg_pool2d, max_pool2d
2493 kernel_shape = args[2]
2494 h = (
2495 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2496 ) // strides[0]
2497 w = (
2498 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2499 ) // strides[1]
2500 # return True if any dimension is < 1
2501 return h < 1 or w < 1
2502
2503 if opName.startswith("transpose_conv2d"):
2504 # transpose_conv2d
TatWai Chong24594f52022-06-08 00:48:04 -07002505 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002506 filter_shape = inputShapes[1]
2507 kernel_shape = filter_shape[1:-1]
2508
TatWai Chong24594f52022-06-08 00:48:04 -07002509 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002510 """Calculate the transpose_conv2d output size for a dimension.
2511
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002512 Args:
2513 in_size: the input size - int
2514 stride: the stride - int
2515 kernel_size: the kernel size - int
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002516 out_pad: the output padding - int
2517 in_pad: the input padding - int
2518
2519 Returns:
2520 the output size
2521 """
TatWai Chong24594f52022-06-08 00:48:04 -07002522 return (in_size - 1) * stride + kernel_size - in_pad - out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002523
2524 for pad_h, pad_w in (
2525 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2526 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2527 (0, 0), # VALID padding
2528 ):
2529 h = get_out_size(
2530 input_shape[1],
2531 strides[0],
2532 kernel_shape[0],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002533 padding[0],
2534 pad_h,
2535 )
2536 w = get_out_size(
2537 input_shape[2],
2538 strides[1],
2539 kernel_shape[1],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002540 padding[1],
2541 pad_w,
2542 )
2543 if output_shape[1] == h and output_shape[2] == w:
2544 return False
2545
2546 # output shape does not match the expected shape for any padding option
2547 return True
2548
2549 if "conv2d" in opName or "conv3d" in opName:
2550 # conv2d, conv3d, depthwise_conv2d
2551 dilations = args[2]
2552 filter_shape = inputShapes[1]
2553 kernel_shape = (
2554 filter_shape[0:2]
2555 if opName.startswith("depthwise_conv2d")
2556 else filter_shape[1:-1]
2557 )
2558
2559 for i in range(len(kernel_shape)):
2560 dim = (
2561 input_shape[i + 1]
2562 - kernel_shape[i]
2563 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2564 + padding[i * 2 + 0]
2565 + padding[i * 2 + 1]
2566 ) // strides[i] + 1
2567 # return True if any dimension is < 1
2568 if dim < 1:
2569 return True
2570 return False
2571
2572 assert False, f"Unrecognized Op: {opName}"
2573
2574 @staticmethod
2575 def ivNonPositiveOutputShape(**kwargs):
2576 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002577 output_shape = args[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002578 if output_shape[1] <= 0 or output_shape[2] <= 0:
2579 # Negative output shape
2580 return True
2581 return False