blob: c9d35c79bdbac6dfdea417142062014bfb695182 [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005from generator.tosa_utils import product
6from generator.tosa_utils import usableDTypes
7from generator.tosa_utils import valueToName
8from tosa.DType import DType
9from tosa.Op import Op
10from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011
Matthew Haddone86fd342021-09-07 16:12:21 +010012
13class ErrorIf(object):
14 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010015 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
16 ScaleNLargerMax = "ScaleNLargerMax"
17 ScaleDLargerMax = "ScaleDLargerMax"
18 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010019 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010020 BorderSmallerMin = "BorderSmallerMin"
21 BorderLargerEqualMax = "BorderLargerEqualMax"
22 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
23 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010024 WrongInputType = "WrongInputType"
25 WrongOutputType = "WrongOutputType"
26 WrongInputList = "WrongInputList"
27 WrongOutputList = "WrongOutputList"
28 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010029 BatchMismatch = "BatchMismatch"
30 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010031 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000032 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010033 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010034 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010036 AxisSmallerZero = "AxisSmallerZero"
37 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010038 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
39 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010040 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010041 KernelSmallerOne = "KernelSmallerOne"
42 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000043 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010044 PadSmallerZero = "PadSmallerZero"
45 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010046 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010047 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010048 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
49 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
50 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010051 ScaleNotTrue = "ScaleNotTrue"
52 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010053 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
54 StartSmallerZero = "StartSmallerZero"
55 SizeSmallerEqualZero = "SizeSmallerEqualZero"
56 StartSizeOutsideBounds = "StartSizeOutsideBounds"
57 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
58 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
59 IndexOutsideBounds = "IndexOutsideBounds"
60 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010061 MaxSmallerMin = "MaxSmallerMin"
62 ConcatInputRankMismatch = "ConcatInputRankMismatch"
63 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010064 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010065 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
66 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
67 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
68 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
69 InputListOutputListMismatch = "InputListOutputListMismatch"
70 InputListCondGraphMismatch = "InputListCondGraphMismatch"
71 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
72 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
73 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010074 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
75 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson05c711e2022-12-12 18:00:41 +000076 CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
77 CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
78 CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010079
80
81class TosaErrorIfArgGen:
82 @staticmethod
83 def eiResizeErrorIf(
84 testGen,
85 error_name,
86 mode,
87 dtype,
88 shapeList,
89 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010090 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010091 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010092 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010093 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010094 if error_name == ErrorIf.ScaleSmallerEqualZero:
95 index = testGen.randInt(low=0, high=4)
96 scale[index] = testGen.rng.choice([-2, -1, 0])
97 elif error_name == ErrorIf.ScaleNLargerMax:
98 index = testGen.rng.choice([0, 2])
99 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
100 elif error_name == ErrorIf.ScaleDLargerMax:
101 index = testGen.rng.choice([1, 3])
102 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100104 if error_name == ErrorIf.OffsetLargerEqualMax:
105 index = testGen.rng.choice([0, 1])
106 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
107 elif error_name == ErrorIf.OffsetSmallerMin:
108 index = testGen.rng.choice([0, 1])
109 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
110
111 if error_name == ErrorIf.BorderLargerEqualMax:
112 index = testGen.rng.choice([0, 1])
113 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
114 elif error_name == ErrorIf.BorderSmallerMin:
115 index = testGen.rng.choice([0, 1])
116 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100117
118 if error_name == ErrorIf.WrongOutputType:
119 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
120 incorrect_types = (
121 DType.INT4,
122 DType.INT16,
123 DType.INT32,
124 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100125 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100126 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100127 )
128 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
129 incorrect_types = (
130 DType.INT4,
131 DType.INT8,
132 DType.INT32,
133 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100134 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100135 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100136 )
137 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
138 incorrect_types = (
139 DType.INT4,
140 DType.INT8,
141 DType.INT16,
142 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100143 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100144 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100145 )
146 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
147 incorrect_types = (
148 DType.INT4,
149 DType.INT8,
150 DType.INT16,
151 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100152 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100153 DType.FP16,
154 )
155 elif dtype == DType.FP16:
156 incorrect_types = (
157 DType.INT4,
158 DType.INT8,
159 DType.INT16,
160 DType.INT32,
161 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100162 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100163 )
James Ward24dbc422022-10-19 12:20:31 +0100164 elif dtype == DType.BF16:
165 incorrect_types = (
166 DType.INT4,
167 DType.INT8,
168 DType.INT16,
169 DType.INT32,
170 DType.INT48,
171 DType.FP32,
172 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100173 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100174 incorrect_types = (
175 DType.INT4,
176 DType.INT8,
177 DType.INT16,
178 DType.INT32,
179 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100180 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100181 )
182 outputDType = testGen.rng.choice(a=incorrect_types)
183
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100184 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100185
186 @staticmethod
187 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
188 if (
189 error_name == ErrorIf.StrideSmallerOne
190 # padding must not exceed the kernel size
191 and pad[0] < kernel[0]
192 and pad[1] < kernel[0]
193 and pad[2] < kernel[1]
194 and pad[3] < kernel[1]
195 ):
196 wrongStride = (
197 testGen.rng.choice([0, -1, -2, -3]),
198 testGen.rng.choice([0, -1, -2, -3]),
199 )
200 return wrongStride, pad, kernel
201 elif error_name == ErrorIf.PadSmallerZero:
202 wrongPad = (
203 testGen.rng.choice([-1, -2, -3]),
204 testGen.rng.choice([-1, -2, -3]),
205 testGen.rng.choice([-1, -2, -3]),
206 testGen.rng.choice([-1, -2, -3]),
207 )
208 return stride, wrongPad, kernel
209 elif error_name == ErrorIf.KernelSmallerOne:
210 wrongKernel = (
211 testGen.rng.choice([0, -1, -2, -3]),
212 testGen.rng.choice([0, -1, -2, -3]),
213 )
214 return stride, pad, wrongKernel
215 elif error_name == ErrorIf.PadLargerEqualKernel:
216 wrongPad = (
217 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
218 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
219 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
220 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
221 )
222 return stride, wrongPad, kernel
223 else:
224 return None, None, None
225
226 @staticmethod
227 def eiRescaleWrongOutputType(input_dtype, output_dtype):
228 if input_dtype == DType.INT8:
229 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
230 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100231 elif input_dtype == DType.INT16:
232 if output_dtype not in [
233 DType.UINT8,
234 DType.INT8,
235 DType.UINT16,
236 DType.INT16,
237 DType.INT32,
238 ]:
239 return True
240 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100241 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
242 return True
243 elif input_dtype == DType.INT48:
244 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
245 return True
246 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100247 if output_dtype not in [DType.INT8, DType.INT16]:
248 return True
249 elif input_dtype == DType.UINT16:
250 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100251 return True
252 return False
253
254 @staticmethod
255 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
256 # Mess up input/output tensors for ERROR_IF checks
257 if error_name == "WrongInputList":
258 add_input = testGen.rng.choice([True, False])
259 if add_input:
260 input_list.append("eiDummyInput")
261 else:
262 input_list = input_list[:-1]
263 elif error_name == "WrongOutputList":
264 add_output = testGen.rng.choice([True, False])
265 if add_output:
266 output_list.append("eiDummyOutput")
267 else:
268 output_list = []
269 return input_list, output_list
270
271 @staticmethod
272 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
273 """Restrict the dimensions and overall size of a shape to
274 max_dim and max_items.
275 """
276 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
277 while product(new_shape) > max_items:
278 new_shape = [max(d - 1, 1) for d in new_shape]
279 return new_shape
280
281 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
282 if error_name == ErrorIf.StartSmallerZero:
283 newStart = []
284 for i in range(len(input_shape)):
285 newStart.append(testGen.rng.choice([-3, -2, -1]))
286 return newStart, size
287 elif error_name == ErrorIf.SizeSmallerEqualZero:
288 newSize = []
289 for i in range(len(input_shape)):
290 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
291 return start, newSize
292 elif error_name == ErrorIf.StartSizeOutsideBounds:
293 newStart, newSize = [], []
294 for i in range(len(input_shape)):
295 newStart.append(input_shape[i] - 1)
296 newSize.append(testGen.rng.choice([2, 3, 4]))
297 return newStart, newSize
298 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
299 remove = testGen.rng.choice([True, False])
300 if remove:
301 newStart = start[1:]
302 newSize = size[1:]
303 else:
304 newStart = start
305 newStart.append(1)
306 newSize = size
307 newSize.append(1)
308 return newStart, newSize
309 else:
310 return start, size
311
312 @staticmethod
313 def eiCastErrorIf(testGen, input_dtype):
James Ward24dbc422022-10-19 12:20:31 +0100314 if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]:
315 outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100316 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
317 outputDType = [DType.INT48]
318 else:
319 assert True, f"input_dtype ({input_dtype}) not supported"
320 return outputDType
321
322
323class TosaErrorValidator:
324 @staticmethod
325 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
326 """Check ERROR_IF statements are caught and set the expected result.
327
328 Args:
329 serializer: the serializer to set the expected result in
330 validator_fcns: a sequence of validator functions to verify the result
331 error_name: the name of the ERROR_IF condition to check for
332 kwargs: keyword arguments for the validator functions
333 Returns:
334 True if the result matches the expected result; otherwise False
335 """
336 overall_result = True
337 for val_fcn in validator_fcns:
338 val_result = val_fcn(True, **kwargs)
339 validator_name = val_result["error_name"]
340 error_result = val_result["error_result"]
341 error_reason = val_result["error_reason"]
342
343 # expect an error IFF the error_name and validator_name match
344 expected_result = error_result == (error_name == validator_name)
345 overall_result &= expected_result
346
347 if expected_result and error_result:
348 serializer.setExpectedReturnCode(2, True, desc=error_reason)
349 elif error_result: # and not expected_result
350 print(
351 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
352 f" Expected: {error_name}, Got: {validator_name}"
353 )
354 elif not expected_result: # and not error_result
355 print(
356 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
357 f" Expected: {error_name}"
358 )
359
360 if not expected_result:
361 for k, v in sorted(kwargs.items()):
362 if k != "op":
363 if k.endswith("dtype"):
364 v = valueToName(DType, v)
365 print(f" {k} = {v}")
366
367 return overall_result
368
369 @staticmethod
370 def evWrongInputType(check=False, **kwargs):
371 error_result = False
372
373 # Find the unsupported input data types
374 op = kwargs["op"]
375 input_dtypes = op["types"]
376 allowed_input_dtypes = {
377 t[0] if isinstance(t, list) else t for t in input_dtypes
378 }
379 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
380
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100381 # Turn the wrong dtypes into required list of types
382 if op["op"] in [
383 Op.FULLY_CONNECTED,
384 Op.CONV2D,
385 Op.CONV3D,
386 Op.DEPTHWISE_CONV2D,
387 Op.TRANSPOSE_CONV2D,
388 ]:
389 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
390
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100391 if op["op"] == Op.CLAMP:
392 wrong_input_dtypes.remove(DType.INT48)
393
394 if check:
395 input_dtype = kwargs["input_dtype"]
396 if input_dtype not in allowed_input_dtypes:
397 error_result = True
398
399 info_dict = {
400 "error_name": ErrorIf.WrongInputType,
401 "error_result": error_result,
402 "error_reason": "Input data type not supported for this operator",
403 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
404 }
405 return info_dict
406
407 @staticmethod
408 def evWrongOutputType(check=False, **kwargs):
409 error_result = False
410
411 if check:
412 input_dtype = kwargs["input_dtype"]
413 output_dtype = kwargs["output_dtype"]
414 op = kwargs["op"]
415
416 if op["op"] == Op.RESIZE:
417 mode = kwargs["mode"]
418 if (
419 (
420 mode == ResizeMode.NEAREST
421 and input_dtype == DType.INT8
422 and output_dtype != DType.INT8
423 )
424 or (
425 mode == ResizeMode.NEAREST
426 and input_dtype == DType.INT16
427 and output_dtype != DType.INT16
428 )
429 or (
430 mode == ResizeMode.BILINEAR
431 and input_dtype == DType.INT8
432 and output_dtype != DType.INT32
433 )
434 or (
435 mode == ResizeMode.BILINEAR
436 and input_dtype == DType.INT16
437 and output_dtype != DType.INT48
438 )
James Ward8b390432022-08-12 20:48:56 +0100439 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100440 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100441 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100442 ):
443 error_result = True
444
445 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100446 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
447 input_dtype, output_dtype
448 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100449
450 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
451 if (
452 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
453 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100454 or (
455 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100456 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100457 )
James Ward24dbc422022-10-19 12:20:31 +0100458 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100459 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100460 ):
461 error_result = True
462
463 elif op["op"] == Op.ARGMAX:
464 if (
James Ward24dbc422022-10-19 12:20:31 +0100465 input_dtype
466 in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100467 and output_dtype != DType.INT32
468 ):
469 error_result = True
470
471 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100472 if (
James Ward24dbc422022-10-19 12:20:31 +0100473 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100474 and output_dtype != DType.INT32
475 ):
476 error_result = True
477 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100478 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100479 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
480 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100481 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100482 error_result = True
483
484 elif op["op"] == Op.TABLE:
485 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
486 error_result = True
487 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
488 error_result = True
489
490 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
491 if output_dtype != DType.BOOL:
492 error_result = True
493
494 elif op["op"] == Op.CAST:
495 if (
496 (
497 input_dtype == DType.BOOL
498 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
499 )
500 or (
501 input_dtype == DType.INT8
502 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100503 not in [
504 DType.BOOL,
505 DType.INT16,
506 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100507 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100508 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100509 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100510 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100511 )
512 or (
513 input_dtype == DType.INT16
514 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100515 not in [
516 DType.BOOL,
517 DType.INT8,
518 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100519 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100520 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100521 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100522 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100523 )
524 or (
525 input_dtype == DType.INT32
526 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100527 not in [
528 DType.BOOL,
529 DType.INT8,
530 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100531 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100532 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100533 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100534 ]
535 )
536 or (
537 input_dtype == DType.FP16
538 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100539 )
540 or (
James Ward24dbc422022-10-19 12:20:31 +0100541 input_dtype == DType.BF16
542 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
543 )
544 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100545 input_dtype == DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100546 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
547 )
548 ):
549 error_result = True
550
551 elif op["op"] in {
552 Op.CONV2D,
553 Op.CONV3D,
554 Op.DEPTHWISE_CONV2D,
555 Op.TRANSPOSE_CONV2D,
556 }:
557 if (
558 input_dtype == DType.INT8
559 and output_dtype != DType.INT32
560 or input_dtype == DType.INT16
561 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100562 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100563 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100564 or input_dtype == DType.BF16
565 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100566 or input_dtype == DType.FP32
567 and output_dtype != DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100568 ):
569 error_result = True
570 # invalid input types are ignored, to avoid reporting multiple errors
571
572 else:
573 if output_dtype != input_dtype:
574 error_result = True
575
576 info_dict = {
577 "error_name": ErrorIf.WrongOutputType,
578 "error_result": error_result,
579 "error_reason": (
580 "Output data type not supported for this configuration of operator"
581 ),
582 "param_reqs": {"rank": None, "dtype": None, "shape": None},
583 }
584 return info_dict
585
586 @staticmethod
587 def evWrongRank(check=False, **kwargs):
588 all_ranks = (1, 2, 3, 4, 5)
589
590 # Make a list of incorrect ranks
591 assert "op" in kwargs
592 op = kwargs["op"]
593 rmin, rmax = op["rank"]
594 rank_range = range(rmin, rmax + 1)
595 incorrect_ranks = list(set(all_ranks) - set(rank_range))
596 # Remove small incorrect ranks to avoid index errors
597 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
598 # Set minimum incorrect rank to 3 to avoid index error
599 if op["op"] in [Op.RESIZE]:
600 incorrect_ranks = [3, 5]
601 elif op["op"] in [Op.TRANSPOSE]:
602 incorrect_ranks = [7, 8]
603 elif op["op"] in [Op.CONV3D]:
604 incorrect_ranks = [6, 7]
605
606 error_name = ErrorIf.WrongRank
607 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
608 error_result = False
609 error_reason = "Rank not supported for this operator"
610
611 if check:
612 input_shape = kwargs["input_shape"]
613
614 if (
615 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
616 and len(input_shape) != 4
617 ):
618 error_result = True
619 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
620 error_result = True
621 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
622 error_result = True
623 else:
624 if len(input_shape) not in rank_range:
625 error_result = True
626
627 info_dict = {
628 "error_name": error_name,
629 "error_result": error_result,
630 "error_reason": error_reason,
631 "param_reqs": param_reqs,
632 }
633 return info_dict
634
635 @staticmethod
636 def evWrongInputList(check=False, **kwargs):
637 error_name = ErrorIf.WrongInputList
638 param_reqs = {"rank": None, "dtype": None, "shape": None}
639 error_result = False
640 error_reason = "Op input list does not match expected input"
641
642 if check:
643 op = kwargs["op"]
644 input_list = kwargs["input_list"]
645 num_operands = kwargs["num_operands"]
646 if op["op"] in [Op.SCATTER, Op.GATHER]:
647 # SCATTER/GATHER add an indices input tensor in their build functions
648 num_operands += 1
649 if len(input_list) != num_operands:
650 error_result = True
651
652 info_dict = {
653 "error_name": error_name,
654 "error_result": error_result,
655 "error_reason": error_reason,
656 "param_reqs": param_reqs,
657 }
658 return info_dict
659
660 @staticmethod
661 def evWrongOutputList(check=False, **kwargs):
662 error_name = ErrorIf.WrongOutputList
663 param_reqs = {"rank": None, "dtype": None, "shape": None}
664 error_result = False
665 error_reason = "Op output list does not match expected output"
666
667 if check:
668 output_list = kwargs["output_list"]
669 # Note this will be incorrect if an operator returns more than one output
670 if len(output_list) != 1:
671 error_result = True
672
673 info_dict = {
674 "error_name": error_name,
675 "error_result": error_result,
676 "error_reason": error_reason,
677 "param_reqs": param_reqs,
678 }
679 return info_dict
680
681 @staticmethod
682 def evMaxDimExceeded(check=False, **kwargs):
683 error_name = ErrorIf.MaxDimExceeded
684 param_reqs = {
685 "rank": [4, 4],
686 "dtype": [DType.INT8],
687 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
688 }
689 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100690 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100691
692 if check:
693 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100694 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100695 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100696 (input_shape[1] >= MAX_RESIZE_DIMENSION)
697 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
698 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
699 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100700 ):
701 error_result = True
702
703 info_dict = {
704 "error_name": error_name,
705 "error_result": error_result,
706 "error_reason": error_reason,
707 "param_reqs": param_reqs,
708 }
709 return info_dict
710
711 @staticmethod
712 def evBatchMismatch(check=False, **kwargs):
713 error_name = ErrorIf.BatchMismatch
714 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
715 error_result = False
716 error_reason = "Input batch size not equal to output batch size"
717
718 assert "op" in kwargs
719 op = kwargs["op"]
720 rmin, rmax = op["rank"]
721 rank_range = range(rmin, rmax + 1)
722
723 if check:
724 input_shape = kwargs["input_shape"]
725 output_shape = kwargs[
726 "result_tensor"
727 ].shape # Note this is just (N, OH, OW, C)
728
729 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
730 error_result = True
731
732 info_dict = {
733 "error_name": error_name,
734 "error_result": error_result,
735 "error_reason": error_reason,
736 "param_reqs": param_reqs,
737 }
738 return info_dict
739
740 @staticmethod
741 def evChannelMismatch(check=False, **kwargs):
742 error_name = ErrorIf.ChannelMismatch
743 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
744 error_result = False
745 error_reason = "Input channel size not equal to output channel size"
746
747 assert "op" in kwargs
748 op = kwargs["op"]
749 rmin, rmax = op["rank"]
750 rank_range = range(rmin, rmax + 1)
751
752 if check:
753 input_shape = kwargs["input_shape"]
754 output_shape = kwargs[
755 "result_tensor"
756 ].shape # Note this is just (N, OH, OW, C)
757 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
758 error_result = True
759
760 info_dict = {
761 "error_name": error_name,
762 "error_result": error_result,
763 "error_reason": error_reason,
764 "param_reqs": param_reqs,
765 }
766 return info_dict
767
768 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100769 def evScaleSmallerEqualZero(check=False, **kwargs):
770 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100771 param_reqs = {"rank": None, "dtype": None, "shape": None}
772 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100773 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100774
775 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100776 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100777
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100778 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100779 error_result = True
780
781 info_dict = {
782 "error_name": error_name,
783 "error_result": error_result,
784 "error_reason": error_reason,
785 "param_reqs": param_reqs,
786 }
787 return info_dict
788
789 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100790 def evScaleNLargerMax(check=False, **kwargs):
791 error_name = ErrorIf.ScaleNLargerMax
792 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100793 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100794 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100795
796 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100797 scale = kwargs["scale"]
798
799 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
800 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100801
802 info_dict = {
803 "error_name": error_name,
804 "error_result": error_result,
805 "error_reason": error_reason,
806 "param_reqs": param_reqs,
807 }
808 return info_dict
809
810 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100811 def evScaleDLargerMax(check=False, **kwargs):
812 error_name = ErrorIf.ScaleDLargerMax
813 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100814 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100815 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100816
817 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100818 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100819
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100820 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
821 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100822 ):
823 error_result = True
824
825 info_dict = {
826 "error_name": error_name,
827 "error_result": error_result,
828 "error_reason": error_reason,
829 "param_reqs": param_reqs,
830 }
831 return info_dict
832
833 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100834 def evOffsetSmallerMin(check=False, **kwargs):
835 error_name = ErrorIf.OffsetSmallerMin
836 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100837 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100838 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100839
840 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100841 scale = kwargs["scale"]
842 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100843
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100844 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100845 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100846 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100847 error_result = True
848
849 info_dict = {
850 "error_name": error_name,
851 "error_result": error_result,
852 "error_reason": error_reason,
853 "param_reqs": param_reqs,
854 }
855 return info_dict
856
857 @staticmethod
858 def evOffsetLargerEqualMax(check=False, **kwargs):
859 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100860 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100861 error_result = False
862 error_reason = "Offset value larger than or equal to maximum value"
863
864 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100865 scale = kwargs["scale"]
866 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100867
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100868 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
869 error_result = True
870 elif (
871 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
872 ):
873 error_result = True
874
875 info_dict = {
876 "error_name": error_name,
877 "error_result": error_result,
878 "error_reason": error_reason,
879 "param_reqs": param_reqs,
880 }
881 return info_dict
882
883 @staticmethod
884 def evBorderSmallerMin(check=False, **kwargs):
885 error_name = ErrorIf.BorderSmallerMin
886 param_reqs = {"rank": None, "dtype": None, "shape": None}
887 error_result = False
888 error_reason = "Border value smaller than minimum value"
889
890 if check:
891 scale = kwargs["scale"]
892 border = kwargs["border"]
893
894 if (
895 scale[0] > 0
896 and scale[0] <= (1 << 11)
897 and (border[0] < (-16 * scale[0]))
898 ):
899 error_result = True
900 elif (
901 scale[2] > 0
902 and scale[2] <= (1 << 11)
903 and (border[1] < (-16 * scale[2]))
904 ):
905 error_result = True
906
907 info_dict = {
908 "error_name": error_name,
909 "error_result": error_result,
910 "error_reason": error_reason,
911 "param_reqs": param_reqs,
912 }
913 return info_dict
914
915 @staticmethod
916 def evBorderLargerEqualMax(check=False, **kwargs):
917 error_name = ErrorIf.BorderLargerEqualMax
918 param_reqs = {"rank": None, "dtype": None, "shape": None}
919 error_result = False
920 error_reason = "Border value larger than or equal to maximum value"
921
922 if check:
923 scale = kwargs["scale"]
924 border = kwargs["border"]
925
926 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
927 error_result = True
928 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
929 error_result = True
930
931 info_dict = {
932 "error_name": error_name,
933 "error_result": error_result,
934 "error_reason": error_reason,
935 "param_reqs": param_reqs,
936 }
937 return info_dict
938
939 @staticmethod
940 def checkResizeParams(scale, offset, border):
941 return (
942 min(scale) > 0
943 and max(scale[0], scale[2]) <= (1 << 11)
944 and scale[1] < 16 * scale[0]
945 and scale[3] < 16 * scale[2]
946 and offset[0] >= -scale[0]
947 and offset[1] >= -scale[2]
948 and offset[0] < 16 * scale[0]
949 and offset[1] < 16 * scale[2]
950 and border[0] >= -16 * scale[0]
951 and border[1] >= -16 * scale[2]
952 and border[0] < scale[0]
953 and border[1] < scale[2]
954 )
955
956 @staticmethod
957 def evResizeOutputShapeMismatch(check=False, **kwargs):
958 error_name = ErrorIf.ResizeOutputShapeMismatch
959 param_reqs = {"rank": None, "dtype": None, "shape": None}
960 error_result = False
961 error_reason = (
962 "Mismatch between output shape provided and expected output shape"
963 )
964
965 if check:
966 input_shape = kwargs["input_shape"]
967 output_shape = kwargs["output_shape"]
968 scale = kwargs["scale"]
969 offset = kwargs["offset"]
970 border = kwargs["border"]
971
972 # Ensure parameters are valid
973 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
974
975 if (
976 params_valid
977 and max(output_shape) < MAX_RESIZE_DIMENSION
978 and max(input_shape) < MAX_RESIZE_DIMENSION
979 ):
980 output_y = (
981 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
982 ) // scale[1] + 1
983 output_x = (
984 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
985 ) // scale[3] + 1
986
987 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100988 error_result = True
989
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100990 info_dict = {
991 "error_name": error_name,
992 "error_result": error_result,
993 "error_reason": error_reason,
994 "param_reqs": param_reqs,
995 }
996 return info_dict
997
998 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100999 def evResizeOutputShapeNonInteger(check=False, **kwargs):
1000 error_name = ErrorIf.ResizeOutputShapeNonInteger
1001 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001002 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001003 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001004
1005 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001006 input_shape = kwargs["input_shape"]
1007 scale = kwargs["scale"]
1008 offset = kwargs["offset"]
1009 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001010
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001011 # Ensure parameters are valid
1012 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001013
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001014 if params_valid:
1015 remainder_y = (
1016 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1017 ) % scale[1]
1018 remainder_x = (
1019 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1020 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001022 if max(remainder_y, remainder_x) > 0:
1023 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001024
1025 info_dict = {
1026 "error_name": error_name,
1027 "error_result": error_result,
1028 "error_reason": error_reason,
1029 "param_reqs": param_reqs,
1030 }
1031 return info_dict
1032
1033 @staticmethod
1034 def evRankMismatch(check=False, **kwargs):
1035 error_name = ErrorIf.RankMismatch
1036 param_reqs = {"rank": None, "dtype": None, "shape": None}
1037 error_result = False
1038 error_reason = "Input Rank does not match output rank"
1039
1040 if check:
1041 input1_shape = kwargs["input1"].shape
1042 input2_shape = kwargs["input2"].shape
1043 # In case of SELECT op
1044 input3_shape = (
1045 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1046 )
1047 output_shape = kwargs["result_tensor"].shape
1048 if (
1049 (len(input1_shape) != len(output_shape))
1050 or (len(input2_shape) != len(output_shape))
1051 or (len(input3_shape) != len(output_shape))
1052 ):
1053 error_result = True
1054
1055 info_dict = {
1056 "error_name": error_name,
1057 "error_result": error_result,
1058 "error_reason": error_reason,
1059 "param_reqs": param_reqs,
1060 }
1061 return info_dict
1062
1063 @staticmethod
1064 def evDimensionMismatch(check=False, **kwargs):
1065 error_name = ErrorIf.DimensionMismatch
1066 param_reqs = {"rank": None, "dtype": None, "shape": None}
1067 error_result = False
1068 error_reason = "Input Dimensions do not match output"
1069
1070 if check:
1071 input1_shape = kwargs["input1"].shape
1072 input2_shape = kwargs["input2"].shape
1073 # In case of SELECT op
1074 input3_shape = (
1075 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1076 )
1077 output_shape = kwargs["result_tensor"].shape
1078 for i in range(
1079 min(len(input1_shape), len(input2_shape), len(input3_shape))
1080 ):
1081 if (
1082 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
1083 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
1084 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
1085 ):
1086 error_result = True
1087
1088 info_dict = {
1089 "error_name": error_name,
1090 "error_result": error_result,
1091 "error_reason": error_reason,
1092 "param_reqs": param_reqs,
1093 }
1094 return info_dict
1095
1096 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001097 def _getZeroPoint(qinfo, index):
1098 """Return zero point value from quantization info.
1099
1100 Generally input_zp is index 0, output_zp is index 1
1101 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001102 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001103
1104 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001105 def evInputZeroPointNotZero(check=False, **kwargs):
1106 op = kwargs["op"]
1107 error_result = False
1108
1109 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001110 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001111
1112 # This does not apply to quantizable types
1113 inputDtypes = [
1114 dtype
1115 for dtype in op["types"]
1116 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1117 or (not isinstance(dtype, list) and dtype not in qTypes)
1118 ]
1119
1120 if check:
1121 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001122 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001123 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001124 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001125 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001126 (kwargs["input_dtype"], input_zero_point),
1127 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001128 ):
1129 if dtype not in qTypes and zp != 0:
1130 error_result = True
1131 break
1132 else:
1133 error_result = input_dtype not in qTypes and input_zero_point != 0
1134
1135 info_dict = {
1136 "error_name": ErrorIf.InputZeroPointNotZero,
1137 "error_result": error_result,
1138 "error_reason": "Input DType not INT8 and zero point not 0",
1139 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1140 }
1141 return info_dict
1142
1143 @staticmethod
1144 def evWeightZeroPointNotZero(check=False, **kwargs):
1145 op = kwargs["op"]
1146
1147 # exclude inputs with INT8 weights
1148 inputDtypes = [
1149 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1150 ]
1151
1152 error_name = ErrorIf.WeightZeroPointNotZero
1153 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1154 error_result = False
1155 error_reason = "Weight DType not INT8 and zero point not 0"
1156
1157 if check:
1158 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001159 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001160 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1161 error_result = True
1162
1163 info_dict = {
1164 "error_name": error_name,
1165 "error_result": error_result,
1166 "error_reason": error_reason,
1167 "param_reqs": param_reqs,
1168 }
1169 return info_dict
1170
1171 @staticmethod
1172 def evOutputZeroPointNotZero(check=False, **kwargs):
1173 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001174 inputDtypes = [
1175 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1176 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001177
1178 error_name = ErrorIf.OutputZeroPointNotZero
1179 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1180 error_result = False
1181 error_reason = "Output DType not INT8 and zero point not 0"
1182
1183 if check:
1184 input_dtype = kwargs["input_dtype"]
1185 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001186 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001187 if op["op"] == Op.AVG_POOL2D:
1188 if input_dtype != DType.INT8 and output_zero_point != 0:
1189 error_result = True
1190 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001191 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1192 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001193 ):
1194 error_result = True
1195
1196 info_dict = {
1197 "error_name": error_name,
1198 "error_result": error_result,
1199 "error_reason": error_reason,
1200 "param_reqs": param_reqs,
1201 }
1202 return info_dict
1203
1204 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001205 def evU16InputZeroPointNotValid(check=False, **kwargs):
1206 error_name = ErrorIf.U16InputZeroPointNotValid
1207 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1208 error_result = False
1209 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1210
1211 if check:
1212 input_dtype = kwargs["input_dtype"]
1213 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1214 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1215 0,
1216 32768,
1217 ]
1218
1219 info_dict = {
1220 "error_name": error_name,
1221 "error_result": error_result,
1222 "error_reason": error_reason,
1223 "param_reqs": param_reqs,
1224 }
1225 return info_dict
1226
1227 @staticmethod
1228 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1229 error_name = ErrorIf.U16OutputZeroPointNotValid
1230 param_reqs = {"rank": None, "dtype": None, "shape": None}
1231 error_result = False
1232 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1233
1234 if check:
1235 output_dtype = kwargs["output_dtype"]
1236 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1237
1238 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1239 0,
1240 32768,
1241 ]
1242
1243 info_dict = {
1244 "error_name": error_name,
1245 "error_result": error_result,
1246 "error_reason": error_reason,
1247 "param_reqs": param_reqs,
1248 }
1249 return info_dict
1250
1251 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001252 def evAxisSmallerZero(check=False, **kwargs):
1253 error_name = ErrorIf.AxisSmallerZero
1254 param_reqs = {"rank": None, "dtype": None, "shape": None}
1255 error_result = False
1256 error_reason = "Axis smaller than zero"
1257
1258 if check:
1259 axis = kwargs["axis"]
1260 if axis < 0:
1261 error_result = True
1262
1263 info_dict = {
1264 "error_name": error_name,
1265 "error_result": error_result,
1266 "error_reason": error_reason,
1267 "param_reqs": param_reqs,
1268 }
1269 return info_dict
1270
1271 @staticmethod
1272 def evAxisLargerRank(check=False, **kwargs):
1273 error_name = ErrorIf.AxisLargerRank
1274 param_reqs = {"rank": None, "dtype": None, "shape": None}
1275 error_result = False
1276 error_reason = "Axis larger than rank"
1277
1278 if check:
1279 axis = kwargs["axis"]
1280 shape = kwargs["input_shape"]
1281 if axis > len(shape):
1282 error_result = True
1283
1284 info_dict = {
1285 "error_name": error_name,
1286 "error_result": error_result,
1287 "error_reason": error_reason,
1288 "param_reqs": param_reqs,
1289 }
1290 return info_dict
1291
1292 @staticmethod
1293 def evShapeOfAxisNotOne(check=False, **kwargs):
1294 error_name = ErrorIf.ShapeOfAxisNotOne
1295 param_reqs = {"rank": None, "dtype": None, "shape": None}
1296 error_result = False
1297 error_reason = "shape[axis] is not equal to 1"
1298
1299 if check:
1300 axis = kwargs["axis"]
1301 shape = kwargs["output_shape"]
1302 if (0 <= axis < len(shape)) and shape[axis] != 1:
1303 error_result = True
1304
1305 info_dict = {
1306 "error_name": error_name,
1307 "error_result": error_result,
1308 "error_reason": error_reason,
1309 "param_reqs": param_reqs,
1310 }
1311 return info_dict
1312
1313 @staticmethod
1314 def evPadSmallerZero(check=False, **kwargs):
1315 error_name = ErrorIf.PadSmallerZero
1316 param_reqs = {"rank": None, "dtype": None, "shape": None}
1317 error_result = False
1318 error_reason = "At least one pad is smaller than zero"
1319
1320 if check:
1321 op = kwargs["op"]
1322 pad = kwargs["pad"]
1323 if op["op"] == Op.PAD:
1324 for padding in pad:
1325 if min(padding) < 0:
1326 error_result = True
1327 else:
1328 if min(pad) < 0:
1329 error_result = True
1330
1331 info_dict = {
1332 "error_name": error_name,
1333 "error_result": error_result,
1334 "error_reason": error_reason,
1335 "param_reqs": param_reqs,
1336 }
1337 return info_dict
1338
1339 @staticmethod
1340 def evPadLargerEqualKernel(check=False, **kwargs):
1341 error_name = ErrorIf.PadLargerEqualKernel
1342 param_reqs = {"rank": None, "dtype": None, "shape": None}
1343 error_result = False
1344 error_reason = "At least one pad is larger than kernel dimension"
1345
1346 if check:
1347 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001348 op = kwargs["op"]
1349 if op["op"] == Op.TRANSPOSE_CONV2D:
1350 # transpose_conv2d
1351 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001352 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001353 pad[0] <= -kernel[0]
1354 or pad[1] <= -kernel[0]
1355 or pad[2] <= -kernel[1]
1356 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001357 ):
1358 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001359 else:
1360 # pooling op
1361 kernel = kwargs["kernel"]
1362 if min(pad) > 0 and min(kernel) > 1:
1363 if (
1364 pad[0] >= kernel[0]
1365 or pad[1] >= kernel[0]
1366 or pad[2] >= kernel[1]
1367 or pad[3] >= kernel[1]
1368 ):
1369 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001370
1371 info_dict = {
1372 "error_name": error_name,
1373 "error_result": error_result,
1374 "error_reason": error_reason,
1375 "param_reqs": param_reqs,
1376 }
1377 return info_dict
1378
1379 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001380 def evPadOutputShapeMismatch(check=False, **kwargs):
1381 error_name = ErrorIf.PadOutputShapeMismatch
1382 param_reqs = {"rank": None, "dtype": None, "shape": None}
1383 error_result = False
1384 error_reason = "Pad output shape mismatch for requested padding"
1385
1386 if check:
1387 pad = kwargs["pad"]
1388 input_shape = kwargs["input_shape"]
1389 output_shape = kwargs["output_shape"]
1390 for dim, padding in enumerate(pad):
1391 expected_size = input_shape[dim] + padding[0] + padding[1]
1392 if expected_size != output_shape[dim]:
1393 error_result = True
1394
1395 info_dict = {
1396 "error_name": error_name,
1397 "error_result": error_result,
1398 "error_reason": error_reason,
1399 "param_reqs": param_reqs,
1400 }
1401 return info_dict
1402
1403 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001404 def checkPoolingParams(kernel, stride, pad):
1405 return (
1406 min(kernel) >= 1
1407 and min(stride) >= 1
1408 and min(pad) >= 0
1409 and not (
1410 pad[0] >= kernel[0]
1411 or pad[1] >= kernel[0]
1412 or pad[2] >= kernel[1]
1413 or pad[3] >= kernel[1]
1414 )
1415 )
1416
1417 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001418 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1419 error_name = ErrorIf.PoolingOutputShapeMismatch
1420 param_reqs = {"rank": None, "dtype": None, "shape": None}
1421 error_result = False
1422 error_reason = (
1423 "Mismatch between output shape provided and expected output shape"
1424 )
1425
1426 if check:
1427 pad = kwargs["pad"]
1428 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1429
1430 kernel = kwargs["kernel"]
1431 kernel_y, kernel_x = kernel[0], kernel[1]
1432
1433 input_shape = kwargs["input_shape"]
1434 IH, IW = input_shape[1], input_shape[2]
1435
1436 output_shape = kwargs["output_shape"]
1437 OH, OW = output_shape[1], output_shape[2]
1438
1439 stride = kwargs["stride"]
1440 stride_y, stride_x = stride[0], stride[1]
1441
1442 # calculate correct height, width dimensions
1443 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001444 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1445 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001446
1447 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001448 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001449
1450 if params_valid and (OH != y_correct or OW != x_correct):
1451 error_result = True
1452
1453 info_dict = {
1454 "error_name": error_name,
1455 "error_result": error_result,
1456 "error_reason": error_reason,
1457 "param_reqs": param_reqs,
1458 }
1459 return info_dict
1460
1461 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001462 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1463 error_name = ErrorIf.PoolingOutputShapeNonInteger
1464 param_reqs = {"rank": None, "dtype": None, "shape": None}
1465 error_result = False
1466 error_reason = "Parameters do not yield exact integer output dimensions"
1467
1468 if check:
1469 pad = kwargs["pad"]
1470 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1471
1472 kernel = kwargs["kernel"]
1473 kernel_y, kernel_x = kernel[0], kernel[1]
1474
1475 input_shape = kwargs["input_shape"]
1476 IH, IW = input_shape[1], input_shape[2]
1477
1478 stride = kwargs["stride"]
1479 stride_y, stride_x = stride[0], stride[1]
1480
1481 # calculate remainder of height, width dimensions
1482 if stride_x != 0 and stride_y != 0:
1483 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1484 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1485
1486 # ensure parameters are valid
1487 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1488 if params_valid and (y_remainder != 0 or x_remainder != 0):
1489 error_result = True
1490
1491 info_dict = {
1492 "error_name": error_name,
1493 "error_result": error_result,
1494 "error_reason": error_reason,
1495 "param_reqs": param_reqs,
1496 }
1497 return info_dict
1498
1499 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001500 def checkConvParams(op, weight_shape, stride, pad, dilation):
1501 if op == Op.TRANSPOSE_CONV2D:
1502 pad_ok = (
1503 pad[0] > -weight_shape[1]
1504 and pad[1] > -weight_shape[1]
1505 and pad[2] > -weight_shape[2]
1506 and pad[3] > -weight_shape[2]
1507 )
1508 else:
1509 pad_ok = min(pad) >= 0
1510
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001511 return (
1512 # Check kernel sizes
1513 min(weight_shape[1:-1]) >= 1
1514 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001515 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001516 and (dilation is None or min(dilation) >= 1)
1517 )
1518
1519 @staticmethod
1520 def evConvOutputShapeMismatch(check=False, **kwargs):
1521 error_name = ErrorIf.ConvOutputShapeMismatch
1522 param_reqs = {"rank": None, "dtype": None, "shape": None}
1523 error_result = False
1524 error_reason = (
1525 "Mismatch between output shape provided and expected output shape"
1526 )
1527
1528 if check:
1529 op = kwargs["op"]
1530 pad = kwargs["pad"]
1531 weight_shape = kwargs["weight_shape"]
1532 input_shape = kwargs["input_shape"]
1533 output_shape = kwargs["output_shape"]
1534 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1535 stride = kwargs["stride"]
1536
1537 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1538
1539 # calculate correct dimensions
1540 dims_correct = []
1541 if min(stride) > 0:
1542 for index in range(len(stride)):
1543 pad_offset = index * 2
1544 if op["op"] == Op.TRANSPOSE_CONV2D:
1545 dims_correct.append(
1546 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001547 + pad[pad_offset]
1548 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001549 + weight_shape[index + kernel_offset]
1550 )
1551 else:
1552 dims_correct.append(
1553 (
1554 input_shape[index + 1]
1555 - 1
1556 + pad[pad_offset]
1557 + pad[pad_offset + 1]
1558 - (weight_shape[index + kernel_offset] - 1)
1559 * dilation[index]
1560 )
1561 // stride[index]
1562 + 1
1563 )
1564
1565 # ensure parameters are valid
1566 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001567 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001568 )
1569
1570 if params_valid and output_shape[1:-1] != dims_correct:
1571 error_result = True
1572
1573 info_dict = {
1574 "error_name": error_name,
1575 "error_result": error_result,
1576 "error_reason": error_reason,
1577 "param_reqs": param_reqs,
1578 }
1579 return info_dict
1580
1581 @staticmethod
1582 def evConvOutputShapeNonInteger(check=False, **kwargs):
1583 error_name = ErrorIf.ConvOutputShapeNonInteger
1584 param_reqs = {"rank": None, "dtype": None, "shape": None}
1585 error_result = False
1586 error_reason = "Parameters do not yield exact integer output dimensions"
1587
1588 if check:
1589 op = kwargs["op"]
1590 pad = kwargs["pad"]
1591 weight_shape = kwargs["weight_shape"]
1592 input_shape = kwargs["input_shape"]
1593 dilation = kwargs["dilation"]
1594 stride = kwargs["stride"]
1595
1596 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1597
1598 # calculate correct height, width dimensions
1599 remainders = []
1600 if min(stride) > 0:
1601 for index in range(len(stride)):
1602 pad_offset = index * 2
1603 remainders.append(
1604 (
1605 input_shape[index + 1]
1606 - 1
1607 + pad[pad_offset]
1608 + pad[pad_offset + 1]
1609 - (weight_shape[index + kernel_offset] - 1)
1610 * dilation[index]
1611 )
1612 % stride[index]
1613 )
1614
1615 # ensure parameters are valid
1616 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001617 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001618 )
1619 if params_valid and max(remainders) > 0:
1620 error_result = True
1621
1622 info_dict = {
1623 "error_name": error_name,
1624 "error_result": error_result,
1625 "error_reason": error_reason,
1626 "param_reqs": param_reqs,
1627 }
1628 return info_dict
1629
1630 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001631 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1632 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1633 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1634 error_result = False
1635 error_reason = (
1636 "Mismatch between output shape provided and expected output shape"
1637 )
1638
1639 if check:
1640 output_shape = kwargs["output_shape"]
1641 input_shape = kwargs["input_shape"]
1642 axis = kwargs["axis"]
1643
1644 dimension_match = True
1645 axis_shift = 0
1646
1647 # Check that rank is correct before trying to check dimensions
1648 if (len(input_shape) - 1) == len(output_shape):
1649 for i in range(len(input_shape)):
1650 if i == axis:
1651 axis_shift = 1
1652 continue
1653 if input_shape[i] != output_shape[i - axis_shift]:
1654 dimension_match = False
1655
1656 if not dimension_match:
1657 error_result = True
1658
1659 info_dict = {
1660 "error_name": error_name,
1661 "error_result": error_result,
1662 "error_reason": error_reason,
1663 "param_reqs": param_reqs,
1664 }
1665 return info_dict
1666
1667 @staticmethod
1668 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1669 error_name = ErrorIf.ArgmaxOutputRankMismatch
1670 param_reqs = {"rank": None, "dtype": None, "shape": None}
1671 error_result = False
1672 error_reason = (
1673 "Mismatch between output shape provided and expected output shape"
1674 )
1675
1676 if check:
1677 output_shape = kwargs["output_shape"]
1678 input_shape = kwargs["input_shape"]
1679 axis = kwargs["axis"]
1680 valid_params = axis >= 0 and axis < len(input_shape)
1681
1682 if valid_params and (len(input_shape) - 1) != len(output_shape):
1683 error_result = True
1684
1685 info_dict = {
1686 "error_name": error_name,
1687 "error_result": error_result,
1688 "error_reason": error_reason,
1689 "param_reqs": param_reqs,
1690 }
1691 return info_dict
1692
1693 @staticmethod
1694 def evKernelSmallerOne(check=False, **kwargs):
1695 error_name = ErrorIf.KernelSmallerOne
1696 param_reqs = {"rank": None, "dtype": None, "shape": None}
1697 error_result = False
1698 error_reason = "At least one kernel dimension is smaller than zero"
1699
1700 if check:
1701 kernel = kwargs["kernel"]
1702 if min(kernel) < 1:
1703 error_result = True
1704
1705 info_dict = {
1706 "error_name": error_name,
1707 "error_result": error_result,
1708 "error_reason": error_reason,
1709 "param_reqs": param_reqs,
1710 }
1711 return info_dict
1712
1713 @staticmethod
1714 def evStrideSmallerOne(check=False, **kwargs):
1715 error_name = ErrorIf.StrideSmallerOne
1716 param_reqs = {"rank": None, "dtype": None, "shape": None}
1717 error_result = False
1718 error_reason = "At least one stride dimension is smaller than zero"
1719
1720 if check:
1721 stride = kwargs["stride"]
1722 if min(stride) < 1:
1723 error_result = True
1724
1725 info_dict = {
1726 "error_name": error_name,
1727 "error_result": error_result,
1728 "error_reason": error_reason,
1729 "param_reqs": param_reqs,
1730 }
1731 return info_dict
1732
1733 @staticmethod
1734 def evDilationSmallerOne(check=False, **kwargs):
1735 error_result = check and min(kwargs["dilation"]) < 1
1736 return {
1737 "error_name": ErrorIf.DilationSmallerOne,
1738 "error_reason": "At least one dilation is smaller than one",
1739 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1740 "error_result": error_result,
1741 }
1742
1743 @staticmethod
1744 def evScaleTrue(check=False, **kwargs):
1745 error_name = ErrorIf.ScaleTrue
1746 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1747 error_result = False
1748 error_reason = "Scale set to true but input type is INT48"
1749
1750 if check:
1751 input_dtype = kwargs["input_dtype"]
1752 scale32 = kwargs["scale32"]
1753 if scale32 and input_dtype == DType.INT48:
1754 error_result = True
1755
1756 info_dict = {
1757 "error_name": error_name,
1758 "error_result": error_result,
1759 "error_reason": error_reason,
1760 "param_reqs": param_reqs,
1761 }
1762 return info_dict
1763
1764 @staticmethod
1765 def evScaleNotTrue(check=False, **kwargs):
1766 error_name = ErrorIf.ScaleNotTrue
1767 param_reqs = {"rank": None, "dtype": None, "shape": None}
1768 error_result = False
1769 error_reason = "Scale set to false but double round set to true"
1770
1771 if check:
1772 scale32 = kwargs["scale32"]
1773 double_round = kwargs["double_round"]
1774 if not scale32 and double_round:
1775 error_result = True
1776
1777 info_dict = {
1778 "error_name": error_name,
1779 "error_result": error_result,
1780 "error_reason": error_reason,
1781 "param_reqs": param_reqs,
1782 }
1783 return info_dict
1784
1785 @staticmethod
1786 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1787 error_name = ErrorIf.TensorSizeInputOutputMismatch
1788 param_reqs = {"rank": None, "dtype": None, "shape": None}
1789 error_result = False
1790 error_reason = "Input tensor size does not match output tensor size"
1791
1792 if check:
1793 input_shape = kwargs["input_shape"]
1794 output_shape = kwargs["output_shape"]
1795 input_size = np.prod(input_shape)
1796 output_size = np.prod(output_shape)
1797 if input_size != output_size:
1798 error_result = True
1799
1800 info_dict = {
1801 "error_name": error_name,
1802 "error_result": error_result,
1803 "error_reason": error_reason,
1804 "param_reqs": param_reqs,
1805 }
1806 return info_dict
1807
1808 @staticmethod
1809 def evStartSmallerZero(check=False, **kwargs):
1810 error_name = ErrorIf.StartSmallerZero
1811 param_reqs = {"rank": None, "dtype": None, "shape": None}
1812 error_result = False
1813 error_reason = "Starting point smaller than zero"
1814
1815 if check:
1816 input_shape = kwargs["input_shape"]
1817 start = kwargs["start"]
1818 rank = len(input_shape)
1819 if len(start) == rank:
1820 for index in range(rank):
1821 if start[index] < 0:
1822 error_result = True
1823
1824 info_dict = {
1825 "error_name": error_name,
1826 "error_result": error_result,
1827 "error_reason": error_reason,
1828 "param_reqs": param_reqs,
1829 }
1830 return info_dict
1831
1832 @staticmethod
1833 def evSizeSmallerEqualZero(check=False, **kwargs):
1834 error_name = ErrorIf.SizeSmallerEqualZero
1835 param_reqs = {"rank": None, "dtype": None, "shape": None}
1836 error_result = False
1837 error_reason = "Size smaller than or equal to zero"
1838
1839 if check:
1840 input_shape = kwargs["input_shape"]
1841 size = kwargs["size"]
1842 rank = len(input_shape)
1843 if len(size) == rank:
1844 for index in range(rank):
1845 if size[index] <= 0:
1846 error_result = True
1847
1848 info_dict = {
1849 "error_name": error_name,
1850 "error_result": error_result,
1851 "error_reason": error_reason,
1852 "param_reqs": param_reqs,
1853 }
1854 return info_dict
1855
1856 @staticmethod
1857 def evStartSizeOutsideBounds(check=False, **kwargs):
1858 error_name = ErrorIf.StartSizeOutsideBounds
1859 param_reqs = {"rank": None, "dtype": None, "shape": None}
1860 error_result = False
1861 error_reason = "starting point plus size larger than input dimension"
1862
1863 if check:
1864 input_shape = kwargs["input_shape"]
1865 start = kwargs["start"]
1866 size = kwargs["size"]
1867 rank = len(input_shape)
1868 if len(start) == rank and len(size) == rank:
1869 for index in range(rank):
1870 if start[index] + size[index] > input_shape[index]:
1871 error_result = True
1872
1873 info_dict = {
1874 "error_name": error_name,
1875 "error_result": error_result,
1876 "error_reason": error_reason,
1877 "param_reqs": param_reqs,
1878 }
1879 return info_dict
1880
1881 @staticmethod
1882 def evSizeOutputShapeMismatch(check=False, **kwargs):
1883 error_name = ErrorIf.SizeOutputShapeMismatch
1884 param_reqs = {"rank": None, "dtype": None, "shape": None}
1885 error_result = False
1886 error_reason = "Size does not match output dimension"
1887
1888 if check:
1889 input_shape = kwargs["input_shape"]
1890 output_shape = kwargs["output_shape"]
1891 size = kwargs["size"]
1892 rank = len(input_shape)
1893 if len(size) == rank:
1894 for index in range(rank):
1895 if size[index] != output_shape[index]:
1896 error_result = True
1897
1898 info_dict = {
1899 "error_name": error_name,
1900 "error_result": error_result,
1901 "error_reason": error_reason,
1902 "param_reqs": param_reqs,
1903 }
1904 return info_dict
1905
1906 @staticmethod
1907 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1908 error_name = ErrorIf.InputSizeStartLengthMismatch
1909 param_reqs = {"rank": None, "dtype": None, "shape": None}
1910 error_result = False
1911 error_reason = "rank of input not equal to length of start or size"
1912
1913 if check:
1914 input_shape = kwargs["input_shape"]
1915 start = kwargs["start"]
1916 size = kwargs["size"]
1917 rank = len(input_shape)
1918 if rank != len(start) or rank != len(size):
1919 error_result = True
1920
1921 info_dict = {
1922 "error_name": error_name,
1923 "error_result": error_result,
1924 "error_reason": error_reason,
1925 "param_reqs": param_reqs,
1926 }
1927 return info_dict
1928
1929 @staticmethod
1930 def evIndexOutsideBounds(check=False, **kwargs):
1931 error_name = ErrorIf.IndexOutsideBounds
1932 param_reqs = {"rank": None, "dtype": None, "shape": None}
1933 error_result = False
1934 error_reason = "Index outside of allowed bounds"
1935
1936 if check:
1937 input_shape = kwargs["input_shape"]
1938 perms = kwargs["perms"]
1939 rank = len(input_shape)
1940
1941 for index in perms:
1942 if index < 0 or index > rank:
1943 error_result = True
1944
1945 info_dict = {
1946 "error_name": error_name,
1947 "error_result": error_result,
1948 "error_reason": error_reason,
1949 "param_reqs": param_reqs,
1950 }
1951 return info_dict
1952
1953 @staticmethod
1954 def evIndexUsedTwice(check=False, **kwargs):
1955 error_name = ErrorIf.IndexUsedTwice
1956 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1957 error_result = False
1958 error_reason = "Index used multiple times"
1959
1960 if check:
1961 perms = kwargs["perms"]
1962
1963 unique_indices = []
1964 for index in perms:
1965 if index in unique_indices:
1966 error_result = True
1967 else:
1968 unique_indices.append(index)
1969
1970 info_dict = {
1971 "error_name": error_name,
1972 "error_result": error_result,
1973 "error_reason": error_reason,
1974 "param_reqs": param_reqs,
1975 }
1976 return info_dict
1977
1978 @staticmethod
1979 def evMaxSmallerMin(check=False, **kwargs):
1980 error_name = ErrorIf.MaxSmallerMin
1981 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1982 error_result = False
1983 error_reason = "Max value smaller than min value"
1984
1985 if check:
1986 max_val = kwargs["max_val"]
1987 min_val = kwargs["min_val"]
1988 if max_val < min_val:
1989 error_result = True
1990
1991 info_dict = {
1992 "error_name": error_name,
1993 "error_result": error_result,
1994 "error_reason": error_reason,
1995 "param_reqs": param_reqs,
1996 }
1997 return info_dict
1998
1999 @staticmethod
2000 def evConcatInputRankMismatch(check=False, **kwargs):
2001 error_name = ErrorIf.ConcatInputRankMismatch
2002 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2003 error_result = False
2004 error_reason = "Input ranks are not identical"
2005
2006 if check:
2007 inputs = kwargs["inputs"]
2008 input_shape = kwargs["input_shape"]
2009 for input in inputs:
2010 if len(input.shape) != len(input_shape):
2011 error_result = True
2012
2013 info_dict = {
2014 "error_name": error_name,
2015 "error_result": error_result,
2016 "error_reason": error_reason,
2017 "param_reqs": param_reqs,
2018 }
2019 return info_dict
2020
2021 @staticmethod
2022 def evConcatInputDimMismatch(check=False, **kwargs):
2023 error_name = ErrorIf.ConcatInputDimMismatch
2024 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2025 error_result = False
2026 error_reason = "Input dimensions differ on too many axes"
2027
2028 if check:
2029 inputs = kwargs["inputs"]
2030 input_shape = kwargs["input_shape"]
2031 axis = kwargs["axis"]
2032
2033 # Ensure rank is valid before checking dims.
2034 valid_rank = True
2035 for input in inputs:
2036 if len(input.shape) != len(input_shape):
2037 valid_rank = False
2038
2039 if valid_rank:
2040 for input in inputs:
2041 for i, dim in enumerate(input.shape):
2042 if dim != input_shape[i] and axis != i:
2043 error_result = True
2044
2045 info_dict = {
2046 "error_name": error_name,
2047 "error_result": error_result,
2048 "error_reason": error_reason,
2049 "param_reqs": param_reqs,
2050 }
2051 return info_dict
2052
2053 @staticmethod
2054 def evConcatShapeSumMismatch(check=False, **kwargs):
2055 error_name = ErrorIf.ConcatShapeSumMismatch
2056 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2057 error_result = False
2058 error_reason = "Sum of dimensions on axis not equal to output dimension"
2059
2060 if check:
2061 inputs = kwargs["inputs"]
2062 input_shape = kwargs["input_shape"]
2063 output_shape = kwargs["output_shape"]
2064 axis = kwargs["axis"]
2065
2066 # Ensure rank is valid before checking dims.
2067 valid_params = True
2068 for input in inputs:
2069 if len(input.shape) != len(input_shape):
2070 valid_params = False
2071 if axis < 0 or axis > len(input_shape):
2072 valid_params = False
2073
2074 if valid_params:
2075 axis_dim_sum = 0
2076 for input in inputs:
2077 axis_dim_sum += input.shape[axis]
2078
2079 if axis_dim_sum != output_shape[axis]:
2080 error_result = True
2081
2082 info_dict = {
2083 "error_name": error_name,
2084 "error_result": error_result,
2085 "error_reason": error_reason,
2086 "param_reqs": param_reqs,
2087 }
2088 return info_dict
2089
2090 @staticmethod
2091 def evInputListThenGraphMismatch(check=False, **kwargs):
2092 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2093 param_reqs = {"rank": None, "dtype": None, "shape": None}
2094 error_result = False
2095 error_reason = "Input list shape does not match then-graph shape"
2096
2097 if check:
2098 a = kwargs["a"]
2099 b = kwargs["b"]
2100 basicBlocks = kwargs["basicBlocks"]
2101 then_block = basicBlocks[1]
2102 then_inputs = then_block.inputs
2103 then_tens = then_block.tensors
2104 if (a.shape != then_tens[then_inputs[0]].shape) or (
2105 b.shape != then_tens[then_inputs[1]].shape
2106 ):
2107 error_result = True
2108
2109 info_dict = {
2110 "error_name": error_name,
2111 "error_result": error_result,
2112 "error_reason": error_reason,
2113 "param_reqs": param_reqs,
2114 }
2115 return info_dict
2116
2117 @staticmethod
2118 def evInputListElseGraphMismatch(check=False, **kwargs):
2119 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2120 param_reqs = {"rank": None, "dtype": None, "shape": None}
2121 error_result = False
2122 error_reason = "Input list shape does not match else-graph shape"
2123
2124 if check:
2125 a = kwargs["a"]
2126 b = kwargs["b"]
2127 basicBlocks = kwargs["basicBlocks"]
2128 else_block = basicBlocks[2]
2129 else_inputs = else_block.inputs
2130 else_tens = else_block.tensors
2131 if (a.shape != else_tens[else_inputs[0]].shape) or (
2132 b.shape != else_tens[else_inputs[1]].shape
2133 ):
2134 error_result = True
2135
2136 info_dict = {
2137 "error_name": error_name,
2138 "error_result": error_result,
2139 "error_reason": error_reason,
2140 "param_reqs": param_reqs,
2141 }
2142 return info_dict
2143
2144 @staticmethod
2145 def evOutputListThenGraphMismatch(check=False, **kwargs):
2146 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2147 param_reqs = {"rank": None, "dtype": None, "shape": None}
2148 error_result = False
2149 error_reason = "Output list shape does not match then-graph shape"
2150
2151 if check:
2152 basicBlocks = kwargs["basicBlocks"]
2153 cond_block = basicBlocks[0]
2154 cond_outputs = cond_block.outputs
2155 cond_tens = cond_block.tensors
2156 then_block = basicBlocks[1]
2157 then_outputs = then_block.outputs
2158 then_tens = then_block.tensors
2159 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2160 error_result = True
2161
2162 info_dict = {
2163 "error_name": error_name,
2164 "error_result": error_result,
2165 "error_reason": error_reason,
2166 "param_reqs": param_reqs,
2167 }
2168 return info_dict
2169
2170 @staticmethod
2171 def evOutputListElseGraphMismatch(check=False, **kwargs):
2172 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2173 param_reqs = {"rank": None, "dtype": None, "shape": None}
2174 error_result = False
2175 error_reason = "Output list shape does not match else-graph shape"
2176
2177 if check:
2178 basicBlocks = kwargs["basicBlocks"]
2179 cond_block = basicBlocks[0]
2180 cond_outputs = cond_block.outputs
2181 cond_tens = cond_block.tensors
2182 else_block = basicBlocks[2]
2183 else_outputs = else_block.outputs
2184 else_tens = else_block.tensors
2185 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2186 error_result = True
2187
2188 info_dict = {
2189 "error_name": error_name,
2190 "error_result": error_result,
2191 "error_reason": error_reason,
2192 "param_reqs": param_reqs,
2193 }
2194 return info_dict
2195
2196 @staticmethod
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002197 def evCondIfCondNotMatchingBool(check=False, **kwargs):
2198 error_name = ErrorIf.CondIfCondNotMatchingBool
2199 param_reqs = {"rank": None, "dtype": None, "shape": None}
2200 error_result = False
2201 error_reason = "Conditional tensor does not match bool type"
2202
2203 if check:
2204 cond = kwargs["cond"]
2205 if cond.dtype != DType.BOOL:
2206 error_result = True
2207
2208 info_dict = {
2209 "error_name": error_name,
2210 "error_result": error_result,
2211 "error_reason": error_reason,
2212 "param_reqs": param_reqs,
2213 }
2214 return info_dict
2215
2216 @staticmethod
2217 def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
2218 error_name = ErrorIf.CondIfCondShapeNotSizeOne
2219 param_reqs = {"rank": None, "dtype": None, "shape": None}
2220 error_result = False
2221 error_reason = "Conditional tensor is not equal to a size of one"
2222
2223 if check:
2224 cond = kwargs["cond"]
2225 # Size of 1 is equivalent to rank 0
2226 if len(cond.shape) != 0:
2227 error_result = True
2228
2229 info_dict = {
2230 "error_name": error_name,
2231 "error_result": error_result,
2232 "error_reason": error_reason,
2233 "param_reqs": param_reqs,
2234 }
2235 return info_dict
2236
2237 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002238 def evInputListOutputListMismatch(check=False, **kwargs):
2239 error_name = ErrorIf.InputListOutputListMismatch
2240 param_reqs = {"rank": None, "dtype": None, "shape": None}
2241 error_result = False
2242 error_reason = "Input list does not match output list"
2243
2244 if check:
2245 basicBlocks = kwargs["basicBlocks"]
2246 while_block = basicBlocks[0]
2247 while_inputs = while_block.inputs
2248 while_outputs = while_block.outputs
2249 while_tens = while_block.tensors
2250 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2251 error_result = True
2252
2253 info_dict = {
2254 "error_name": error_name,
2255 "error_result": error_result,
2256 "error_reason": error_reason,
2257 "param_reqs": param_reqs,
2258 }
2259 return info_dict
2260
2261 @staticmethod
2262 def evInputListCondGraphMismatch(check=False, **kwargs):
2263 error_name = ErrorIf.InputListCondGraphMismatch
2264 param_reqs = {"rank": None, "dtype": None, "shape": None}
2265 error_result = False
2266 error_reason = "Input list does not match cond graph"
2267
2268 if check:
2269 basicBlocks = kwargs["basicBlocks"]
2270 while_block = basicBlocks[0]
2271 while_inputs = while_block.inputs
2272 while_tens = while_block.tensors
2273 cond_block = basicBlocks[1]
2274 cond_inputs = cond_block.inputs
2275 cond_tens = cond_block.tensors
2276 if (
2277 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2278 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2279 error_result = True
2280
2281 info_dict = {
2282 "error_name": error_name,
2283 "error_result": error_result,
2284 "error_reason": error_reason,
2285 "param_reqs": param_reqs,
2286 }
2287 return info_dict
2288
2289 @staticmethod
2290 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2291 error_name = ErrorIf.InputListBodyGraphInputMismatch
2292 param_reqs = {"rank": None, "dtype": None, "shape": None}
2293 error_result = False
2294 error_reason = "Input list does not match body graph input"
2295
2296 if check:
2297 basicBlocks = kwargs["basicBlocks"]
2298 while_block = basicBlocks[0]
2299 while_inputs = while_block.inputs
2300 while_tens = while_block.tensors
2301 body_block = basicBlocks[2]
2302 body_outputs = body_block.inputs
2303 body_tens = body_block.tensors
2304 if (
2305 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2306 ) or (
2307 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2308 ):
2309 error_result = True
2310
2311 info_dict = {
2312 "error_name": error_name,
2313 "error_result": error_result,
2314 "error_reason": error_reason,
2315 "param_reqs": param_reqs,
2316 }
2317 return info_dict
2318
2319 @staticmethod
2320 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2321 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2322 param_reqs = {"rank": None, "dtype": None, "shape": None}
2323 error_result = False
2324 error_reason = "Input list does not match body graph output"
2325
2326 if check:
2327 basicBlocks = kwargs["basicBlocks"]
2328 while_block = basicBlocks[0]
2329 while_inputs = while_block.inputs
2330 while_tens = while_block.tensors
2331 body_block = basicBlocks[2]
2332 body_outputs = body_block.outputs
2333 body_tens = body_block.tensors
2334 if (
2335 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2336 ) or (
2337 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2338 ):
2339 error_result = True
2340 info_dict = {
2341 "error_name": error_name,
2342 "error_result": error_result,
2343 "error_reason": error_reason,
2344 "param_reqs": param_reqs,
2345 }
2346 return info_dict
2347
2348 @staticmethod
2349 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2350 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2351 param_reqs = {"rank": None, "dtype": None, "shape": None}
2352 error_result = False
2353 error_reason = "Cond graph output is not a match list of booleans"
2354
2355 if check:
2356 basicBlocks = kwargs["basicBlocks"]
2357 cond_block = basicBlocks[1]
2358 cond_outputs = cond_block.outputs
2359 cond_tens = cond_block.tensors
2360 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2361 error_result = True
2362
2363 info_dict = {
2364 "error_name": error_name,
2365 "error_result": error_result,
2366 "error_reason": error_reason,
2367 "param_reqs": param_reqs,
2368 }
2369 return info_dict
2370
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002371 @staticmethod
2372 def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
2373 error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
2374 param_reqs = {"rank": None, "dtype": None, "shape": None}
2375 error_result = False
2376 error_reason = "Cond graph output is not a shape of size one"
2377
2378 if check:
2379 basicBlocks = kwargs["basicBlocks"]
2380 cond_block = basicBlocks[1]
2381 cond_outputs = cond_block.outputs
2382 cond_tens = cond_block.tensors
2383 # Size of 1 is equivalent to rank 0
2384 if len(cond_tens[cond_outputs[0]].shape) != 0:
2385 error_result = True
2386
2387 info_dict = {
2388 "error_name": error_name,
2389 "error_result": error_result,
2390 "error_reason": error_reason,
2391 "param_reqs": param_reqs,
2392 }
2393 return info_dict
2394
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002395
2396class TosaInvalidValidator:
2397 @staticmethod
2398 def ivWrongDataTypeOrModeResize(**kwargs):
2399 input_dtype = kwargs["input_dtype"]
2400 args = kwargs["args"]
2401 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002402 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002403
2404 if mode == ResizeMode.BILINEAR:
2405 # Invalid output data type / Invalid input datatype
2406 return (
2407 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002408 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002409 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002410 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002411 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002412 )
2413 elif mode == ResizeMode.NEAREST:
2414 # Invalid output data type / Invalid input datatype
2415 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002416 input_dtype
2417 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002418 )
2419 else:
2420 # Invalid resize mode
2421 return True
2422
2423 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002424 def ivHeightWidthInvalid(**kwargs):
2425 opName = kwargs["opName"]
2426
2427 inputShapes = kwargs["shapeList"]
2428 input_shape = inputShapes[0]
2429
2430 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002431
2432 # MaxPool2D has no accum_dtype arg
2433 stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2)
2434 strides = args[stride_idx]
2435 padding = args[pad_idx]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002436
2437 if opName.endswith("pool2d"):
2438 # avg_pool2d, max_pool2d
2439 kernel_shape = args[2]
2440 h = (
2441 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2442 ) // strides[0]
2443 w = (
2444 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2445 ) // strides[1]
2446 # return True if any dimension is < 1
2447 return h < 1 or w < 1
2448
2449 if opName.startswith("transpose_conv2d"):
2450 # transpose_conv2d
TatWai Chong24594f52022-06-08 00:48:04 -07002451 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002452 filter_shape = inputShapes[1]
2453 kernel_shape = filter_shape[1:-1]
2454
TatWai Chong24594f52022-06-08 00:48:04 -07002455 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002456 """Calculate the transpose_conv2d output size for a dimension.
2457
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002458 Args:
2459 in_size: the input size - int
2460 stride: the stride - int
2461 kernel_size: the kernel size - int
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002462 out_pad: the output padding - int
2463 in_pad: the input padding - int
2464
2465 Returns:
2466 the output size
2467 """
TatWai Chong24594f52022-06-08 00:48:04 -07002468 return (in_size - 1) * stride + kernel_size - in_pad - out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002469
2470 for pad_h, pad_w in (
2471 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2472 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2473 (0, 0), # VALID padding
2474 ):
2475 h = get_out_size(
2476 input_shape[1],
2477 strides[0],
2478 kernel_shape[0],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002479 padding[0],
2480 pad_h,
2481 )
2482 w = get_out_size(
2483 input_shape[2],
2484 strides[1],
2485 kernel_shape[1],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002486 padding[1],
2487 pad_w,
2488 )
2489 if output_shape[1] == h and output_shape[2] == w:
2490 return False
2491
2492 # output shape does not match the expected shape for any padding option
2493 return True
2494
2495 if "conv2d" in opName or "conv3d" in opName:
2496 # conv2d, conv3d, depthwise_conv2d
2497 dilations = args[2]
2498 filter_shape = inputShapes[1]
2499 kernel_shape = (
2500 filter_shape[0:2]
2501 if opName.startswith("depthwise_conv2d")
2502 else filter_shape[1:-1]
2503 )
2504
2505 for i in range(len(kernel_shape)):
2506 dim = (
2507 input_shape[i + 1]
2508 - kernel_shape[i]
2509 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2510 + padding[i * 2 + 0]
2511 + padding[i * 2 + 1]
2512 ) // strides[i] + 1
2513 # return True if any dimension is < 1
2514 if dim < 1:
2515 return True
2516 return False
2517
2518 assert False, f"Unrecognized Op: {opName}"
2519
2520 @staticmethod
2521 def ivNonPositiveOutputShape(**kwargs):
2522 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002523 output_shape = args[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002524 if output_shape[1] <= 0 or output_shape[2] <= 0:
2525 # Negative output shape
2526 return True
2527 return False