blob: a85069976c55131c5d4d01cac37cf0bae83cc8dd [file] [log] [blame]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003import numpy as np
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01005from generator.tosa_utils import product
6from generator.tosa_utils import usableDTypes
7from generator.tosa_utils import valueToName
8from tosa.DType import DType
9from tosa.Op import Op
10from tosa.ResizeMode import ResizeMode
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011
Matthew Haddone86fd342021-09-07 16:12:21 +010012
13class ErrorIf(object):
14 MaxDimExceeded = "MaxDimExceeded"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010015 ScaleSmallerEqualZero = "ScaleSmallerEqualZero"
16 ScaleNLargerMax = "ScaleNLargerMax"
17 ScaleDLargerMax = "ScaleDLargerMax"
18 OffsetSmallerMin = "OffsetSmallerMin"
Matthew Haddone86fd342021-09-07 16:12:21 +010019 OffsetLargerEqualMax = "OffsetLargerEqualMax"
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010020 BorderSmallerMin = "BorderSmallerMin"
21 BorderLargerEqualMax = "BorderLargerEqualMax"
22 ResizeOutputShapeMismatch = "ResizeOutputShapeMismatch"
23 ResizeOutputShapeNonInteger = "ResizeOutputShapeNonInteger"
Matthew Haddon848efb42021-09-09 12:30:53 +010024 WrongInputType = "WrongInputType"
25 WrongOutputType = "WrongOutputType"
26 WrongInputList = "WrongInputList"
27 WrongOutputList = "WrongOutputList"
28 WrongRank = "WrongRank"
Matthew Haddon693ba9e2021-09-22 11:24:37 +010029 BatchMismatch = "BatchMismatch"
30 ChannelMismatch = "ChannelMismatch"
Matthew Haddoneacff9a2021-09-24 14:42:13 +010031 RankMismatch = "RankMismatch"
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +000032 DimensionMismatch = "DimensionMismatch"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010033 InputZeroPointNotZero = "InputZeroPointNotZero"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010034 WeightZeroPointNotZero = "WeightZeroPointNotZero"
Matthew Haddone4ecdb22021-09-28 11:38:21 +010035 OutputZeroPointNotZero = "OutputZeroPointNotZero"
Matthew Haddond6ce7252021-09-29 15:35:44 +010036 AxisSmallerZero = "AxisSmallerZero"
37 AxisLargerRank = "AxisLargerRank"
Matthew Haddonc4cf0372021-10-11 09:38:10 +010038 ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
39 ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
Matthew Haddond6ce7252021-09-29 15:35:44 +010040 ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010041 KernelSmallerOne = "KernelSmallerOne"
42 StrideSmallerOne = "StrideSmallerOne"
Les Bell0e027d42021-11-09 14:42:14 +000043 DilationSmallerOne = "DilationSmallerOne"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010044 PadSmallerZero = "PadSmallerZero"
45 PadLargerEqualKernel = "PadLargerEqualKernel"
Jeremy Johnsond32c6da2022-08-24 17:09:09 +010046 PadOutputShapeMismatch = "PadOutputShapeMismatch"
Matthew Haddonb6b59e32021-10-07 17:19:20 +010047 PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +010048 PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
49 ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
50 ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
Matthew Haddonc2025212021-10-08 21:21:05 +010051 ScaleNotTrue = "ScaleNotTrue"
52 ScaleTrue = "ScaleTrue"
Matthew Haddone807aae2021-10-11 18:12:58 +010053 TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
54 StartSmallerZero = "StartSmallerZero"
55 SizeSmallerEqualZero = "SizeSmallerEqualZero"
56 StartSizeOutsideBounds = "StartSizeOutsideBounds"
57 SizeOutputShapeMismatch = "SizeOutputShapeMismatch"
58 InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
59 IndexOutsideBounds = "IndexOutsideBounds"
60 IndexUsedTwice = "IndexUsedTwice"
Matthew Haddonbb5676f2021-10-13 11:30:30 +010061 MaxSmallerMin = "MaxSmallerMin"
62 ConcatInputRankMismatch = "ConcatInputRankMismatch"
63 ConcatInputDimMismatch = "ConcatInputDimMismatch"
Matthew Haddon01c359d2021-10-15 16:30:48 +010064 ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
Matthew Haddon630c17c2021-10-14 15:05:41 +010065 CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
66 CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
67 CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
68 CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
69 InputListOutputListMismatch = "InputListOutputListMismatch"
70 InputListCondGraphMismatch = "InputListCondGraphMismatch"
71 InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
72 InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
73 CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010074 U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
75 U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010076
77
78class TosaErrorIfArgGen:
79 @staticmethod
80 def eiResizeErrorIf(
81 testGen,
82 error_name,
83 mode,
84 dtype,
85 shapeList,
86 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010087 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010088 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010089 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010090 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010091 if error_name == ErrorIf.ScaleSmallerEqualZero:
92 index = testGen.randInt(low=0, high=4)
93 scale[index] = testGen.rng.choice([-2, -1, 0])
94 elif error_name == ErrorIf.ScaleNLargerMax:
95 index = testGen.rng.choice([0, 2])
96 scale[index] = (1 << 11) + testGen.rng.choice([1, 2, 3])
97 elif error_name == ErrorIf.ScaleDLargerMax:
98 index = testGen.rng.choice([1, 3])
99 scale[index] = 16 * scale[index - 1] + testGen.rng.choice([0, 1, 2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100101 if error_name == ErrorIf.OffsetLargerEqualMax:
102 index = testGen.rng.choice([0, 1])
103 offset[index] = 16 * scale[index * 2] + testGen.rng.choice([0, 1, 2])
104 elif error_name == ErrorIf.OffsetSmallerMin:
105 index = testGen.rng.choice([0, 1])
106 offset[index] = -scale[index * 2] - testGen.rng.choice([1, 2, 3])
107
108 if error_name == ErrorIf.BorderLargerEqualMax:
109 index = testGen.rng.choice([0, 1])
110 border[index] = scale[index * 2] + testGen.rng.choice([0, 1, 2])
111 elif error_name == ErrorIf.BorderSmallerMin:
112 index = testGen.rng.choice([0, 1])
113 border[index] = -16 * scale[index * 2] - testGen.rng.choice([1, 2, 3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100114
115 if error_name == ErrorIf.WrongOutputType:
116 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
117 incorrect_types = (
118 DType.INT4,
119 DType.INT16,
120 DType.INT32,
121 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100122 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100123 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100124 )
125 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
126 incorrect_types = (
127 DType.INT4,
128 DType.INT8,
129 DType.INT32,
130 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100131 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100132 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100133 )
134 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
135 incorrect_types = (
136 DType.INT4,
137 DType.INT8,
138 DType.INT16,
139 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100140 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100141 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100142 )
143 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
144 incorrect_types = (
145 DType.INT4,
146 DType.INT8,
147 DType.INT16,
148 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100149 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100150 DType.FP16,
151 )
152 elif dtype == DType.FP16:
153 incorrect_types = (
154 DType.INT4,
155 DType.INT8,
156 DType.INT16,
157 DType.INT32,
158 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100159 DType.FP32,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100160 )
James Ward24dbc422022-10-19 12:20:31 +0100161 elif dtype == DType.BF16:
162 incorrect_types = (
163 DType.INT4,
164 DType.INT8,
165 DType.INT16,
166 DType.INT32,
167 DType.INT48,
168 DType.FP32,
169 )
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100170 elif dtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100171 incorrect_types = (
172 DType.INT4,
173 DType.INT8,
174 DType.INT16,
175 DType.INT32,
176 DType.INT48,
James Ward8b390432022-08-12 20:48:56 +0100177 DType.FP16,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 )
179 outputDType = testGen.rng.choice(a=incorrect_types)
180
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100181 return scale, offset, border, outputDType
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100182
183 @staticmethod
184 def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
185 if (
186 error_name == ErrorIf.StrideSmallerOne
187 # padding must not exceed the kernel size
188 and pad[0] < kernel[0]
189 and pad[1] < kernel[0]
190 and pad[2] < kernel[1]
191 and pad[3] < kernel[1]
192 ):
193 wrongStride = (
194 testGen.rng.choice([0, -1, -2, -3]),
195 testGen.rng.choice([0, -1, -2, -3]),
196 )
197 return wrongStride, pad, kernel
198 elif error_name == ErrorIf.PadSmallerZero:
199 wrongPad = (
200 testGen.rng.choice([-1, -2, -3]),
201 testGen.rng.choice([-1, -2, -3]),
202 testGen.rng.choice([-1, -2, -3]),
203 testGen.rng.choice([-1, -2, -3]),
204 )
205 return stride, wrongPad, kernel
206 elif error_name == ErrorIf.KernelSmallerOne:
207 wrongKernel = (
208 testGen.rng.choice([0, -1, -2, -3]),
209 testGen.rng.choice([0, -1, -2, -3]),
210 )
211 return stride, pad, wrongKernel
212 elif error_name == ErrorIf.PadLargerEqualKernel:
213 wrongPad = (
214 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
215 testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
216 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
217 testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
218 )
219 return stride, wrongPad, kernel
220 else:
221 return None, None, None
222
223 @staticmethod
224 def eiRescaleWrongOutputType(input_dtype, output_dtype):
225 if input_dtype == DType.INT8:
226 if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
227 return True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100228 elif input_dtype == DType.INT16:
229 if output_dtype not in [
230 DType.UINT8,
231 DType.INT8,
232 DType.UINT16,
233 DType.INT16,
234 DType.INT32,
235 ]:
236 return True
237 elif input_dtype == DType.INT32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100238 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
239 return True
240 elif input_dtype == DType.INT48:
241 if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
242 return True
243 elif input_dtype == DType.UINT8:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100244 if output_dtype not in [DType.INT8, DType.INT16]:
245 return True
246 elif input_dtype == DType.UINT16:
247 if output_dtype != DType.INT16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100248 return True
249 return False
250
251 @staticmethod
252 def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
253 # Mess up input/output tensors for ERROR_IF checks
254 if error_name == "WrongInputList":
255 add_input = testGen.rng.choice([True, False])
256 if add_input:
257 input_list.append("eiDummyInput")
258 else:
259 input_list = input_list[:-1]
260 elif error_name == "WrongOutputList":
261 add_output = testGen.rng.choice([True, False])
262 if add_output:
263 output_list.append("eiDummyOutput")
264 else:
265 output_list = []
266 return input_list, output_list
267
268 @staticmethod
269 def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
270 """Restrict the dimensions and overall size of a shape to
271 max_dim and max_items.
272 """
273 new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
274 while product(new_shape) > max_items:
275 new_shape = [max(d - 1, 1) for d in new_shape]
276 return new_shape
277
278 def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
279 if error_name == ErrorIf.StartSmallerZero:
280 newStart = []
281 for i in range(len(input_shape)):
282 newStart.append(testGen.rng.choice([-3, -2, -1]))
283 return newStart, size
284 elif error_name == ErrorIf.SizeSmallerEqualZero:
285 newSize = []
286 for i in range(len(input_shape)):
287 newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
288 return start, newSize
289 elif error_name == ErrorIf.StartSizeOutsideBounds:
290 newStart, newSize = [], []
291 for i in range(len(input_shape)):
292 newStart.append(input_shape[i] - 1)
293 newSize.append(testGen.rng.choice([2, 3, 4]))
294 return newStart, newSize
295 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
296 remove = testGen.rng.choice([True, False])
297 if remove:
298 newStart = start[1:]
299 newSize = size[1:]
300 else:
301 newStart = start
302 newStart.append(1)
303 newSize = size
304 newSize.append(1)
305 return newStart, newSize
306 else:
307 return start, size
308
309 @staticmethod
310 def eiCastErrorIf(testGen, input_dtype):
James Ward24dbc422022-10-19 12:20:31 +0100311 if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]:
312 outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100313 elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
314 outputDType = [DType.INT48]
315 else:
316 assert True, f"input_dtype ({input_dtype}) not supported"
317 return outputDType
318
319
320class TosaErrorValidator:
321 @staticmethod
322 def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
323 """Check ERROR_IF statements are caught and set the expected result.
324
325 Args:
326 serializer: the serializer to set the expected result in
327 validator_fcns: a sequence of validator functions to verify the result
328 error_name: the name of the ERROR_IF condition to check for
329 kwargs: keyword arguments for the validator functions
330 Returns:
331 True if the result matches the expected result; otherwise False
332 """
333 overall_result = True
334 for val_fcn in validator_fcns:
335 val_result = val_fcn(True, **kwargs)
336 validator_name = val_result["error_name"]
337 error_result = val_result["error_result"]
338 error_reason = val_result["error_reason"]
339
340 # expect an error IFF the error_name and validator_name match
341 expected_result = error_result == (error_name == validator_name)
342 overall_result &= expected_result
343
344 if expected_result and error_result:
345 serializer.setExpectedReturnCode(2, True, desc=error_reason)
346 elif error_result: # and not expected_result
347 print(
348 f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
349 f" Expected: {error_name}, Got: {validator_name}"
350 )
351 elif not expected_result: # and not error_result
352 print(
353 f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
354 f" Expected: {error_name}"
355 )
356
357 if not expected_result:
358 for k, v in sorted(kwargs.items()):
359 if k != "op":
360 if k.endswith("dtype"):
361 v = valueToName(DType, v)
362 print(f" {k} = {v}")
363
364 return overall_result
365
366 @staticmethod
367 def evWrongInputType(check=False, **kwargs):
368 error_result = False
369
370 # Find the unsupported input data types
371 op = kwargs["op"]
372 input_dtypes = op["types"]
373 allowed_input_dtypes = {
374 t[0] if isinstance(t, list) else t for t in input_dtypes
375 }
376 wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
377
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100378 # Turn the wrong dtypes into required list of types
379 if op["op"] in [
380 Op.FULLY_CONNECTED,
381 Op.CONV2D,
382 Op.CONV3D,
383 Op.DEPTHWISE_CONV2D,
384 Op.TRANSPOSE_CONV2D,
385 ]:
386 wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
387
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100388 if op["op"] == Op.CLAMP:
389 wrong_input_dtypes.remove(DType.INT48)
390
391 if check:
392 input_dtype = kwargs["input_dtype"]
393 if input_dtype not in allowed_input_dtypes:
394 error_result = True
395
396 info_dict = {
397 "error_name": ErrorIf.WrongInputType,
398 "error_result": error_result,
399 "error_reason": "Input data type not supported for this operator",
400 "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
401 }
402 return info_dict
403
404 @staticmethod
405 def evWrongOutputType(check=False, **kwargs):
406 error_result = False
407
408 if check:
409 input_dtype = kwargs["input_dtype"]
410 output_dtype = kwargs["output_dtype"]
411 op = kwargs["op"]
412
413 if op["op"] == Op.RESIZE:
414 mode = kwargs["mode"]
415 if (
416 (
417 mode == ResizeMode.NEAREST
418 and input_dtype == DType.INT8
419 and output_dtype != DType.INT8
420 )
421 or (
422 mode == ResizeMode.NEAREST
423 and input_dtype == DType.INT16
424 and output_dtype != DType.INT16
425 )
426 or (
427 mode == ResizeMode.BILINEAR
428 and input_dtype == DType.INT8
429 and output_dtype != DType.INT32
430 )
431 or (
432 mode == ResizeMode.BILINEAR
433 and input_dtype == DType.INT16
434 and output_dtype != DType.INT48
435 )
James Ward8b390432022-08-12 20:48:56 +0100436 or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +0100437 or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100438 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100439 ):
440 error_result = True
441
442 elif op["op"] == Op.RESCALE:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100443 error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
444 input_dtype, output_dtype
445 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100446
447 elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
448 if (
449 (input_dtype == DType.INT8 and output_dtype != DType.INT32)
450 or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
James Ward8b390432022-08-12 20:48:56 +0100451 or (
452 input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100453 and output_dtype not in (DType.FP16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100454 )
James Ward24dbc422022-10-19 12:20:31 +0100455 or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100456 or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100457 ):
458 error_result = True
459
460 elif op["op"] == Op.ARGMAX:
461 if (
James Ward24dbc422022-10-19 12:20:31 +0100462 input_dtype
463 in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100464 and output_dtype != DType.INT32
465 ):
466 error_result = True
467
468 elif op["op"] == Op.MUL:
James Ward8b390432022-08-12 20:48:56 +0100469 if (
James Ward24dbc422022-10-19 12:20:31 +0100470 input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
James Ward8b390432022-08-12 20:48:56 +0100471 and output_dtype != DType.INT32
472 ):
473 error_result = True
474 elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100475 error_result = True
James Ward24dbc422022-10-19 12:20:31 +0100476 elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
477 error_result = True
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100478 elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100479 error_result = True
480
481 elif op["op"] == Op.TABLE:
482 if input_dtype == DType.INT8 and output_dtype != DType.INT8:
483 error_result = True
484 elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
485 error_result = True
486
487 elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
488 if output_dtype != DType.BOOL:
489 error_result = True
490
491 elif op["op"] == Op.CAST:
492 if (
493 (
494 input_dtype == DType.BOOL
495 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
496 )
497 or (
498 input_dtype == DType.INT8
499 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100500 not in [
501 DType.BOOL,
502 DType.INT16,
503 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100504 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100505 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100506 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100507 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100508 )
509 or (
510 input_dtype == DType.INT16
511 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100512 not in [
513 DType.BOOL,
514 DType.INT8,
515 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100516 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100517 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100518 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100519 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100520 )
521 or (
522 input_dtype == DType.INT32
523 and output_dtype
James Ward8b390432022-08-12 20:48:56 +0100524 not in [
525 DType.BOOL,
526 DType.INT8,
527 DType.INT16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100528 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +0100529 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +0100530 DType.BF16,
James Ward8b390432022-08-12 20:48:56 +0100531 ]
532 )
533 or (
534 input_dtype == DType.FP16
535 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100536 )
537 or (
James Ward24dbc422022-10-19 12:20:31 +0100538 input_dtype == DType.BF16
539 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
540 )
541 or (
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100542 input_dtype == DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100543 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
544 )
545 ):
546 error_result = True
547
548 elif op["op"] in {
549 Op.CONV2D,
550 Op.CONV3D,
551 Op.DEPTHWISE_CONV2D,
552 Op.TRANSPOSE_CONV2D,
553 }:
554 if (
555 input_dtype == DType.INT8
556 and output_dtype != DType.INT32
557 or input_dtype == DType.INT16
558 and output_dtype != DType.INT48
James Ward8b390432022-08-12 20:48:56 +0100559 or input_dtype == DType.FP16
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100560 and output_dtype not in (DType.FP16, DType.FP32)
James Ward24dbc422022-10-19 12:20:31 +0100561 or input_dtype == DType.BF16
562 and output_dtype != DType.FP32
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100563 or input_dtype == DType.FP32
564 and output_dtype != DType.FP32
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100565 ):
566 error_result = True
567 # invalid input types are ignored, to avoid reporting multiple errors
568
569 else:
570 if output_dtype != input_dtype:
571 error_result = True
572
573 info_dict = {
574 "error_name": ErrorIf.WrongOutputType,
575 "error_result": error_result,
576 "error_reason": (
577 "Output data type not supported for this configuration of operator"
578 ),
579 "param_reqs": {"rank": None, "dtype": None, "shape": None},
580 }
581 return info_dict
582
583 @staticmethod
584 def evWrongRank(check=False, **kwargs):
585 all_ranks = (1, 2, 3, 4, 5)
586
587 # Make a list of incorrect ranks
588 assert "op" in kwargs
589 op = kwargs["op"]
590 rmin, rmax = op["rank"]
591 rank_range = range(rmin, rmax + 1)
592 incorrect_ranks = list(set(all_ranks) - set(rank_range))
593 # Remove small incorrect ranks to avoid index errors
594 incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
595 # Set minimum incorrect rank to 3 to avoid index error
596 if op["op"] in [Op.RESIZE]:
597 incorrect_ranks = [3, 5]
598 elif op["op"] in [Op.TRANSPOSE]:
599 incorrect_ranks = [7, 8]
600 elif op["op"] in [Op.CONV3D]:
601 incorrect_ranks = [6, 7]
602
603 error_name = ErrorIf.WrongRank
604 param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
605 error_result = False
606 error_reason = "Rank not supported for this operator"
607
608 if check:
609 input_shape = kwargs["input_shape"]
610
611 if (
612 op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
613 and len(input_shape) != 4
614 ):
615 error_result = True
616 elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
617 error_result = True
618 elif op["op"] == Op.MATMUL and len(input_shape) != 3:
619 error_result = True
620 else:
621 if len(input_shape) not in rank_range:
622 error_result = True
623
624 info_dict = {
625 "error_name": error_name,
626 "error_result": error_result,
627 "error_reason": error_reason,
628 "param_reqs": param_reqs,
629 }
630 return info_dict
631
632 @staticmethod
633 def evWrongInputList(check=False, **kwargs):
634 error_name = ErrorIf.WrongInputList
635 param_reqs = {"rank": None, "dtype": None, "shape": None}
636 error_result = False
637 error_reason = "Op input list does not match expected input"
638
639 if check:
640 op = kwargs["op"]
641 input_list = kwargs["input_list"]
642 num_operands = kwargs["num_operands"]
643 if op["op"] in [Op.SCATTER, Op.GATHER]:
644 # SCATTER/GATHER add an indices input tensor in their build functions
645 num_operands += 1
646 if len(input_list) != num_operands:
647 error_result = True
648
649 info_dict = {
650 "error_name": error_name,
651 "error_result": error_result,
652 "error_reason": error_reason,
653 "param_reqs": param_reqs,
654 }
655 return info_dict
656
657 @staticmethod
658 def evWrongOutputList(check=False, **kwargs):
659 error_name = ErrorIf.WrongOutputList
660 param_reqs = {"rank": None, "dtype": None, "shape": None}
661 error_result = False
662 error_reason = "Op output list does not match expected output"
663
664 if check:
665 output_list = kwargs["output_list"]
666 # Note this will be incorrect if an operator returns more than one output
667 if len(output_list) != 1:
668 error_result = True
669
670 info_dict = {
671 "error_name": error_name,
672 "error_result": error_result,
673 "error_reason": error_reason,
674 "param_reqs": param_reqs,
675 }
676 return info_dict
677
678 @staticmethod
679 def evMaxDimExceeded(check=False, **kwargs):
680 error_name = ErrorIf.MaxDimExceeded
681 param_reqs = {
682 "rank": [4, 4],
683 "dtype": [DType.INT8],
684 "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
685 }
686 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100687 error_reason = f"At least one maximum dimension is greater than or equal to {MAX_RESIZE_DIMENSION}"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100688
689 if check:
690 input_shape = kwargs["input_shape"]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100691 output_shape = kwargs["output_shape"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100692 if (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100693 (input_shape[1] >= MAX_RESIZE_DIMENSION)
694 or (input_shape[2] >= MAX_RESIZE_DIMENSION)
695 or (output_shape[1] >= MAX_RESIZE_DIMENSION)
696 or (output_shape[2] >= MAX_RESIZE_DIMENSION)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100697 ):
698 error_result = True
699
700 info_dict = {
701 "error_name": error_name,
702 "error_result": error_result,
703 "error_reason": error_reason,
704 "param_reqs": param_reqs,
705 }
706 return info_dict
707
708 @staticmethod
709 def evBatchMismatch(check=False, **kwargs):
710 error_name = ErrorIf.BatchMismatch
711 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
712 error_result = False
713 error_reason = "Input batch size not equal to output batch size"
714
715 assert "op" in kwargs
716 op = kwargs["op"]
717 rmin, rmax = op["rank"]
718 rank_range = range(rmin, rmax + 1)
719
720 if check:
721 input_shape = kwargs["input_shape"]
722 output_shape = kwargs[
723 "result_tensor"
724 ].shape # Note this is just (N, OH, OW, C)
725
726 if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
727 error_result = True
728
729 info_dict = {
730 "error_name": error_name,
731 "error_result": error_result,
732 "error_reason": error_reason,
733 "param_reqs": param_reqs,
734 }
735 return info_dict
736
737 @staticmethod
738 def evChannelMismatch(check=False, **kwargs):
739 error_name = ErrorIf.ChannelMismatch
740 param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
741 error_result = False
742 error_reason = "Input channel size not equal to output channel size"
743
744 assert "op" in kwargs
745 op = kwargs["op"]
746 rmin, rmax = op["rank"]
747 rank_range = range(rmin, rmax + 1)
748
749 if check:
750 input_shape = kwargs["input_shape"]
751 output_shape = kwargs[
752 "result_tensor"
753 ].shape # Note this is just (N, OH, OW, C)
754 if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
755 error_result = True
756
757 info_dict = {
758 "error_name": error_name,
759 "error_result": error_result,
760 "error_reason": error_reason,
761 "param_reqs": param_reqs,
762 }
763 return info_dict
764
765 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100766 def evScaleSmallerEqualZero(check=False, **kwargs):
767 error_name = ErrorIf.ScaleSmallerEqualZero
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100768 param_reqs = {"rank": None, "dtype": None, "shape": None}
769 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100770 error_reason = "Scale value smaller than or equal zero"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100771
772 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100773 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100774
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100775 if min(scale) <= 0:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100776 error_result = True
777
778 info_dict = {
779 "error_name": error_name,
780 "error_result": error_result,
781 "error_reason": error_reason,
782 "param_reqs": param_reqs,
783 }
784 return info_dict
785
786 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100787 def evScaleNLargerMax(check=False, **kwargs):
788 error_name = ErrorIf.ScaleNLargerMax
789 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100790 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100791 error_reason = "Scale N value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100792
793 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100794 scale = kwargs["scale"]
795
796 if scale[0] > (1 << 11) or scale[2] > (1 << 11):
797 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100798
799 info_dict = {
800 "error_name": error_name,
801 "error_result": error_result,
802 "error_reason": error_reason,
803 "param_reqs": param_reqs,
804 }
805 return info_dict
806
807 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100808 def evScaleDLargerMax(check=False, **kwargs):
809 error_name = ErrorIf.ScaleDLargerMax
810 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100811 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100812 error_reason = "Scale D value larger than maximum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100813
814 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100815 scale = kwargs["scale"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100816
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100817 if (scale[0] > 0 and scale[1] >= (16 * scale[0])) or (
818 scale[2] > 0 and scale[3] >= (16 * scale[2])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100819 ):
820 error_result = True
821
822 info_dict = {
823 "error_name": error_name,
824 "error_result": error_result,
825 "error_reason": error_reason,
826 "param_reqs": param_reqs,
827 }
828 return info_dict
829
830 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100831 def evOffsetSmallerMin(check=False, **kwargs):
832 error_name = ErrorIf.OffsetSmallerMin
833 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100834 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100835 error_reason = "Offset value smaller than minimum value"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100836
837 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100838 scale = kwargs["scale"]
839 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100840
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100841 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] < -scale[0]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100842 error_result = True
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100843 elif scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] < -scale[2]):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100844 error_result = True
845
846 info_dict = {
847 "error_name": error_name,
848 "error_result": error_result,
849 "error_reason": error_reason,
850 "param_reqs": param_reqs,
851 }
852 return info_dict
853
854 @staticmethod
855 def evOffsetLargerEqualMax(check=False, **kwargs):
856 error_name = ErrorIf.OffsetLargerEqualMax
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100857 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100858 error_result = False
859 error_reason = "Offset value larger than or equal to maximum value"
860
861 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100862 scale = kwargs["scale"]
863 offset = kwargs["offset"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100864
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100865 if scale[0] > 0 and scale[0] <= (1 << 11) and (offset[0] >= 16 * scale[0]):
866 error_result = True
867 elif (
868 scale[2] > 0 and scale[2] <= (1 << 11) and (offset[1] >= 16 * scale[2])
869 ):
870 error_result = True
871
872 info_dict = {
873 "error_name": error_name,
874 "error_result": error_result,
875 "error_reason": error_reason,
876 "param_reqs": param_reqs,
877 }
878 return info_dict
879
880 @staticmethod
881 def evBorderSmallerMin(check=False, **kwargs):
882 error_name = ErrorIf.BorderSmallerMin
883 param_reqs = {"rank": None, "dtype": None, "shape": None}
884 error_result = False
885 error_reason = "Border value smaller than minimum value"
886
887 if check:
888 scale = kwargs["scale"]
889 border = kwargs["border"]
890
891 if (
892 scale[0] > 0
893 and scale[0] <= (1 << 11)
894 and (border[0] < (-16 * scale[0]))
895 ):
896 error_result = True
897 elif (
898 scale[2] > 0
899 and scale[2] <= (1 << 11)
900 and (border[1] < (-16 * scale[2]))
901 ):
902 error_result = True
903
904 info_dict = {
905 "error_name": error_name,
906 "error_result": error_result,
907 "error_reason": error_reason,
908 "param_reqs": param_reqs,
909 }
910 return info_dict
911
912 @staticmethod
913 def evBorderLargerEqualMax(check=False, **kwargs):
914 error_name = ErrorIf.BorderLargerEqualMax
915 param_reqs = {"rank": None, "dtype": None, "shape": None}
916 error_result = False
917 error_reason = "Border value larger than or equal to maximum value"
918
919 if check:
920 scale = kwargs["scale"]
921 border = kwargs["border"]
922
923 if scale[0] > 0 and scale[0] <= (1 << 11) and (border[0] >= scale[0]):
924 error_result = True
925 elif scale[2] > 0 and scale[2] <= (1 << 11) and (border[1] >= scale[2]):
926 error_result = True
927
928 info_dict = {
929 "error_name": error_name,
930 "error_result": error_result,
931 "error_reason": error_reason,
932 "param_reqs": param_reqs,
933 }
934 return info_dict
935
936 @staticmethod
937 def checkResizeParams(scale, offset, border):
938 return (
939 min(scale) > 0
940 and max(scale[0], scale[2]) <= (1 << 11)
941 and scale[1] < 16 * scale[0]
942 and scale[3] < 16 * scale[2]
943 and offset[0] >= -scale[0]
944 and offset[1] >= -scale[2]
945 and offset[0] < 16 * scale[0]
946 and offset[1] < 16 * scale[2]
947 and border[0] >= -16 * scale[0]
948 and border[1] >= -16 * scale[2]
949 and border[0] < scale[0]
950 and border[1] < scale[2]
951 )
952
953 @staticmethod
954 def evResizeOutputShapeMismatch(check=False, **kwargs):
955 error_name = ErrorIf.ResizeOutputShapeMismatch
956 param_reqs = {"rank": None, "dtype": None, "shape": None}
957 error_result = False
958 error_reason = (
959 "Mismatch between output shape provided and expected output shape"
960 )
961
962 if check:
963 input_shape = kwargs["input_shape"]
964 output_shape = kwargs["output_shape"]
965 scale = kwargs["scale"]
966 offset = kwargs["offset"]
967 border = kwargs["border"]
968
969 # Ensure parameters are valid
970 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
971
972 if (
973 params_valid
974 and max(output_shape) < MAX_RESIZE_DIMENSION
975 and max(input_shape) < MAX_RESIZE_DIMENSION
976 ):
977 output_y = (
978 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
979 ) // scale[1] + 1
980 output_x = (
981 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
982 ) // scale[3] + 1
983
984 if [output_y, output_x] != output_shape[1:-1]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100985 error_result = True
986
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100987 info_dict = {
988 "error_name": error_name,
989 "error_result": error_result,
990 "error_reason": error_reason,
991 "param_reqs": param_reqs,
992 }
993 return info_dict
994
995 @staticmethod
Jeremy Johnsona0e03f32022-06-13 17:48:09 +0100996 def evResizeOutputShapeNonInteger(check=False, **kwargs):
997 error_name = ErrorIf.ResizeOutputShapeNonInteger
998 param_reqs = {"rank": None, "dtype": None, "shape": None}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100999 error_result = False
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001000 error_reason = "Parameters do not yield exact integer output dimensions"
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001001
1002 if check:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001003 input_shape = kwargs["input_shape"]
1004 scale = kwargs["scale"]
1005 offset = kwargs["offset"]
1006 border = kwargs["border"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001007
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001008 # Ensure parameters are valid
1009 params_valid = TosaErrorValidator.checkResizeParams(scale, offset, border)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001010
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001011 if params_valid:
1012 remainder_y = (
1013 (input_shape[1] - 1) * scale[0] - offset[0] + border[0]
1014 ) % scale[1]
1015 remainder_x = (
1016 (input_shape[2] - 1) * scale[2] - offset[1] + border[1]
1017 ) % scale[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001018
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001019 if max(remainder_y, remainder_x) > 0:
1020 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021
1022 info_dict = {
1023 "error_name": error_name,
1024 "error_result": error_result,
1025 "error_reason": error_reason,
1026 "param_reqs": param_reqs,
1027 }
1028 return info_dict
1029
1030 @staticmethod
1031 def evRankMismatch(check=False, **kwargs):
1032 error_name = ErrorIf.RankMismatch
1033 param_reqs = {"rank": None, "dtype": None, "shape": None}
1034 error_result = False
1035 error_reason = "Input Rank does not match output rank"
1036
1037 if check:
1038 input1_shape = kwargs["input1"].shape
1039 input2_shape = kwargs["input2"].shape
1040 # In case of SELECT op
1041 input3_shape = (
1042 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1043 )
1044 output_shape = kwargs["result_tensor"].shape
1045 if (
1046 (len(input1_shape) != len(output_shape))
1047 or (len(input2_shape) != len(output_shape))
1048 or (len(input3_shape) != len(output_shape))
1049 ):
1050 error_result = True
1051
1052 info_dict = {
1053 "error_name": error_name,
1054 "error_result": error_result,
1055 "error_reason": error_reason,
1056 "param_reqs": param_reqs,
1057 }
1058 return info_dict
1059
1060 @staticmethod
1061 def evDimensionMismatch(check=False, **kwargs):
1062 error_name = ErrorIf.DimensionMismatch
1063 param_reqs = {"rank": None, "dtype": None, "shape": None}
1064 error_result = False
1065 error_reason = "Input Dimensions do not match output"
1066
1067 if check:
1068 input1_shape = kwargs["input1"].shape
1069 input2_shape = kwargs["input2"].shape
1070 # In case of SELECT op
1071 input3_shape = (
1072 kwargs["input3"].shape if "input3" in kwargs else input2_shape
1073 )
1074 output_shape = kwargs["result_tensor"].shape
1075 for i in range(
1076 min(len(input1_shape), len(input2_shape), len(input3_shape))
1077 ):
1078 if (
1079 (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
1080 or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
1081 or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
1082 ):
1083 error_result = True
1084
1085 info_dict = {
1086 "error_name": error_name,
1087 "error_result": error_result,
1088 "error_reason": error_reason,
1089 "param_reqs": param_reqs,
1090 }
1091 return info_dict
1092
1093 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001094 def _getZeroPoint(qinfo, index):
1095 """Return zero point value from quantization info.
1096
1097 Generally input_zp is index 0, output_zp is index 1
1098 """
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001099 return qinfo[index]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001100
1101 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001102 def evInputZeroPointNotZero(check=False, **kwargs):
1103 op = kwargs["op"]
1104 error_result = False
1105
1106 # Quantizable types
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001107 qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001108
1109 # This does not apply to quantizable types
1110 inputDtypes = [
1111 dtype
1112 for dtype in op["types"]
1113 if (isinstance(dtype, list) and dtype[0] not in qTypes)
1114 or (not isinstance(dtype, list) and dtype not in qTypes)
1115 ]
1116
1117 if check:
1118 input_dtype = kwargs["input_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001119 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001120 if op["op"] == Op.MATMUL:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001121 input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001122 for dtype, zp in (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001123 (kwargs["input_dtype"], input_zero_point),
1124 (kwargs["input2_dtype"], input2_zero_point),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001125 ):
1126 if dtype not in qTypes and zp != 0:
1127 error_result = True
1128 break
1129 else:
1130 error_result = input_dtype not in qTypes and input_zero_point != 0
1131
1132 info_dict = {
1133 "error_name": ErrorIf.InputZeroPointNotZero,
1134 "error_result": error_result,
1135 "error_reason": "Input DType not INT8 and zero point not 0",
1136 "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
1137 }
1138 return info_dict
1139
1140 @staticmethod
1141 def evWeightZeroPointNotZero(check=False, **kwargs):
1142 op = kwargs["op"]
1143
1144 # exclude inputs with INT8 weights
1145 inputDtypes = [
1146 t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
1147 ]
1148
1149 error_name = ErrorIf.WeightZeroPointNotZero
1150 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1151 error_result = False
1152 error_reason = "Weight DType not INT8 and zero point not 0"
1153
1154 if check:
1155 weight_dtype = kwargs["weight_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001156 weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001157 if weight_dtype != DType.INT8 and weight_zero_point != 0:
1158 error_result = True
1159
1160 info_dict = {
1161 "error_name": error_name,
1162 "error_result": error_result,
1163 "error_reason": error_reason,
1164 "param_reqs": param_reqs,
1165 }
1166 return info_dict
1167
1168 @staticmethod
1169 def evOutputZeroPointNotZero(check=False, **kwargs):
1170 op = kwargs["op"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001171 inputDtypes = [
1172 t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
1173 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001174
1175 error_name = ErrorIf.OutputZeroPointNotZero
1176 param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
1177 error_result = False
1178 error_reason = "Output DType not INT8 and zero point not 0"
1179
1180 if check:
1181 input_dtype = kwargs["input_dtype"]
1182 output_dtype = kwargs["output_dtype"]
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001183 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001184 if op["op"] == Op.AVG_POOL2D:
1185 if input_dtype != DType.INT8 and output_zero_point != 0:
1186 error_result = True
1187 elif (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001188 output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
1189 and output_zero_point != 0
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001190 ):
1191 error_result = True
1192
1193 info_dict = {
1194 "error_name": error_name,
1195 "error_result": error_result,
1196 "error_reason": error_reason,
1197 "param_reqs": param_reqs,
1198 }
1199 return info_dict
1200
1201 @staticmethod
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001202 def evU16InputZeroPointNotValid(check=False, **kwargs):
1203 error_name = ErrorIf.U16InputZeroPointNotValid
1204 param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
1205 error_result = False
1206 error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
1207
1208 if check:
1209 input_dtype = kwargs["input_dtype"]
1210 input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
1211 error_result = input_dtype == DType.UINT16 and input_zero_point not in [
1212 0,
1213 32768,
1214 ]
1215
1216 info_dict = {
1217 "error_name": error_name,
1218 "error_result": error_result,
1219 "error_reason": error_reason,
1220 "param_reqs": param_reqs,
1221 }
1222 return info_dict
1223
1224 @staticmethod
1225 def evU16OutputZeroPointNotValid(check=False, **kwargs):
1226 error_name = ErrorIf.U16OutputZeroPointNotValid
1227 param_reqs = {"rank": None, "dtype": None, "shape": None}
1228 error_result = False
1229 error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
1230
1231 if check:
1232 output_dtype = kwargs["output_dtype"]
1233 output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
1234
1235 error_result = output_dtype == DType.UINT16 and output_zero_point not in [
1236 0,
1237 32768,
1238 ]
1239
1240 info_dict = {
1241 "error_name": error_name,
1242 "error_result": error_result,
1243 "error_reason": error_reason,
1244 "param_reqs": param_reqs,
1245 }
1246 return info_dict
1247
1248 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001249 def evAxisSmallerZero(check=False, **kwargs):
1250 error_name = ErrorIf.AxisSmallerZero
1251 param_reqs = {"rank": None, "dtype": None, "shape": None}
1252 error_result = False
1253 error_reason = "Axis smaller than zero"
1254
1255 if check:
1256 axis = kwargs["axis"]
1257 if axis < 0:
1258 error_result = True
1259
1260 info_dict = {
1261 "error_name": error_name,
1262 "error_result": error_result,
1263 "error_reason": error_reason,
1264 "param_reqs": param_reqs,
1265 }
1266 return info_dict
1267
1268 @staticmethod
1269 def evAxisLargerRank(check=False, **kwargs):
1270 error_name = ErrorIf.AxisLargerRank
1271 param_reqs = {"rank": None, "dtype": None, "shape": None}
1272 error_result = False
1273 error_reason = "Axis larger than rank"
1274
1275 if check:
1276 axis = kwargs["axis"]
1277 shape = kwargs["input_shape"]
1278 if axis > len(shape):
1279 error_result = True
1280
1281 info_dict = {
1282 "error_name": error_name,
1283 "error_result": error_result,
1284 "error_reason": error_reason,
1285 "param_reqs": param_reqs,
1286 }
1287 return info_dict
1288
1289 @staticmethod
1290 def evShapeOfAxisNotOne(check=False, **kwargs):
1291 error_name = ErrorIf.ShapeOfAxisNotOne
1292 param_reqs = {"rank": None, "dtype": None, "shape": None}
1293 error_result = False
1294 error_reason = "shape[axis] is not equal to 1"
1295
1296 if check:
1297 axis = kwargs["axis"]
1298 shape = kwargs["output_shape"]
1299 if (0 <= axis < len(shape)) and shape[axis] != 1:
1300 error_result = True
1301
1302 info_dict = {
1303 "error_name": error_name,
1304 "error_result": error_result,
1305 "error_reason": error_reason,
1306 "param_reqs": param_reqs,
1307 }
1308 return info_dict
1309
1310 @staticmethod
1311 def evPadSmallerZero(check=False, **kwargs):
1312 error_name = ErrorIf.PadSmallerZero
1313 param_reqs = {"rank": None, "dtype": None, "shape": None}
1314 error_result = False
1315 error_reason = "At least one pad is smaller than zero"
1316
1317 if check:
1318 op = kwargs["op"]
1319 pad = kwargs["pad"]
1320 if op["op"] == Op.PAD:
1321 for padding in pad:
1322 if min(padding) < 0:
1323 error_result = True
1324 else:
1325 if min(pad) < 0:
1326 error_result = True
1327
1328 info_dict = {
1329 "error_name": error_name,
1330 "error_result": error_result,
1331 "error_reason": error_reason,
1332 "param_reqs": param_reqs,
1333 }
1334 return info_dict
1335
1336 @staticmethod
1337 def evPadLargerEqualKernel(check=False, **kwargs):
1338 error_name = ErrorIf.PadLargerEqualKernel
1339 param_reqs = {"rank": None, "dtype": None, "shape": None}
1340 error_result = False
1341 error_reason = "At least one pad is larger than kernel dimension"
1342
1343 if check:
1344 pad = kwargs["pad"]
Eric Kunzec1a97832022-07-01 16:56:09 -07001345 op = kwargs["op"]
1346 if op["op"] == Op.TRANSPOSE_CONV2D:
1347 # transpose_conv2d
1348 kernel = kwargs["weight_shape"][1:-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001349 if (
Eric Kunzec1a97832022-07-01 16:56:09 -07001350 pad[0] <= -kernel[0]
1351 or pad[1] <= -kernel[0]
1352 or pad[2] <= -kernel[1]
1353 or pad[3] <= -kernel[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001354 ):
1355 error_result = True
Eric Kunzec1a97832022-07-01 16:56:09 -07001356 else:
1357 # pooling op
1358 kernel = kwargs["kernel"]
1359 if min(pad) > 0 and min(kernel) > 1:
1360 if (
1361 pad[0] >= kernel[0]
1362 or pad[1] >= kernel[0]
1363 or pad[2] >= kernel[1]
1364 or pad[3] >= kernel[1]
1365 ):
1366 error_result = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001367
1368 info_dict = {
1369 "error_name": error_name,
1370 "error_result": error_result,
1371 "error_reason": error_reason,
1372 "param_reqs": param_reqs,
1373 }
1374 return info_dict
1375
1376 @staticmethod
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01001377 def evPadOutputShapeMismatch(check=False, **kwargs):
1378 error_name = ErrorIf.PadOutputShapeMismatch
1379 param_reqs = {"rank": None, "dtype": None, "shape": None}
1380 error_result = False
1381 error_reason = "Pad output shape mismatch for requested padding"
1382
1383 if check:
1384 pad = kwargs["pad"]
1385 input_shape = kwargs["input_shape"]
1386 output_shape = kwargs["output_shape"]
1387 for dim, padding in enumerate(pad):
1388 expected_size = input_shape[dim] + padding[0] + padding[1]
1389 if expected_size != output_shape[dim]:
1390 error_result = True
1391
1392 info_dict = {
1393 "error_name": error_name,
1394 "error_result": error_result,
1395 "error_reason": error_reason,
1396 "param_reqs": param_reqs,
1397 }
1398 return info_dict
1399
1400 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001401 def checkPoolingParams(kernel, stride, pad):
1402 return (
1403 min(kernel) >= 1
1404 and min(stride) >= 1
1405 and min(pad) >= 0
1406 and not (
1407 pad[0] >= kernel[0]
1408 or pad[1] >= kernel[0]
1409 or pad[2] >= kernel[1]
1410 or pad[3] >= kernel[1]
1411 )
1412 )
1413
1414 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001415 def evPoolingOutputShapeMismatch(check=False, **kwargs):
1416 error_name = ErrorIf.PoolingOutputShapeMismatch
1417 param_reqs = {"rank": None, "dtype": None, "shape": None}
1418 error_result = False
1419 error_reason = (
1420 "Mismatch between output shape provided and expected output shape"
1421 )
1422
1423 if check:
1424 pad = kwargs["pad"]
1425 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1426
1427 kernel = kwargs["kernel"]
1428 kernel_y, kernel_x = kernel[0], kernel[1]
1429
1430 input_shape = kwargs["input_shape"]
1431 IH, IW = input_shape[1], input_shape[2]
1432
1433 output_shape = kwargs["output_shape"]
1434 OH, OW = output_shape[1], output_shape[2]
1435
1436 stride = kwargs["stride"]
1437 stride_y, stride_x = stride[0], stride[1]
1438
1439 # calculate correct height, width dimensions
1440 if stride_x != 0 and stride_y != 0:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001441 y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
1442 x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001443
1444 # ensure parameters are valid
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001445 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001446
1447 if params_valid and (OH != y_correct or OW != x_correct):
1448 error_result = True
1449
1450 info_dict = {
1451 "error_name": error_name,
1452 "error_result": error_result,
1453 "error_reason": error_reason,
1454 "param_reqs": param_reqs,
1455 }
1456 return info_dict
1457
1458 @staticmethod
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001459 def evPoolingOutputShapeNonInteger(check=False, **kwargs):
1460 error_name = ErrorIf.PoolingOutputShapeNonInteger
1461 param_reqs = {"rank": None, "dtype": None, "shape": None}
1462 error_result = False
1463 error_reason = "Parameters do not yield exact integer output dimensions"
1464
1465 if check:
1466 pad = kwargs["pad"]
1467 pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
1468
1469 kernel = kwargs["kernel"]
1470 kernel_y, kernel_x = kernel[0], kernel[1]
1471
1472 input_shape = kwargs["input_shape"]
1473 IH, IW = input_shape[1], input_shape[2]
1474
1475 stride = kwargs["stride"]
1476 stride_y, stride_x = stride[0], stride[1]
1477
1478 # calculate remainder of height, width dimensions
1479 if stride_x != 0 and stride_y != 0:
1480 y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
1481 x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
1482
1483 # ensure parameters are valid
1484 params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
1485 if params_valid and (y_remainder != 0 or x_remainder != 0):
1486 error_result = True
1487
1488 info_dict = {
1489 "error_name": error_name,
1490 "error_result": error_result,
1491 "error_reason": error_reason,
1492 "param_reqs": param_reqs,
1493 }
1494 return info_dict
1495
1496 @staticmethod
Eric Kunzec1a97832022-07-01 16:56:09 -07001497 def checkConvParams(op, weight_shape, stride, pad, dilation):
1498 if op == Op.TRANSPOSE_CONV2D:
1499 pad_ok = (
1500 pad[0] > -weight_shape[1]
1501 and pad[1] > -weight_shape[1]
1502 and pad[2] > -weight_shape[2]
1503 and pad[3] > -weight_shape[2]
1504 )
1505 else:
1506 pad_ok = min(pad) >= 0
1507
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001508 return (
1509 # Check kernel sizes
1510 min(weight_shape[1:-1]) >= 1
1511 and min(stride) >= 1
Eric Kunzec1a97832022-07-01 16:56:09 -07001512 and pad_ok
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001513 and (dilation is None or min(dilation) >= 1)
1514 )
1515
1516 @staticmethod
1517 def evConvOutputShapeMismatch(check=False, **kwargs):
1518 error_name = ErrorIf.ConvOutputShapeMismatch
1519 param_reqs = {"rank": None, "dtype": None, "shape": None}
1520 error_result = False
1521 error_reason = (
1522 "Mismatch between output shape provided and expected output shape"
1523 )
1524
1525 if check:
1526 op = kwargs["op"]
1527 pad = kwargs["pad"]
1528 weight_shape = kwargs["weight_shape"]
1529 input_shape = kwargs["input_shape"]
1530 output_shape = kwargs["output_shape"]
1531 dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
1532 stride = kwargs["stride"]
1533
1534 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1535
1536 # calculate correct dimensions
1537 dims_correct = []
1538 if min(stride) > 0:
1539 for index in range(len(stride)):
1540 pad_offset = index * 2
1541 if op["op"] == Op.TRANSPOSE_CONV2D:
1542 dims_correct.append(
1543 (input_shape[index + 1] - 1) * stride[index]
Eric Kunzec1a97832022-07-01 16:56:09 -07001544 + pad[pad_offset]
1545 + pad[pad_offset + 1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001546 + weight_shape[index + kernel_offset]
1547 )
1548 else:
1549 dims_correct.append(
1550 (
1551 input_shape[index + 1]
1552 - 1
1553 + pad[pad_offset]
1554 + pad[pad_offset + 1]
1555 - (weight_shape[index + kernel_offset] - 1)
1556 * dilation[index]
1557 )
1558 // stride[index]
1559 + 1
1560 )
1561
1562 # ensure parameters are valid
1563 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001564 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001565 )
1566
1567 if params_valid and output_shape[1:-1] != dims_correct:
1568 error_result = True
1569
1570 info_dict = {
1571 "error_name": error_name,
1572 "error_result": error_result,
1573 "error_reason": error_reason,
1574 "param_reqs": param_reqs,
1575 }
1576 return info_dict
1577
1578 @staticmethod
1579 def evConvOutputShapeNonInteger(check=False, **kwargs):
1580 error_name = ErrorIf.ConvOutputShapeNonInteger
1581 param_reqs = {"rank": None, "dtype": None, "shape": None}
1582 error_result = False
1583 error_reason = "Parameters do not yield exact integer output dimensions"
1584
1585 if check:
1586 op = kwargs["op"]
1587 pad = kwargs["pad"]
1588 weight_shape = kwargs["weight_shape"]
1589 input_shape = kwargs["input_shape"]
1590 dilation = kwargs["dilation"]
1591 stride = kwargs["stride"]
1592
1593 kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
1594
1595 # calculate correct height, width dimensions
1596 remainders = []
1597 if min(stride) > 0:
1598 for index in range(len(stride)):
1599 pad_offset = index * 2
1600 remainders.append(
1601 (
1602 input_shape[index + 1]
1603 - 1
1604 + pad[pad_offset]
1605 + pad[pad_offset + 1]
1606 - (weight_shape[index + kernel_offset] - 1)
1607 * dilation[index]
1608 )
1609 % stride[index]
1610 )
1611
1612 # ensure parameters are valid
1613 params_valid = TosaErrorValidator.checkConvParams(
Eric Kunzec1a97832022-07-01 16:56:09 -07001614 op["op"], weight_shape, stride, pad, dilation
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001615 )
1616 if params_valid and max(remainders) > 0:
1617 error_result = True
1618
1619 info_dict = {
1620 "error_name": error_name,
1621 "error_result": error_result,
1622 "error_reason": error_reason,
1623 "param_reqs": param_reqs,
1624 }
1625 return info_dict
1626
1627 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001628 def evArgmaxOutputShapeMismatch(check=False, **kwargs):
1629 error_name = ErrorIf.ArgmaxOutputShapeMismatch
1630 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1631 error_result = False
1632 error_reason = (
1633 "Mismatch between output shape provided and expected output shape"
1634 )
1635
1636 if check:
1637 output_shape = kwargs["output_shape"]
1638 input_shape = kwargs["input_shape"]
1639 axis = kwargs["axis"]
1640
1641 dimension_match = True
1642 axis_shift = 0
1643
1644 # Check that rank is correct before trying to check dimensions
1645 if (len(input_shape) - 1) == len(output_shape):
1646 for i in range(len(input_shape)):
1647 if i == axis:
1648 axis_shift = 1
1649 continue
1650 if input_shape[i] != output_shape[i - axis_shift]:
1651 dimension_match = False
1652
1653 if not dimension_match:
1654 error_result = True
1655
1656 info_dict = {
1657 "error_name": error_name,
1658 "error_result": error_result,
1659 "error_reason": error_reason,
1660 "param_reqs": param_reqs,
1661 }
1662 return info_dict
1663
1664 @staticmethod
1665 def evArgmaxOutputRankMismatch(check=False, **kwargs):
1666 error_name = ErrorIf.ArgmaxOutputRankMismatch
1667 param_reqs = {"rank": None, "dtype": None, "shape": None}
1668 error_result = False
1669 error_reason = (
1670 "Mismatch between output shape provided and expected output shape"
1671 )
1672
1673 if check:
1674 output_shape = kwargs["output_shape"]
1675 input_shape = kwargs["input_shape"]
1676 axis = kwargs["axis"]
1677 valid_params = axis >= 0 and axis < len(input_shape)
1678
1679 if valid_params and (len(input_shape) - 1) != len(output_shape):
1680 error_result = True
1681
1682 info_dict = {
1683 "error_name": error_name,
1684 "error_result": error_result,
1685 "error_reason": error_reason,
1686 "param_reqs": param_reqs,
1687 }
1688 return info_dict
1689
1690 @staticmethod
1691 def evKernelSmallerOne(check=False, **kwargs):
1692 error_name = ErrorIf.KernelSmallerOne
1693 param_reqs = {"rank": None, "dtype": None, "shape": None}
1694 error_result = False
1695 error_reason = "At least one kernel dimension is smaller than zero"
1696
1697 if check:
1698 kernel = kwargs["kernel"]
1699 if min(kernel) < 1:
1700 error_result = True
1701
1702 info_dict = {
1703 "error_name": error_name,
1704 "error_result": error_result,
1705 "error_reason": error_reason,
1706 "param_reqs": param_reqs,
1707 }
1708 return info_dict
1709
1710 @staticmethod
1711 def evStrideSmallerOne(check=False, **kwargs):
1712 error_name = ErrorIf.StrideSmallerOne
1713 param_reqs = {"rank": None, "dtype": None, "shape": None}
1714 error_result = False
1715 error_reason = "At least one stride dimension is smaller than zero"
1716
1717 if check:
1718 stride = kwargs["stride"]
1719 if min(stride) < 1:
1720 error_result = True
1721
1722 info_dict = {
1723 "error_name": error_name,
1724 "error_result": error_result,
1725 "error_reason": error_reason,
1726 "param_reqs": param_reqs,
1727 }
1728 return info_dict
1729
1730 @staticmethod
1731 def evDilationSmallerOne(check=False, **kwargs):
1732 error_result = check and min(kwargs["dilation"]) < 1
1733 return {
1734 "error_name": ErrorIf.DilationSmallerOne,
1735 "error_reason": "At least one dilation is smaller than one",
1736 "param_reqs": {"rank": None, "dtype": None, "shape": None},
1737 "error_result": error_result,
1738 }
1739
1740 @staticmethod
1741 def evScaleTrue(check=False, **kwargs):
1742 error_name = ErrorIf.ScaleTrue
1743 param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
1744 error_result = False
1745 error_reason = "Scale set to true but input type is INT48"
1746
1747 if check:
1748 input_dtype = kwargs["input_dtype"]
1749 scale32 = kwargs["scale32"]
1750 if scale32 and input_dtype == DType.INT48:
1751 error_result = True
1752
1753 info_dict = {
1754 "error_name": error_name,
1755 "error_result": error_result,
1756 "error_reason": error_reason,
1757 "param_reqs": param_reqs,
1758 }
1759 return info_dict
1760
1761 @staticmethod
1762 def evScaleNotTrue(check=False, **kwargs):
1763 error_name = ErrorIf.ScaleNotTrue
1764 param_reqs = {"rank": None, "dtype": None, "shape": None}
1765 error_result = False
1766 error_reason = "Scale set to false but double round set to true"
1767
1768 if check:
1769 scale32 = kwargs["scale32"]
1770 double_round = kwargs["double_round"]
1771 if not scale32 and double_round:
1772 error_result = True
1773
1774 info_dict = {
1775 "error_name": error_name,
1776 "error_result": error_result,
1777 "error_reason": error_reason,
1778 "param_reqs": param_reqs,
1779 }
1780 return info_dict
1781
1782 @staticmethod
1783 def evTensorSizeInputOutputMismatch(check=False, **kwargs):
1784 error_name = ErrorIf.TensorSizeInputOutputMismatch
1785 param_reqs = {"rank": None, "dtype": None, "shape": None}
1786 error_result = False
1787 error_reason = "Input tensor size does not match output tensor size"
1788
1789 if check:
1790 input_shape = kwargs["input_shape"]
1791 output_shape = kwargs["output_shape"]
1792 input_size = np.prod(input_shape)
1793 output_size = np.prod(output_shape)
1794 if input_size != output_size:
1795 error_result = True
1796
1797 info_dict = {
1798 "error_name": error_name,
1799 "error_result": error_result,
1800 "error_reason": error_reason,
1801 "param_reqs": param_reqs,
1802 }
1803 return info_dict
1804
1805 @staticmethod
1806 def evStartSmallerZero(check=False, **kwargs):
1807 error_name = ErrorIf.StartSmallerZero
1808 param_reqs = {"rank": None, "dtype": None, "shape": None}
1809 error_result = False
1810 error_reason = "Starting point smaller than zero"
1811
1812 if check:
1813 input_shape = kwargs["input_shape"]
1814 start = kwargs["start"]
1815 rank = len(input_shape)
1816 if len(start) == rank:
1817 for index in range(rank):
1818 if start[index] < 0:
1819 error_result = True
1820
1821 info_dict = {
1822 "error_name": error_name,
1823 "error_result": error_result,
1824 "error_reason": error_reason,
1825 "param_reqs": param_reqs,
1826 }
1827 return info_dict
1828
1829 @staticmethod
1830 def evSizeSmallerEqualZero(check=False, **kwargs):
1831 error_name = ErrorIf.SizeSmallerEqualZero
1832 param_reqs = {"rank": None, "dtype": None, "shape": None}
1833 error_result = False
1834 error_reason = "Size smaller than or equal to zero"
1835
1836 if check:
1837 input_shape = kwargs["input_shape"]
1838 size = kwargs["size"]
1839 rank = len(input_shape)
1840 if len(size) == rank:
1841 for index in range(rank):
1842 if size[index] <= 0:
1843 error_result = True
1844
1845 info_dict = {
1846 "error_name": error_name,
1847 "error_result": error_result,
1848 "error_reason": error_reason,
1849 "param_reqs": param_reqs,
1850 }
1851 return info_dict
1852
1853 @staticmethod
1854 def evStartSizeOutsideBounds(check=False, **kwargs):
1855 error_name = ErrorIf.StartSizeOutsideBounds
1856 param_reqs = {"rank": None, "dtype": None, "shape": None}
1857 error_result = False
1858 error_reason = "starting point plus size larger than input dimension"
1859
1860 if check:
1861 input_shape = kwargs["input_shape"]
1862 start = kwargs["start"]
1863 size = kwargs["size"]
1864 rank = len(input_shape)
1865 if len(start) == rank and len(size) == rank:
1866 for index in range(rank):
1867 if start[index] + size[index] > input_shape[index]:
1868 error_result = True
1869
1870 info_dict = {
1871 "error_name": error_name,
1872 "error_result": error_result,
1873 "error_reason": error_reason,
1874 "param_reqs": param_reqs,
1875 }
1876 return info_dict
1877
1878 @staticmethod
1879 def evSizeOutputShapeMismatch(check=False, **kwargs):
1880 error_name = ErrorIf.SizeOutputShapeMismatch
1881 param_reqs = {"rank": None, "dtype": None, "shape": None}
1882 error_result = False
1883 error_reason = "Size does not match output dimension"
1884
1885 if check:
1886 input_shape = kwargs["input_shape"]
1887 output_shape = kwargs["output_shape"]
1888 size = kwargs["size"]
1889 rank = len(input_shape)
1890 if len(size) == rank:
1891 for index in range(rank):
1892 if size[index] != output_shape[index]:
1893 error_result = True
1894
1895 info_dict = {
1896 "error_name": error_name,
1897 "error_result": error_result,
1898 "error_reason": error_reason,
1899 "param_reqs": param_reqs,
1900 }
1901 return info_dict
1902
1903 @staticmethod
1904 def evInputSizeStartLengthMismatch(check=False, **kwargs):
1905 error_name = ErrorIf.InputSizeStartLengthMismatch
1906 param_reqs = {"rank": None, "dtype": None, "shape": None}
1907 error_result = False
1908 error_reason = "rank of input not equal to length of start or size"
1909
1910 if check:
1911 input_shape = kwargs["input_shape"]
1912 start = kwargs["start"]
1913 size = kwargs["size"]
1914 rank = len(input_shape)
1915 if rank != len(start) or rank != len(size):
1916 error_result = True
1917
1918 info_dict = {
1919 "error_name": error_name,
1920 "error_result": error_result,
1921 "error_reason": error_reason,
1922 "param_reqs": param_reqs,
1923 }
1924 return info_dict
1925
1926 @staticmethod
1927 def evIndexOutsideBounds(check=False, **kwargs):
1928 error_name = ErrorIf.IndexOutsideBounds
1929 param_reqs = {"rank": None, "dtype": None, "shape": None}
1930 error_result = False
1931 error_reason = "Index outside of allowed bounds"
1932
1933 if check:
1934 input_shape = kwargs["input_shape"]
1935 perms = kwargs["perms"]
1936 rank = len(input_shape)
1937
1938 for index in perms:
1939 if index < 0 or index > rank:
1940 error_result = True
1941
1942 info_dict = {
1943 "error_name": error_name,
1944 "error_result": error_result,
1945 "error_reason": error_reason,
1946 "param_reqs": param_reqs,
1947 }
1948 return info_dict
1949
1950 @staticmethod
1951 def evIndexUsedTwice(check=False, **kwargs):
1952 error_name = ErrorIf.IndexUsedTwice
1953 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1954 error_result = False
1955 error_reason = "Index used multiple times"
1956
1957 if check:
1958 perms = kwargs["perms"]
1959
1960 unique_indices = []
1961 for index in perms:
1962 if index in unique_indices:
1963 error_result = True
1964 else:
1965 unique_indices.append(index)
1966
1967 info_dict = {
1968 "error_name": error_name,
1969 "error_result": error_result,
1970 "error_reason": error_reason,
1971 "param_reqs": param_reqs,
1972 }
1973 return info_dict
1974
1975 @staticmethod
1976 def evMaxSmallerMin(check=False, **kwargs):
1977 error_name = ErrorIf.MaxSmallerMin
1978 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
1979 error_result = False
1980 error_reason = "Max value smaller than min value"
1981
1982 if check:
1983 max_val = kwargs["max_val"]
1984 min_val = kwargs["min_val"]
1985 if max_val < min_val:
1986 error_result = True
1987
1988 info_dict = {
1989 "error_name": error_name,
1990 "error_result": error_result,
1991 "error_reason": error_reason,
1992 "param_reqs": param_reqs,
1993 }
1994 return info_dict
1995
1996 @staticmethod
1997 def evConcatInputRankMismatch(check=False, **kwargs):
1998 error_name = ErrorIf.ConcatInputRankMismatch
1999 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2000 error_result = False
2001 error_reason = "Input ranks are not identical"
2002
2003 if check:
2004 inputs = kwargs["inputs"]
2005 input_shape = kwargs["input_shape"]
2006 for input in inputs:
2007 if len(input.shape) != len(input_shape):
2008 error_result = True
2009
2010 info_dict = {
2011 "error_name": error_name,
2012 "error_result": error_result,
2013 "error_reason": error_reason,
2014 "param_reqs": param_reqs,
2015 }
2016 return info_dict
2017
2018 @staticmethod
2019 def evConcatInputDimMismatch(check=False, **kwargs):
2020 error_name = ErrorIf.ConcatInputDimMismatch
2021 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2022 error_result = False
2023 error_reason = "Input dimensions differ on too many axes"
2024
2025 if check:
2026 inputs = kwargs["inputs"]
2027 input_shape = kwargs["input_shape"]
2028 axis = kwargs["axis"]
2029
2030 # Ensure rank is valid before checking dims.
2031 valid_rank = True
2032 for input in inputs:
2033 if len(input.shape) != len(input_shape):
2034 valid_rank = False
2035
2036 if valid_rank:
2037 for input in inputs:
2038 for i, dim in enumerate(input.shape):
2039 if dim != input_shape[i] and axis != i:
2040 error_result = True
2041
2042 info_dict = {
2043 "error_name": error_name,
2044 "error_result": error_result,
2045 "error_reason": error_reason,
2046 "param_reqs": param_reqs,
2047 }
2048 return info_dict
2049
2050 @staticmethod
2051 def evConcatShapeSumMismatch(check=False, **kwargs):
2052 error_name = ErrorIf.ConcatShapeSumMismatch
2053 param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
2054 error_result = False
2055 error_reason = "Sum of dimensions on axis not equal to output dimension"
2056
2057 if check:
2058 inputs = kwargs["inputs"]
2059 input_shape = kwargs["input_shape"]
2060 output_shape = kwargs["output_shape"]
2061 axis = kwargs["axis"]
2062
2063 # Ensure rank is valid before checking dims.
2064 valid_params = True
2065 for input in inputs:
2066 if len(input.shape) != len(input_shape):
2067 valid_params = False
2068 if axis < 0 or axis > len(input_shape):
2069 valid_params = False
2070
2071 if valid_params:
2072 axis_dim_sum = 0
2073 for input in inputs:
2074 axis_dim_sum += input.shape[axis]
2075
2076 if axis_dim_sum != output_shape[axis]:
2077 error_result = True
2078
2079 info_dict = {
2080 "error_name": error_name,
2081 "error_result": error_result,
2082 "error_reason": error_reason,
2083 "param_reqs": param_reqs,
2084 }
2085 return info_dict
2086
2087 @staticmethod
2088 def evInputListThenGraphMismatch(check=False, **kwargs):
2089 error_name = ErrorIf.CondIfInputListThenGraphMismatch
2090 param_reqs = {"rank": None, "dtype": None, "shape": None}
2091 error_result = False
2092 error_reason = "Input list shape does not match then-graph shape"
2093
2094 if check:
2095 a = kwargs["a"]
2096 b = kwargs["b"]
2097 basicBlocks = kwargs["basicBlocks"]
2098 then_block = basicBlocks[1]
2099 then_inputs = then_block.inputs
2100 then_tens = then_block.tensors
2101 if (a.shape != then_tens[then_inputs[0]].shape) or (
2102 b.shape != then_tens[then_inputs[1]].shape
2103 ):
2104 error_result = True
2105
2106 info_dict = {
2107 "error_name": error_name,
2108 "error_result": error_result,
2109 "error_reason": error_reason,
2110 "param_reqs": param_reqs,
2111 }
2112 return info_dict
2113
2114 @staticmethod
2115 def evInputListElseGraphMismatch(check=False, **kwargs):
2116 error_name = ErrorIf.CondIfInputListElseGraphMismatch
2117 param_reqs = {"rank": None, "dtype": None, "shape": None}
2118 error_result = False
2119 error_reason = "Input list shape does not match else-graph shape"
2120
2121 if check:
2122 a = kwargs["a"]
2123 b = kwargs["b"]
2124 basicBlocks = kwargs["basicBlocks"]
2125 else_block = basicBlocks[2]
2126 else_inputs = else_block.inputs
2127 else_tens = else_block.tensors
2128 if (a.shape != else_tens[else_inputs[0]].shape) or (
2129 b.shape != else_tens[else_inputs[1]].shape
2130 ):
2131 error_result = True
2132
2133 info_dict = {
2134 "error_name": error_name,
2135 "error_result": error_result,
2136 "error_reason": error_reason,
2137 "param_reqs": param_reqs,
2138 }
2139 return info_dict
2140
2141 @staticmethod
2142 def evOutputListThenGraphMismatch(check=False, **kwargs):
2143 error_name = ErrorIf.CondIfOutputListThenGraphMismatch
2144 param_reqs = {"rank": None, "dtype": None, "shape": None}
2145 error_result = False
2146 error_reason = "Output list shape does not match then-graph shape"
2147
2148 if check:
2149 basicBlocks = kwargs["basicBlocks"]
2150 cond_block = basicBlocks[0]
2151 cond_outputs = cond_block.outputs
2152 cond_tens = cond_block.tensors
2153 then_block = basicBlocks[1]
2154 then_outputs = then_block.outputs
2155 then_tens = then_block.tensors
2156 if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2157 error_result = True
2158
2159 info_dict = {
2160 "error_name": error_name,
2161 "error_result": error_result,
2162 "error_reason": error_reason,
2163 "param_reqs": param_reqs,
2164 }
2165 return info_dict
2166
2167 @staticmethod
2168 def evOutputListElseGraphMismatch(check=False, **kwargs):
2169 error_name = ErrorIf.CondIfOutputListElseGraphMismatch
2170 param_reqs = {"rank": None, "dtype": None, "shape": None}
2171 error_result = False
2172 error_reason = "Output list shape does not match else-graph shape"
2173
2174 if check:
2175 basicBlocks = kwargs["basicBlocks"]
2176 cond_block = basicBlocks[0]
2177 cond_outputs = cond_block.outputs
2178 cond_tens = cond_block.tensors
2179 else_block = basicBlocks[2]
2180 else_outputs = else_block.outputs
2181 else_tens = else_block.tensors
2182 if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
2183 error_result = True
2184
2185 info_dict = {
2186 "error_name": error_name,
2187 "error_result": error_result,
2188 "error_reason": error_reason,
2189 "param_reqs": param_reqs,
2190 }
2191 return info_dict
2192
2193 @staticmethod
2194 def evInputListOutputListMismatch(check=False, **kwargs):
2195 error_name = ErrorIf.InputListOutputListMismatch
2196 param_reqs = {"rank": None, "dtype": None, "shape": None}
2197 error_result = False
2198 error_reason = "Input list does not match output list"
2199
2200 if check:
2201 basicBlocks = kwargs["basicBlocks"]
2202 while_block = basicBlocks[0]
2203 while_inputs = while_block.inputs
2204 while_outputs = while_block.outputs
2205 while_tens = while_block.tensors
2206 if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
2207 error_result = True
2208
2209 info_dict = {
2210 "error_name": error_name,
2211 "error_result": error_result,
2212 "error_reason": error_reason,
2213 "param_reqs": param_reqs,
2214 }
2215 return info_dict
2216
2217 @staticmethod
2218 def evInputListCondGraphMismatch(check=False, **kwargs):
2219 error_name = ErrorIf.InputListCondGraphMismatch
2220 param_reqs = {"rank": None, "dtype": None, "shape": None}
2221 error_result = False
2222 error_reason = "Input list does not match cond graph"
2223
2224 if check:
2225 basicBlocks = kwargs["basicBlocks"]
2226 while_block = basicBlocks[0]
2227 while_inputs = while_block.inputs
2228 while_tens = while_block.tensors
2229 cond_block = basicBlocks[1]
2230 cond_inputs = cond_block.inputs
2231 cond_tens = cond_block.tensors
2232 if (
2233 while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
2234 ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
2235 error_result = True
2236
2237 info_dict = {
2238 "error_name": error_name,
2239 "error_result": error_result,
2240 "error_reason": error_reason,
2241 "param_reqs": param_reqs,
2242 }
2243 return info_dict
2244
2245 @staticmethod
2246 def evInputListBodyGraphInputMismatch(check=False, **kwargs):
2247 error_name = ErrorIf.InputListBodyGraphInputMismatch
2248 param_reqs = {"rank": None, "dtype": None, "shape": None}
2249 error_result = False
2250 error_reason = "Input list does not match body graph input"
2251
2252 if check:
2253 basicBlocks = kwargs["basicBlocks"]
2254 while_block = basicBlocks[0]
2255 while_inputs = while_block.inputs
2256 while_tens = while_block.tensors
2257 body_block = basicBlocks[2]
2258 body_outputs = body_block.inputs
2259 body_tens = body_block.tensors
2260 if (
2261 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2262 ) or (
2263 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2264 ):
2265 error_result = True
2266
2267 info_dict = {
2268 "error_name": error_name,
2269 "error_result": error_result,
2270 "error_reason": error_reason,
2271 "param_reqs": param_reqs,
2272 }
2273 return info_dict
2274
2275 @staticmethod
2276 def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
2277 error_name = ErrorIf.InputListBodyGraphOutputMismatch
2278 param_reqs = {"rank": None, "dtype": None, "shape": None}
2279 error_result = False
2280 error_reason = "Input list does not match body graph output"
2281
2282 if check:
2283 basicBlocks = kwargs["basicBlocks"]
2284 while_block = basicBlocks[0]
2285 while_inputs = while_block.inputs
2286 while_tens = while_block.tensors
2287 body_block = basicBlocks[2]
2288 body_outputs = body_block.outputs
2289 body_tens = body_block.tensors
2290 if (
2291 while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
2292 ) or (
2293 while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
2294 ):
2295 error_result = True
2296 info_dict = {
2297 "error_name": error_name,
2298 "error_result": error_result,
2299 "error_reason": error_reason,
2300 "param_reqs": param_reqs,
2301 }
2302 return info_dict
2303
2304 @staticmethod
2305 def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
2306 error_name = ErrorIf.CondGraphOutputNotMatchingBool
2307 param_reqs = {"rank": None, "dtype": None, "shape": None}
2308 error_result = False
2309 error_reason = "Cond graph output is not a match list of booleans"
2310
2311 if check:
2312 basicBlocks = kwargs["basicBlocks"]
2313 cond_block = basicBlocks[1]
2314 cond_outputs = cond_block.outputs
2315 cond_tens = cond_block.tensors
2316 if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
2317 error_result = True
2318
2319 info_dict = {
2320 "error_name": error_name,
2321 "error_result": error_result,
2322 "error_reason": error_reason,
2323 "param_reqs": param_reqs,
2324 }
2325 return info_dict
2326
2327
2328class TosaInvalidValidator:
2329 @staticmethod
2330 def ivWrongDataTypeOrModeResize(**kwargs):
2331 input_dtype = kwargs["input_dtype"]
2332 args = kwargs["args"]
2333 mode = args[0]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002334 output_dtype = args[5]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002335
2336 if mode == ResizeMode.BILINEAR:
2337 # Invalid output data type / Invalid input datatype
2338 return (
2339 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002340 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
James Ward8b390432022-08-12 20:48:56 +01002341 and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
James Ward24dbc422022-10-19 12:20:31 +01002342 and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002343 and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002344 )
2345 elif mode == ResizeMode.NEAREST:
2346 # Invalid output data type / Invalid input datatype
2347 return (input_dtype != output_dtype) or (
James Ward24dbc422022-10-19 12:20:31 +01002348 input_dtype
2349 not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002350 )
2351 else:
2352 # Invalid resize mode
2353 return True
2354
2355 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002356 def ivHeightWidthInvalid(**kwargs):
2357 opName = kwargs["opName"]
2358
2359 inputShapes = kwargs["shapeList"]
2360 input_shape = inputShapes[0]
2361
2362 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002363
2364 # MaxPool2D has no accum_dtype arg
2365 stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2)
2366 strides = args[stride_idx]
2367 padding = args[pad_idx]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002368
2369 if opName.endswith("pool2d"):
2370 # avg_pool2d, max_pool2d
2371 kernel_shape = args[2]
2372 h = (
2373 input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
2374 ) // strides[0]
2375 w = (
2376 input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
2377 ) // strides[1]
2378 # return True if any dimension is < 1
2379 return h < 1 or w < 1
2380
2381 if opName.startswith("transpose_conv2d"):
2382 # transpose_conv2d
TatWai Chong24594f52022-06-08 00:48:04 -07002383 output_shape = args[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002384 filter_shape = inputShapes[1]
2385 kernel_shape = filter_shape[1:-1]
2386
TatWai Chong24594f52022-06-08 00:48:04 -07002387 def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002388 """Calculate the transpose_conv2d output size for a dimension.
2389
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002390 Args:
2391 in_size: the input size - int
2392 stride: the stride - int
2393 kernel_size: the kernel size - int
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002394 out_pad: the output padding - int
2395 in_pad: the input padding - int
2396
2397 Returns:
2398 the output size
2399 """
TatWai Chong24594f52022-06-08 00:48:04 -07002400 return (in_size - 1) * stride + kernel_size - in_pad - out_pad
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002401
2402 for pad_h, pad_w in (
2403 (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
2404 (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
2405 (0, 0), # VALID padding
2406 ):
2407 h = get_out_size(
2408 input_shape[1],
2409 strides[0],
2410 kernel_shape[0],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002411 padding[0],
2412 pad_h,
2413 )
2414 w = get_out_size(
2415 input_shape[2],
2416 strides[1],
2417 kernel_shape[1],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002418 padding[1],
2419 pad_w,
2420 )
2421 if output_shape[1] == h and output_shape[2] == w:
2422 return False
2423
2424 # output shape does not match the expected shape for any padding option
2425 return True
2426
2427 if "conv2d" in opName or "conv3d" in opName:
2428 # conv2d, conv3d, depthwise_conv2d
2429 dilations = args[2]
2430 filter_shape = inputShapes[1]
2431 kernel_shape = (
2432 filter_shape[0:2]
2433 if opName.startswith("depthwise_conv2d")
2434 else filter_shape[1:-1]
2435 )
2436
2437 for i in range(len(kernel_shape)):
2438 dim = (
2439 input_shape[i + 1]
2440 - kernel_shape[i]
2441 - (kernel_shape[i] - 1) * (dilations[i] - 1)
2442 + padding[i * 2 + 0]
2443 + padding[i * 2 + 1]
2444 ) // strides[i] + 1
2445 # return True if any dimension is < 1
2446 if dim < 1:
2447 return True
2448 return False
2449
2450 assert False, f"Unrecognized Op: {opName}"
2451
2452 @staticmethod
2453 def ivNonPositiveOutputShape(**kwargs):
2454 args = kwargs["args"]
James Ward8b390432022-08-12 20:48:56 +01002455 output_shape = args[3]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002456 if output_shape[1] <= 0 or output_shape[2] <= 0:
2457 # Negative output shape
2458 return True
2459 return False