blob: ce3f81f968fd03eefbcf1b193b01b943e78737e5 [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010016from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010017from generator.tosa_utils import usableDTypes
Les Bell0e027d42021-11-09 14:42:14 +000018from tosa.DType import DType
19from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010020
21
Eric Kunzee5e26762020-10-13 16:11:07 -070022class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010023 # Maximum rank of tensor supported by test generator.
24 TOSA_TENSOR_MAX_RANK = 6
25
Eric Kunzee5e26762020-10-13 16:11:07 -070026 def __init__(self, args):
27 self.args = args
28 self.basePath = args.output_dir
29 self.random_seed = args.random_seed
30 self.ser = None
31 self.rng = np.random.default_rng(self.random_seed)
32 self.createDynamicOpLists()
33 self.initOpListDefaults()
34 self.quantGen = TosaQuantGen()
35 # Force makeShape to do a specific starting shape
36 self.targetted_shape = None
37
38 def createSerializer(self, opName, testPath):
39 self.testPath = os.path.join(opName, testPath)
40
41 fullPath = os.path.join(self.basePath, self.testPath)
42 os.makedirs(fullPath, exist_ok=True)
43 self.ser = ts.TosaSerializer(fullPath)
44
45 def getSerializer(self):
46 return self.ser
47
48 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080049 with open(
50 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
51 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070052 fd.write(self.ser.serialize())
53
Kevin Cheng550ccc52021-03-03 11:21:43 -080054 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
55 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070056
Matthew Haddon74567092021-07-16 15:38:20 +010057 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000058 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010059 seed = self.random_seed + 1
60 self.rng = np.random.default_rng(seed)
61
Eric Kunzee5e26762020-10-13 16:11:07 -070062 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070063 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070064 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070065 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070066 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070067 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070068 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010069 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
70 elif dtype == DType.UINT8:
71 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070072 elif dtype == DType.INT16:
73 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010074 elif dtype == DType.UINT16:
75 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070076 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080077 return np.int32(
78 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
79 )
Eric Kunzee5e26762020-10-13 16:11:07 -070080 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080081 return np.int64(
82 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
83 )
Eric Kunzee5e26762020-10-13 16:11:07 -070084 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +010085 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070086 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -080087 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070088
Kevin Cheng989cb052021-04-28 16:29:44 -070089 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -070090 placeholders = []
91
Kevin Cheng989cb052021-04-28 16:29:44 -070092 assert len(shape_list) == len(dtype_list)
93
94 for idx, shape in enumerate(shape_list):
95 arr = self.getRandTensor(shape, dtype_list[idx])
96 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -070097
98 return placeholders
99
Kevin Cheng989cb052021-04-28 16:29:44 -0700100 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700101 consts = []
102
Kevin Cheng989cb052021-04-28 16:29:44 -0700103 assert len(shape_list) == len(dtype_list)
104
105 for idx, shape in enumerate(shape_list):
106 arr = self.getRandTensor(shape, dtype_list[idx])
107 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700108
109 return consts
110
111 def makeShape(self, rank):
112 if self.targetted_shape:
113 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800114 return np.int32(
115 self.rng.integers(
116 low=self.args.tensor_shape_range[0],
117 high=self.args.tensor_shape_range[1],
118 size=rank,
119 )
120 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 def setTargetShape(self, shape):
123 self.targetted_shape = shape
124
125 def randInt(self, low=0, high=256):
126 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
127
128 def getRandNumberDType(self, dtype):
129 if dtype == DType.FLOAT:
130 return self.rng.random()
131 elif dtype == DType.BOOL:
132 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700133 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700134 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700135 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100137 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700138 elif dtype == DType.INT16:
139 low, high = (-32768, 32768)
140 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800141 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700142 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800143 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700144 # Special size
145 return np.int64(self.rng.integers(low, high, size=1))[0]
146 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800147 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700148
149 return np.int32(self.rng.integers(low, high, size=1))[0]
150
151 def shapeStr(self, shape):
152
153 sStr = []
154 # Convert to strings
155 for i in shape:
156 sStr.append(str(i))
157
Kevin Cheng550ccc52021-03-03 11:21:43 -0800158 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700159
160 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -0700161 if isinstance(t, list):
162 assert len(t) >= 2
163 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 else:
Kevin Cheng989cb052021-04-28 16:29:44 -0700165 if t == DType.BOOL:
166 return "b"
167 elif t == DType.INT4:
168 return "i4"
169 elif t == DType.INT8:
170 return "i8"
171 elif t == DType.UINT8:
172 return "u8"
173 elif t == DType.INT16:
174 return "i16"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100175 elif t == DType.UINT16:
176 return "u16"
Kevin Cheng989cb052021-04-28 16:29:44 -0700177 elif t == DType.INT32:
178 return "i32"
179 elif t == DType.INT48:
180 return "i48"
181 elif t == DType.FLOAT:
182 return "float"
183 else:
184 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -0700185
186 def typeWidth(self, t):
Jeremy Johnson5d1a3472022-03-31 09:50:06 +0100187 """Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -0800188 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -0700189 return 4
190 elif t == DType.INT8:
191 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -0800192 elif t == DType.UINT8:
193 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -0700194 elif t == DType.INT16:
195 return 16
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100196 elif t == DType.UINT16:
197 return 16
Eric Kunzee5e26762020-10-13 16:11:07 -0700198 elif t == DType.INT32:
199 return 32
200 elif t == DType.INT48:
201 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +0100202 elif t == DType.FLOAT:
203 return 32
204 elif t == DType.BOOL:
205 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 else:
Les Bell729b0352021-11-24 10:28:21 +0000207 raise Exception(f"Unknown dtype, cannot determine width: {t}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
209 # Argument generators
210 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
211 # Where the string descriptor is used to generate the test name and
212 # The build_fcn_arg_list is expanded and passed to the operator test
213 # build function
214
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100215 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
216 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
217
Matthew Haddon848efb42021-09-09 12:30:53 +0100218 # build_placeholder returns an int, ABS/other ops does not
219 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000220 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100221 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000222 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000223 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100224 return result_tens
225
226 # Ensure new output type has correct qinfo
227 if error_name == ErrorIf.WrongOutputType:
228 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000229 qinfo = [
230 TosaQuantGen.getZeroPoint(self, a.dtype),
231 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
232 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100233
234 # Invalidate Input/Output list for error if checks.
235 input_list = [a.name]
236 output_list = [result_tens.name]
237 pCount, cCount = op["operands"]
238 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000239 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
240 self, error_name, input_list, output_list
241 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100242
Les Bell729b0352021-11-24 10:28:21 +0000243 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100244 self.ser,
245 validator_fcns,
246 error_name,
247 op=op,
248 input_dtype=a.dtype,
249 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000250 qinfo=qinfo,
251 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100252 input_list=input_list,
253 output_list=output_list,
254 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000255 ):
256 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100257
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000258 attr = None
259 if op["op"] == Op.NEGATE:
260 attr = ts.TosaSerializerAttribute()
261 attr.NegateAttribute(qinfo[0], qinfo[1])
262
263 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700264 return result_tens
265
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100266 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000267 result_tens = OutputShaper.binaryBroadcastOp(
268 self.ser, self.rng, a, b, error_name
269 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100270
271 # Invalidate Input/Output list for error if checks.
272 input_list = [a.name, b.name]
273 output_list = [result_tens.name]
274 pCount, cCount = op["operands"]
275 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000276 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
277 self, error_name, input_list, output_list
278 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100279
Les Bell729b0352021-11-24 10:28:21 +0000280 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100281 self.ser,
282 validator_fcns,
283 error_name,
284 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000285 input1=a,
286 input2=b,
287 input_dtype=a.dtype,
288 output_dtype=result_tens.dtype,
289 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100290 input_list=input_list,
291 output_list=output_list,
292 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000293 ):
294 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100295
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000296 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700297 return result_tens
298
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100299 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700300 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000301 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700302 return result_tens
303
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000304 def build_arithmetic_right_shift(
305 self, op, a, b, round, validator_fcns=None, error_name=None
306 ):
307 result_tens = OutputShaper.binaryBroadcastOp(
308 self.ser, self.rng, a, b, error_name
309 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100310
311 # Invalidate Input/Output list for error if checks.
312 input_list = [a.name, b.name]
313 output_list = [result_tens.name]
314 pCount, cCount = op["operands"]
315 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000316 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
317 self, error_name, input_list, output_list
318 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100319
Les Bell729b0352021-11-24 10:28:21 +0000320 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100321 self.ser,
322 validator_fcns,
323 error_name,
324 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000325 input1=a,
326 input2=b,
327 input_dtype=a.dtype,
328 output_dtype=result_tens.dtype,
329 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100330 input_list=input_list,
331 output_list=output_list,
332 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000333 ):
334 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800335
336 attr = ts.TosaSerializerAttribute()
337 attr.ArithmeticRightShiftAttribute(round)
338
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000339 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800340 return result_tens
341
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100342 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000343 result_tens = OutputShaper.binaryBroadcastOp(
344 self.ser, self.rng, a, b, error_name
345 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700346
347 # Special for multiply:
348 # Force the result to INT32 for INT types
349 if a.dtype != DType.FLOAT:
350 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100351 if error_name == ErrorIf.WrongOutputType:
352 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
353 outputDType = self.rng.choice(all_dtypes)
354 result_tens.setDtype(outputDType)
355
356 # Invalidate Input/Output list for error if checks.
357 input_list = [a.name, b.name]
358 output_list = [result_tens.name]
359 pCount, cCount = op["operands"]
360 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000361 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
362 self, error_name, input_list, output_list
363 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100364
Les Bell729b0352021-11-24 10:28:21 +0000365 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100366 self.ser,
367 validator_fcns,
368 error_name,
369 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000370 input1=a,
371 input2=b,
372 input_dtype=a.dtype,
373 output_dtype=result_tens.dtype,
374 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100375 input_list=input_list,
376 output_list=output_list,
377 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000378 ):
379 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700380
Kevin Chengaee1fac2020-11-11 13:54:06 -0800381 attr = ts.TosaSerializerAttribute()
382 attr.MulAttribute(shift)
383
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000384 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700385 return result_tens
386
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100387 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
388 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700389
Kevin Chengfe392ce2021-10-18 21:51:55 +0000390 attr = ts.TosaSerializerAttribute()
391 attr.TableAttribute(table)
392
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100393 # Invalidate Input/Output list for error if checks.
394 input_list = [a.name]
395 output_list = [result_tens.name]
396 pCount, cCount = op["operands"]
397 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000398 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
399 self, error_name, input_list, output_list
400 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100401
Les Bell729b0352021-11-24 10:28:21 +0000402 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100403 self.ser,
404 validator_fcns,
405 error_name,
406 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000407 input_shape=a.shape,
408 input_dtype=a.dtype,
409 output_dtype=result_tens.dtype,
410 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100411 input_list=input_list,
412 output_list=output_list,
413 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000414 ):
415 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100416
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000417 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700418
419 return result_tens
420
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100421 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
422 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
423
424 # Invalidate Input/Output list for error if checks.
425 input_list = [cond.name, a.name, b.name]
426 output_list = [result_tens.name]
427 pCount, cCount = op["operands"]
428 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000429 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
430 self, error_name, input_list, output_list
431 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100432
Les Bell729b0352021-11-24 10:28:21 +0000433 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100434 self.ser,
435 validator_fcns,
436 error_name,
437 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000438 input1=cond,
439 input2=a,
440 input3=b,
441 input_shape=a.shape,
442 input_dtype=a.dtype,
443 output_dtype=result_tens.dtype,
444 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100445 input_list=input_list,
446 output_list=output_list,
447 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000448 ):
449 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100450
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000451 self.ser.addOperator(
452 op["op"],
453 input_list,
454 output_list,
455 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700456 return result_tens
457
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100458 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000459 result_tens = OutputShaper.binaryComparisonOp(
460 self.ser, self.rng, a, b, error_name
461 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100462
463 # Invalidate Input/Output list for error if checks.
464 input_list = [a.name, b.name]
465 output_list = [result_tens.name]
466 pCount, cCount = op["operands"]
467 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000468 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
469 self, error_name, input_list, output_list
470 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471
Les Bell729b0352021-11-24 10:28:21 +0000472 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100473 self.ser,
474 validator_fcns,
475 error_name,
476 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000477 input1=a,
478 input2=b,
479 input_shape=a.shape,
480 input_dtype=a.dtype,
481 output_shape=result_tens.shape,
482 output_dtype=result_tens.dtype,
483 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100484 input_list=input_list,
485 output_list=output_list,
486 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000487 ):
488 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100489
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000490 self.ser.addOperator(
491 op["op"],
492 input_list,
493 output_list,
494 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700495 return result_tens
496
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100497 def build_argmax(self, op, a, axis, validator_fcns, error_name):
498 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
499
500 # Invalidate Input/Output list for error if checks.
501 input_list = [a.name]
502 output_list = [result_tens.name]
503 pCount, cCount = op["operands"]
504 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000505 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
506 self, error_name, input_list, output_list
507 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100508
Les Bell729b0352021-11-24 10:28:21 +0000509 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100510 self.ser,
511 validator_fcns,
512 error_name,
513 op=op,
514 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000515 input_shape=a.shape,
516 input_dtype=a.dtype,
517 output_shape=result_tens.shape,
518 output_dtype=result_tens.dtype,
519 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100520 input_list=input_list,
521 output_list=output_list,
522 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000523 ):
524 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700525
526 attr = ts.TosaSerializerAttribute()
527 attr.AxisAttribute(axis)
528
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000529 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700530 return result_tens
531
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000532 def build_pool2d(
533 self,
534 op,
535 input,
536 stride,
537 pad,
538 kernel,
539 validator_fcns=None,
540 error_name=None,
541 qinfo=None,
542 ):
543 result_tens = OutputShaper.pool2dOp(
544 self.ser, self.rng, input, kernel, stride, pad, error_name
545 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100546
547 # Ensure new output type has correct qinfo
548 if error_name == ErrorIf.WrongInputType:
549 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000550 qinfo = [
551 TosaQuantGen.getZeroPoint(self, input.dtype),
552 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
553 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100554
555 # Invalidate Input/Output list for error if checks.
556 input_list = [input.name]
557 output_list = [result_tens.name]
558 pCount, cCount = op["operands"]
559 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000560 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
561 self, error_name, input_list, output_list
562 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100563
Les Bell729b0352021-11-24 10:28:21 +0000564 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100565 self.ser,
566 validator_fcns,
567 error_name,
568 op=op,
569 input_shape=input.shape,
570 input_dtype=input.dtype,
571 output_shape=result_tens.shape,
572 output_dtype=result_tens.dtype,
573 kernel=kernel,
574 stride=stride,
575 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000576 qinfo=qinfo,
577 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100578 input_list=input_list,
579 output_list=output_list,
580 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000581 ):
582 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700583
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000584 if qinfo is None:
585 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700586
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000587 attr = ts.TosaSerializerAttribute()
588 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1])
589
590 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700591 return result_tens
592
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000593 def build_conv2d(
594 self,
595 op,
596 ifm,
597 filter,
598 bias,
599 strides,
600 padding,
601 dilations,
602 validator_fcns=None,
603 error_name=None,
604 qinfo=None,
605 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800606 assert len(padding) == 4
607 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000608 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
609 )
610
611 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000612 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
613 DType.INT8,
614 DType.UINT8,
615 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000616 qinfo = [
617 TosaQuantGen.getZeroPoint(self, ifm.dtype),
618 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
619 ]
Les Bell0e027d42021-11-09 14:42:14 +0000620
621 # Invalidate Input/Output list for error_if checks.
622 input_list = [ifm.name, filter.name, bias.name]
623 output_list = [result_tens.name]
624 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000625 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
626 self, error_name, input_list, output_list
627 )
Les Bell0e027d42021-11-09 14:42:14 +0000628
Les Bell729b0352021-11-24 10:28:21 +0000629 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000630 self.ser,
631 validator_fcns,
632 error_name,
633 op=op,
634 input_dtype=ifm.dtype,
635 weight_dtype=filter.dtype,
636 output_dtype=result_tens.dtype,
637 qinfo=qinfo,
638 input_list=input_list,
639 num_operands=num_operands,
640 output_list=output_list,
641 pad=padding,
642 stride=strides,
643 dilation=dilations,
644 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100645 weight_shape=filter.shape,
646 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000647 ):
648 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700649
650 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000651 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700652
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000653 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700654 return result_tens
655
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000656 def build_conv3d(
657 self,
658 op,
659 ifm,
660 filter,
661 bias,
662 strides,
663 padding,
664 dilations,
665 validator_fcns=None,
666 error_name=None,
667 qinfo=None,
668 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700669 assert len(padding) == 6
670 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000671 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
672 )
673
674 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000675 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
676 DType.INT8,
677 DType.UINT8,
678 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000679 qinfo = [
680 TosaQuantGen.getZeroPoint(self, ifm.dtype),
681 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
682 ]
Les Bell0e027d42021-11-09 14:42:14 +0000683
684 # Invalidate Input/Output list for error_if checks.
685 input_list = [ifm.name, filter.name, bias.name]
686 output_list = [result_tens.name]
687 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000688 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
689 self, error_name, input_list, output_list
690 )
Les Bell0e027d42021-11-09 14:42:14 +0000691
Les Bell729b0352021-11-24 10:28:21 +0000692 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000693 self.ser,
694 validator_fcns,
695 error_name,
696 op=op,
697 input_dtype=ifm.dtype,
698 weight_dtype=filter.dtype,
699 output_dtype=result_tens.dtype,
700 qinfo=qinfo,
701 input_list=input_list,
702 num_operands=num_operands,
703 output_list=output_list,
704 pad=padding,
705 stride=strides,
706 dilation=dilations,
707 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100708 weight_shape=filter.shape,
709 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000710 ):
711 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700712
713 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000714 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700715
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000716 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700717 return result_tens
718
Kevin Cheng550ccc52021-03-03 11:21:43 -0800719 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000720 self,
721 op,
722 ifm,
723 filter,
724 bias,
725 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700726 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000727 output_shape,
728 validator_fcns=None,
729 error_name=None,
730 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800731 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700732 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000733 result_tens = OutputShaper.transposeConv2DOp(
734 self.ser, self.rng, ifm, output_shape, error_name
735 )
Les Bell0e027d42021-11-09 14:42:14 +0000736
737 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000738 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
739 DType.INT8,
740 DType.UINT8,
741 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000742 qinfo = [
743 TosaQuantGen.getZeroPoint(self, ifm.dtype),
744 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
745 ]
Les Bell0e027d42021-11-09 14:42:14 +0000746
747 # Invalidate Input/Output list for error_if checks.
748 input_list = [ifm.name, filter.name, bias.name]
749 output_list = [result_tens.name]
750 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000751 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
752 self, error_name, input_list, output_list
753 )
Les Bell0e027d42021-11-09 14:42:14 +0000754
Les Bell729b0352021-11-24 10:28:21 +0000755 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000756 self.ser,
757 validator_fcns,
758 error_name,
759 op=op,
760 input_dtype=ifm.dtype,
761 weight_dtype=filter.dtype,
762 output_dtype=result_tens.dtype,
763 qinfo=qinfo,
764 input_list=input_list,
765 num_operands=num_operands,
766 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700767 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000768 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000769 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100770 weight_shape=filter.shape,
771 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000772 ):
773 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700774
775 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000776 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700777
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000778 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700779 return result_tens
780
Kevin Cheng550ccc52021-03-03 11:21:43 -0800781 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000782 self,
783 op,
784 ifm,
785 filter,
786 bias,
787 strides,
788 padding,
789 dilations,
790 validator_fcns=None,
791 error_name=None,
792 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800793 ):
794 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000795 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
796 )
797
798 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000799 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
800 DType.INT8,
801 DType.UINT8,
802 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000803 qinfo = [
804 TosaQuantGen.getZeroPoint(self, ifm.dtype),
805 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
806 ]
Les Bell0e027d42021-11-09 14:42:14 +0000807
808 # Invalidate Input/Output list for error_if checks.
809 input_list = [ifm.name, filter.name, bias.name]
810 output_list = [result_tens.name]
811 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000812 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
813 self, error_name, input_list, output_list
814 )
Les Bell0e027d42021-11-09 14:42:14 +0000815
Les Bell729b0352021-11-24 10:28:21 +0000816 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000817 self.ser,
818 validator_fcns,
819 error_name,
820 op=op,
821 input_dtype=ifm.dtype,
822 weight_dtype=filter.dtype,
823 output_dtype=result_tens.dtype,
824 qinfo=qinfo,
825 input_list=input_list,
826 num_operands=num_operands,
827 output_list=output_list,
828 pad=padding,
829 stride=strides,
830 dilation=dilations,
831 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100832 weight_shape=filter.shape,
833 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000834 ):
835 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700836
837 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000838 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700839
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000840 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700841 return result_tens
842
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000843 def build_fully_connected(
844 self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
845 ):
846 result_tens = OutputShaper.fullyConnectedOp(
847 self.ser, self.rng, ifm, filter, error_name
848 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100849
850 # Invalidate Input/Output list for error if checks.
851 input_list = [ifm.name, filter.name, bias.name]
852 output_list = [result_tens.name]
853 pCount, cCount = op["operands"]
854 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000855 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
856 self, error_name, input_list, output_list
857 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100858
Les Bell729b0352021-11-24 10:28:21 +0000859 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100860 self.ser,
861 validator_fcns,
862 error_name,
863 op=op,
864 input_shape=ifm.shape,
865 input_dtype=ifm.dtype,
866 weight_dtype=filter.dtype,
867 output_shape=result_tens.shape,
868 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000869 qinfo=qinfo,
870 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100871 input_list=input_list,
872 output_list=output_list,
873 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000874 ):
875 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700876
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000877 attr = ts.TosaSerializerAttribute()
878 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
879
880 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700881 return result_tens
882
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100883 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
884 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
885
886 # Invalidate Input/Output list for error if checks.
887 input_list = [a.name, b.name]
888 output_list = [result_tens.name]
889 pCount, cCount = op["operands"]
890 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000891 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
892 self, error_name, input_list, output_list
893 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100894
Les Bell729b0352021-11-24 10:28:21 +0000895 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100896 self.ser,
897 validator_fcns,
898 error_name,
899 op=op,
900 input_shape=a.shape,
901 input_dtype=a.dtype,
902 input2_shape=b.shape,
903 input2_dtype=b.dtype,
904 output_shape=result_tens.shape,
905 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 qinfo=qinfo,
907 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100908 input_list=input_list,
909 output_list=output_list,
910 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000911 ):
912 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100913
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000914 attr = ts.TosaSerializerAttribute()
915 attr.MatMulAttribute(qinfo[0], qinfo[1])
916
917 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700918 return result_tens
919
Matthew Haddond6ce7252021-09-29 15:35:44 +0100920 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
921 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
922
923 # Invalidate Input/Output list for error if checks.
924 input_list = [a.name]
925 output_list = [result_tens.name]
926 pCount, cCount = op["operands"]
927 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000928 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
929 self, error_name, input_list, output_list
930 )
Matthew Haddond6ce7252021-09-29 15:35:44 +0100931
Les Bell729b0352021-11-24 10:28:21 +0000932 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +0100933 self.ser,
934 validator_fcns,
935 error_name,
936 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000937 axis=axis,
938 input_shape=a.shape,
939 output_shape=result_tens.shape,
940 input_dtype=a.dtype,
941 output_dtype=result_tens.dtype,
942 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +0100943 input_list=input_list,
944 output_list=output_list,
945 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000946 ):
947 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700948
949 attr = ts.TosaSerializerAttribute()
950 attr.AxisAttribute(axis)
951
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000952 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700953 return result_tens
954
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100955 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
956 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700957
Jeremy Johnson18e26662021-07-22 16:15:29 +0100958 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700959
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100960 if error_name == ErrorIf.MaxSmallerMin:
961 # Make sure the numbers are different to invoke this error
962 while v[0] == v[1]:
963 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
964 max_val = min(v)
965 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700966 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100967 max_val = max(v)
968 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700969
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100970 # Invalidate Input/Output list for error if checks.
971 input_list = [a.name]
972 output_list = [result_tens.name]
973 pCount, cCount = op["operands"]
974 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000975 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
976 self, error_name, input_list, output_list
977 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100978
Les Bell729b0352021-11-24 10:28:21 +0000979 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100980 self.ser,
981 validator_fcns,
982 error_name,
983 op=op,
984 max_val=max_val,
985 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000986 input_shape=a.shape,
987 output_shape=result_tens.shape,
988 input_dtype=a.dtype,
989 output_dtype=result_tens.dtype,
990 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100991 input_list=input_list,
992 output_list=output_list,
993 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000994 ):
995 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100996
997 attr = ts.TosaSerializerAttribute()
998 if a.dtype == DType.FLOAT:
999 attr.ClampAttribute(0, 0, min_val, max_val)
1000 else:
1001 attr.ClampAttribute(min_val, max_val, 0, 0)
1002
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001003 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001004 return result_tens
1005
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001006 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1007 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001008 attr = ts.TosaSerializerAttribute()
1009
1010 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1011
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001012 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001013 return result_tens
1014
1015 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001016 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1017 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001018
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001019 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001020 return result_tens
1021
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001022 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1023 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1024
1025 # Invalidate Input/Output list for error if checks.
1026 input_list = [a.name]
1027 output_list = [result_tens.name]
1028 pCount, cCount = op["operands"]
1029 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001030 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1031 self, error_name, input_list, output_list
1032 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001033
Les Bell729b0352021-11-24 10:28:21 +00001034 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001035 self.ser,
1036 validator_fcns,
1037 error_name,
1038 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001039 input_shape=a.shape,
1040 output_shape=result_tens.shape,
1041 input_dtype=a.dtype,
1042 output_dtype=result_tens.dtype,
1043 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001044 input_list=input_list,
1045 output_list=output_list,
1046 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001047 ):
1048 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001049
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001050 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001051 return result_tens
1052
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001053 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1054 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1055
1056 # Invalidate Input/Output list for error if checks.
1057 input_list = [a.name]
1058 output_list = [result_tens.name]
1059 pCount, cCount = op["operands"]
1060 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001061 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1062 self, error_name, input_list, output_list
1063 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001064
Les Bell729b0352021-11-24 10:28:21 +00001065 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001066 self.ser,
1067 validator_fcns,
1068 error_name,
1069 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001070 input_shape=a.shape,
1071 output_shape=result_tens.shape,
1072 input_dtype=a.dtype,
1073 output_dtype=result_tens.dtype,
1074 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001075 input_list=input_list,
1076 output_list=output_list,
1077 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001078 ):
1079 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001080
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001081 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001082 return result_tens
1083
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001084 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1085 if error_name != ErrorIf.WrongInputType:
1086 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001087
1088 # To store variable length list of input tensors we need to store axis along with it
1089 axis = a[-1]
1090 a = a[:-1]
1091
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001092 result_tens = OutputShaper.concatOp(
1093 self.ser, self.rng, axis, *a, error_name=error_name
1094 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001095
Matthew Haddon818ab902021-07-27 09:12:49 +01001096 input_tensor_names = []
1097 for tensor in a:
1098 input_tensor_names.append(tensor.name)
1099
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001100 # Invalidate Input/Output list for error if checks.
1101 input_list = input_tensor_names
1102 output_list = [result_tens.name]
1103 pCount, cCount = op["operands"]
1104 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001105 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1106 self, error_name, input_list, output_list
1107 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001108
Les Bell729b0352021-11-24 10:28:21 +00001109 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001110 self.ser,
1111 validator_fcns,
1112 error_name,
1113 op=op,
1114 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001115 input_shape=a[0].shape,
1116 output_shape=result_tens.shape,
1117 input_dtype=a[0].dtype,
1118 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001119 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001120 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001121 input_list=input_list,
1122 output_list=output_list,
1123 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001124 ):
1125 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001126
1127 attr = ts.TosaSerializerAttribute()
1128 attr.AxisAttribute(axis)
1129
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001130 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001131 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001132
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001133 def build_pad(
1134 self,
1135 op,
1136 a,
1137 padding,
1138 pad_const_int,
1139 pad_const_float,
1140 validator_fcns=None,
1141 error_name=None,
1142 qinfo=None,
1143 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001144 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001145
Kevin Chengfe392ce2021-10-18 21:51:55 +00001146 attr = ts.TosaSerializerAttribute()
1147 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001148
Matthew Haddone807aae2021-10-11 18:12:58 +01001149 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001150 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001151 output_list = [result_tens.name]
1152 pCount, cCount = op["operands"]
1153 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001154 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1155 self, error_name, input_list, output_list
1156 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001157
Les Bell729b0352021-11-24 10:28:21 +00001158 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001159 self.ser,
1160 validator_fcns,
1161 error_name,
1162 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001163 input_shape=a.shape,
1164 output_shape=result_tens.shape,
1165 input_dtype=a.dtype,
1166 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001167 pad=padding,
1168 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001169 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001170 input_list=input_list,
1171 output_list=output_list,
1172 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001173 ):
1174 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001175
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001176 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001177 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001178
Matthew Haddone807aae2021-10-11 18:12:58 +01001179 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001180 result_tens = OutputShaper.reshapeOp(
1181 self.ser, self.rng, a, newShape, error_name
1182 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001183
1184 # Invalidate Input/Output list for error if checks.
1185 input_list = [a.name]
1186 output_list = [result_tens.name]
1187 pCount, cCount = op["operands"]
1188 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001189 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1190 self, error_name, input_list, output_list
1191 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001192
Les Bell729b0352021-11-24 10:28:21 +00001193 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001194 self.ser,
1195 validator_fcns,
1196 error_name,
1197 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001198 input_shape=a.shape,
1199 output_shape=result_tens.shape,
1200 input_dtype=a.dtype,
1201 output_dtype=result_tens.dtype,
1202 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001203 input_list=input_list,
1204 output_list=output_list,
1205 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001206 ):
1207 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001208
1209 attr = ts.TosaSerializerAttribute()
1210 attr.ReshapeAttribute(newShape)
1211
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001212 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001213 return result_tens
1214
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001215 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1216 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1217
1218 # Invalidate Input/Output list for error if checks.
1219 input_list = [a.name]
1220 output_list = [result_tens.name]
1221 pCount, cCount = op["operands"]
1222 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001223 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1224 self, error_name, input_list, output_list
1225 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001226
Les Bell729b0352021-11-24 10:28:21 +00001227 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001228 self.ser,
1229 validator_fcns,
1230 error_name,
1231 op=op,
1232 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001233 input_shape=a.shape,
1234 output_shape=result_tens.shape,
1235 input_dtype=a.dtype,
1236 output_dtype=result_tens.dtype,
1237 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001238 input_list=input_list,
1239 output_list=output_list,
1240 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001241 ):
1242 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001243
1244 attr = ts.TosaSerializerAttribute()
1245 attr.AxisAttribute(axis)
1246
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001247 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001248 return result_tens
1249
Matthew Haddone807aae2021-10-11 18:12:58 +01001250 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1251 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001252
Kevin Chengfe392ce2021-10-18 21:51:55 +00001253 attr = ts.TosaSerializerAttribute()
1254 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001255
Matthew Haddone807aae2021-10-11 18:12:58 +01001256 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001257 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001258 output_list = [result_tens.name]
1259 pCount, cCount = op["operands"]
1260 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001261 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1262 self, error_name, input_list, output_list
1263 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001264
Les Bell729b0352021-11-24 10:28:21 +00001265 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001266 self.ser,
1267 validator_fcns,
1268 error_name,
1269 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001270 input_shape=a.shape,
1271 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001272 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001273 input_dtype=a.dtype,
1274 output_dtype=result_tens.dtype,
1275 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001276 input_list=input_list,
1277 output_list=output_list,
1278 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001279 ):
1280 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001281
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001282 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001283 return result_tens
1284
Matthew Haddone807aae2021-10-11 18:12:58 +01001285 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001286 result_tens = OutputShaper.sliceOp(
1287 self.ser, self.rng, a, start, size, error_name
1288 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001289
1290 # Invalidate Input/Output list for error if checks.
1291 input_list = [a.name]
1292 output_list = [result_tens.name]
1293 pCount, cCount = op["operands"]
1294 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001295 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1296 self, error_name, input_list, output_list
1297 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001298
Les Bell729b0352021-11-24 10:28:21 +00001299 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001300 self.ser,
1301 validator_fcns,
1302 error_name,
1303 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001304 input_shape=a.shape,
1305 output_shape=result_tens.shape,
1306 input_dtype=a.dtype,
1307 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001308 start=start,
1309 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001311 input_list=input_list,
1312 output_list=output_list,
1313 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001314 ):
1315 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001316
1317 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001318 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001319
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001320 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001321 return result_tens
1322
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001323 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1324 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1325
1326 # Invalidate Input/Output list for error if checks.
1327 input_list = [a.name]
1328 output_list = [result_tens.name]
1329 pCount, cCount = op["operands"]
1330 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001331 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1332 self, error_name, input_list, output_list
1333 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001334
Les Bell729b0352021-11-24 10:28:21 +00001335 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001336 self.ser,
1337 validator_fcns,
1338 error_name,
1339 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001340 input_shape=a.shape,
1341 output_shape=result_tens.shape,
1342 input_dtype=a.dtype,
1343 output_dtype=result_tens.dtype,
1344 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001345 input_list=input_list,
1346 output_list=output_list,
1347 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001348 ):
1349 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001350
1351 attr = ts.TosaSerializerAttribute()
1352 attr.TileAttribute(multiples)
1353
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001354 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001355 return result_tens
1356
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001357 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001358
1359 # Create a new indicies tensor
1360 # here with data that doesn't exceed the dimensions of the values tensor
1361
Kevin Cheng550ccc52021-03-03 11:21:43 -08001362 K = values.shape[1] # K
1363 W = self.randInt(
1364 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1365 ) # W
1366 indicies_arr = np.int32(
1367 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1368 ) # (N, W)
1369 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001370
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 result_tens = OutputShaper.gatherOp(
1372 self.ser, self.rng, values, indicies, error_name
1373 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001374
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001375 # Invalidate Input/Output list for error if checks.
1376 input_list = [values.name, indicies.name]
1377 output_list = [result_tens.name]
1378 pCount, cCount = op["operands"]
1379 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001380 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1381 self, error_name, input_list, output_list
1382 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001383
Les Bell729b0352021-11-24 10:28:21 +00001384 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001385 self.ser,
1386 validator_fcns,
1387 error_name,
1388 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001389 input_shape=values.shape,
1390 output_shape=result_tens.shape,
1391 input_dtype=values.dtype,
1392 output_dtype=result_tens.dtype,
1393 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001394 input_list=input_list,
1395 output_list=output_list,
1396 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001397 ):
1398 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001399
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001400 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001401
1402 return result_tens
1403
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001404 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001405
1406 # Create a new indicies tensor
1407 # here with data that doesn't exceed the dimensions of the values_in tensor
1408
Kevin Cheng550ccc52021-03-03 11:21:43 -08001409 K = values_in.shape[1] # K
1410 W = input.shape[1] # W
1411 indicies_arr = np.int32(
1412 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1413 ) # (N, W)
1414 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001415
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001416 result_tens = OutputShaper.scatterOp(
1417 self.ser, self.rng, values_in, indicies, input, error_name
1418 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001419
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001420 # Invalidate Input/Output list for error if checks.
1421 input_list = [values_in.name, indicies.name, input.name]
1422 output_list = [result_tens.name]
1423 pCount, cCount = op["operands"]
1424 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001425 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1426 self, error_name, input_list, output_list
1427 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001428
Les Bell729b0352021-11-24 10:28:21 +00001429 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001430 self.ser,
1431 validator_fcns,
1432 error_name,
1433 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001434 input_shape=values_in.shape,
1435 output_shape=result_tens.shape,
1436 input_dtype=values_in.dtype,
1437 output_dtype=result_tens.dtype,
1438 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001439 input_list=input_list,
1440 output_list=output_list,
1441 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001442 ):
1443 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001444
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001445 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001446
Kevin Cheng77d0f762020-11-24 10:26:32 -08001447 return result_tens
1448
Kevin Cheng550ccc52021-03-03 11:21:43 -08001449 def build_resize(
1450 self,
1451 op,
1452 input,
1453 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001454 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001455 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001456 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001457 input_dtype,
1458 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001459 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001460 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001461 ):
1462 result_tens = OutputShaper.resizeOp(
1463 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001464 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001465 input,
1466 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001467 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001468 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001469 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001470 input_dtype,
1471 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001472 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001473 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001474
Matthew Haddon848efb42021-09-09 12:30:53 +01001475 # Invalidate Input/Output list for error if checks.
1476 input_list = [input.name]
1477 output_list = [result_tens.name]
1478 pCount, cCount = op["operands"]
1479 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001480 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1481 self, error_name, input_list, output_list
1482 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001483
Les Bell729b0352021-11-24 10:28:21 +00001484 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001485 self.ser,
1486 validator_fcns,
1487 error_name,
1488 op=op,
1489 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001490 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001491 input_dtype=input_dtype,
1492 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001493 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001494 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001495 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001496 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001497 input_list=input_list,
1498 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001499 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001500 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001501 ):
1502 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001503
Eric Kunzee5e26762020-10-13 16:11:07 -07001504 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001505
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001506 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001507
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001508 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001509 return result_tens
1510
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001511 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1512 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1513 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001514 self.ser.addOperator(
1515 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1516 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001517 return result_tens
1518
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001519 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001520 self.ser.addOutputTensor(val)
1521 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001522
1523 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001524 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001525 result_tens = OutputShaper.typeConversionOp(
1526 self.ser, self.rng, val, out_dtype, error_name
1527 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001528
1529 # Invalidate Input/Output list for error if checks.
1530 input_list = [val.name]
1531 output_list = [result_tens.name]
1532 pCount, cCount = op["operands"]
1533 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001534 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1535 self, error_name, input_list, output_list
1536 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001537
Les Bell729b0352021-11-24 10:28:21 +00001538 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001539 self.ser,
1540 validator_fcns,
1541 error_name,
1542 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001543 input_shape=val.shape,
1544 output_shape=result_tens.shape,
1545 input_dtype=val.dtype,
1546 output_dtype=result_tens.dtype,
1547 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001548 input_list=input_list,
1549 output_list=output_list,
1550 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001551 ):
1552 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001553
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001554 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001555 return result_tens
1556
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001557 def build_rescale(
1558 self,
1559 op,
1560 val,
1561 out_dtype,
1562 scale32,
1563 double_round,
1564 per_channel,
1565 validator_fcns,
1566 error_name,
1567 ):
1568 result_tens = OutputShaper.typeConversionOp(
1569 self.ser, self.rng, val, out_dtype, error_name
1570 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001571
1572 if per_channel:
1573 nc = val.shape[-1]
1574 else:
1575 nc = 1
1576
1577 in_type_width = self.typeWidth(val.dtype)
1578 out_type_width = self.typeWidth(out_dtype)
1579
Kevin Cheng3a478572021-01-22 17:21:02 -08001580 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001581 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001582 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001583 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001584 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001585 in_type_width += 1
1586 elif error_name in [
1587 ErrorIf.InputZeroPointNotZero,
1588 ErrorIf.U16InputZeroPointNotValid,
1589 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001590 input_zp = self.randInt(-128, 128)
1591 if input_zp == 0:
1592 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001593 in_type_width += 1
1594 elif val.dtype == DType.UINT16:
1595 # Must come after ErrorIf.U16InputZeroPointNotValid check
1596 input_zp = self.rng.choice([0, 32768])
1597 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001598 else:
1599 input_zp = 0
1600
Kevin Cheng3a478572021-01-22 17:21:02 -08001601 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001602 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001603 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001604 elif out_dtype == DType.UINT8:
1605 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001606 out_type_width += 1
1607 elif error_name in [
1608 ErrorIf.OutputZeroPointNotZero,
1609 ErrorIf.U16OutputZeroPointNotValid,
1610 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001611 output_zp = self.randInt(-128, 128)
1612 if output_zp == 0:
1613 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001614 out_type_width += 1
1615 elif out_dtype == DType.UINT16:
1616 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1617 output_zp = self.rng.choice([0, 32768])
1618 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001619 else:
1620 output_zp = 0
1621
1622 # Calculate scale based on:
1623 # scale = a *(2^output_width)/(2^input_width))
1624
1625 a = np.float32(self.rng.random(size=[nc]))
1626 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1627
1628 if scale32:
1629 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001630 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001631 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1632 else:
1633 # Cap the scaling at 2^15 - 1 for scale16
1634 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1635
Kevin Cheng550ccc52021-03-03 11:21:43 -08001636 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001637
1638 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1639 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001640 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1641 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001642
1643 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001644 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1645 scale_arr[i], scale32
1646 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001647 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1648 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001649
Kevin Cheng550ccc52021-03-03 11:21:43 -08001650 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001651 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001652 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001653 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001654 assert val.placeholderFilename
1655 values = np.load(
1656 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1657 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001658 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1659 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1660 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1661 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001662 if not np.all(np.array_equal(values, val_adj)):
1663 # Values changed so overwrite file with new values
1664 np.save(
1665 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1666 val_adj,
1667 False,
1668 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001669
Matthew Haddonc2025212021-10-08 21:21:05 +01001670 # Invalidate Input/Output list for error if checks.
1671 input_list = [val.name]
1672 output_list = [result_tens.name]
1673 pCount, cCount = op["operands"]
1674 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001675 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1676 self, error_name, input_list, output_list
1677 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001678
1679 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001680 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001681 self.ser,
1682 validator_fcns,
1683 error_name,
1684 op=op,
1685 input_dtype=val.dtype,
1686 output_dtype=out_dtype,
1687 input_shape=val.shape,
1688 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001689 scale32=scale32,
1690 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001691 input_list=input_list,
1692 output_list=output_list,
1693 result_tensor=result_tens,
1694 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001695 ):
1696 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001697
Eric Kunzee5e26762020-10-13 16:11:07 -07001698 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001699 attr.RescaleAttribute(
1700 input_zp,
1701 output_zp,
1702 multiplier_arr,
1703 shift_arr,
1704 scale32,
1705 double_round,
1706 per_channel,
1707 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001709 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001710 return result_tens
1711
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001712 def build_cond_if_const(
1713 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1714 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001715 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1716 # (except for the generated shap) and the condition. Build Then/Else blocks
1717 # and fill them with const nodes for the body.
1718
1719 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001720 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001721
1722 # Make then/else tensors
1723 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001724
1725 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001726 if error_name in [
1727 ErrorIf.CondIfOutputListThenGraphMismatch,
1728 ErrorIf.CondIfOutputListElseGraphMismatch,
1729 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001730 incorrect_shape = deepcopy(then_tens.shape)
1731 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001732 incorrect_shape[i] += (
1733 self.rng.choice([-3, -2, 2, 3])
1734 if incorrect_shape[i] > 3
1735 else self.rng.choice([1, 2, 4])
1736 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001737 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1738
Jeremy Johnson18e26662021-07-22 16:15:29 +01001739 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1740 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001741
1742 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001743 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001744
1745 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001746 then_block = "THEN_BLOCK"
1747 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001748 attr = ts.TosaSerializerAttribute()
1749 attr.CondIfAttribute(then_block, else_block)
1750
1751 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001752 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001753
1754 self.ser.startBasicBlock(then_block)
1755 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001756 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1757 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1758 else:
1759 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001760 self.ser.addOutputTensor(then_tens)
1761
1762 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001763 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1764 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1765 else:
1766 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001767 self.ser.addOutputTensor(else_tens)
1768
Les Bell729b0352021-11-24 10:28:21 +00001769 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001770 self.ser,
1771 validator_fcns,
1772 error_name,
1773 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001774 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001775 ):
1776 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001777
Eric Kunzee5e26762020-10-13 16:11:07 -07001778 return result_tens
1779
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001780 def build_cond_if_binary(
1781 self, op, a, b, cond, validator_fcns=None, error_name=None
1782 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001783 # For cond_if with a binary op in the then/else blocks, take a and b and
1784 # alternately add or subtract them based on the condition
1785
1786 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001787 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001788
Kevin Cheng550ccc52021-03-03 11:21:43 -08001789 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001790
1791 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001792 then_block = "THEN_BLOCK"
1793 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001794 attr = ts.TosaSerializerAttribute()
1795 attr.CondIfAttribute(then_block, else_block)
1796
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001797 if error_name in [
1798 ErrorIf.CondIfInputListThenGraphMismatch,
1799 ErrorIf.CondIfInputListElseGraphMismatch,
1800 ErrorIf.CondIfOutputListElseGraphMismatch,
1801 ErrorIf.CondIfOutputListThenGraphMismatch,
1802 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001803 incorrect_shape = a.shape.copy()
1804 for i in range(len(incorrect_shape)):
1805 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1806 incorrect_block_input = deepcopy(a)
1807 incorrect_block_input.shape = incorrect_shape
1808
Eric Kunzee5e26762020-10-13 16:11:07 -07001809 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001810 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001811 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001812 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001813
Les Bell6040b4d2021-10-11 12:50:31 +01001814 if a.dtype in (DType.FLOAT, DType.INT32):
1815 then_op, else_op = Op.ADD, Op.SUB
1816 elif a.dtype in (DType.INT8, DType.INT16):
1817 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1818 else:
1819 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001820
Les Bell6040b4d2021-10-11 12:50:31 +01001821 for block, op in ((then_block, then_op), (else_block, else_op)):
1822 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001823 if (
1824 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1825 and block == then_block
1826 ) or (
1827 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1828 and block == else_block
1829 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001830 self.ser.addInputTensor(incorrect_block_input)
1831 self.ser.addInputTensor(b)
1832 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001833 elif (
1834 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1835 and block == then_block
1836 ) or (
1837 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1838 and block == else_block
1839 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001840 self.ser.addInputTensor(a)
1841 self.ser.addInputTensor(b)
1842 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1843 else:
1844 self.ser.addInputTensor(a)
1845 self.ser.addInputTensor(b)
1846 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001847 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001848
Les Bell729b0352021-11-24 10:28:21 +00001849 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001850 self.ser,
1851 validator_fcns,
1852 error_name,
1853 op=op,
1854 a=a,
1855 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001856 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001857 ):
1858 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001859
Eric Kunzee5e26762020-10-13 16:11:07 -07001860 return result_tens
1861
Matthew Haddon630c17c2021-10-14 15:05:41 +01001862 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001863 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001864
Kevin Cheng550ccc52021-03-03 11:21:43 -08001865 cond_block = "COND_BLOCK"
1866 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001867
1868 attr = ts.TosaSerializerAttribute()
1869 attr.WhileLoopAttribute(cond_block, body_block)
1870
1871 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001872 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001873 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001874 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001875
1876 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001877 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1878 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001879 if error_name == ErrorIf.InputListOutputListMismatch:
1880 incorrect_acc = deepcopy(acc)
1881 for i in range(len(incorrect_acc.shape)):
1882 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1883 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1884 else:
1885 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001886
1887 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001888 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001889 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001890 [iter.name, a.name, acc.name],
1891 [iter_out.name, a_out.name, acc_out.name],
1892 attr,
1893 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001894 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001895
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001896 if error_name in [
1897 ErrorIf.InputListCondGraphMismatch,
1898 ErrorIf.InputListBodyGraphInputMismatch,
1899 ErrorIf.InputListBodyGraphOutputMismatch,
1900 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001901 incorrect_iter = deepcopy(iter)
1902 for i in range(len(incorrect_iter.shape)):
1903 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
1904 if len(incorrect_iter.shape) == 0:
1905 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
1906
1907 incorrect_acc = deepcopy(acc)
1908 for i in range(len(incorrect_acc.shape)):
1909 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1910
Eric Kunzee5e26762020-10-13 16:11:07 -07001911 # COND block (input: iter, output: cond_tens )
1912 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001913 if error_name == ErrorIf.InputListCondGraphMismatch:
1914 self.ser.addInputTensor(incorrect_iter)
1915 self.ser.addInputTensor(a)
1916 self.ser.addInputTensor(incorrect_acc)
1917 else:
1918 self.ser.addInputTensor(iter)
1919 self.ser.addInputTensor(a)
1920 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001921 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001922
1923 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001924 cond_tens = self.ser.addOutput(
1925 [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
1926 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001927 else:
1928 cond_tens = self.ser.addOutput([], DType.BOOL)
1929
Kevin Cheng550ccc52021-03-03 11:21:43 -08001930 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001931
1932 # BODY block (input: a, acc, iter, output: a, acc, iter)
1933 # Note that local intermediate tensors need to be declared here for the outputs
1934 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001935 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
1936 self.ser.addInputTensor(incorrect_iter)
1937 self.ser.addInputTensor(a)
1938 self.ser.addInputTensor(incorrect_acc)
1939 else:
1940 self.ser.addInputTensor(iter)
1941 self.ser.addInputTensor(a)
1942 self.ser.addInputTensor(acc)
1943
Kevin Cheng550ccc52021-03-03 11:21:43 -08001944 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001945
1946 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001947 iter_body_out = self.ser.addIntermediate(
1948 incorrect_iter.shape, incorrect_iter.dtype
1949 )
1950 acc_body_out = self.ser.addIntermediate(
1951 incorrect_acc.shape, incorrect_acc.dtype
1952 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001953 else:
1954 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1955 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
1956
Eric Kunzee5e26762020-10-13 16:11:07 -07001957 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1958 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1959 self.ser.addOutputTensor(iter_body_out)
1960 self.ser.addOutputTensor(a)
1961 self.ser.addOutputTensor(acc_body_out)
1962
Les Bell729b0352021-11-24 10:28:21 +00001963 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001964 self.ser,
1965 validator_fcns,
1966 error_name,
1967 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001968 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001969 ):
1970 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001971
Eric Kunzee5e26762020-10-13 16:11:07 -07001972 return acc_out
1973
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001974 def create_filter_lists(
1975 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
1976 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01001977 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1978 default_test_rank_range = range(1, 5)
1979 if not shapeFilter:
1980 shapeFilter = [None]
1981
1982 # Calculate the filters based on what is requested and what the operator allows
1983 rmin, rmax = op["rank"]
1984 if rankFilter is not None:
1985 cleanRankFilter = []
1986 # Ensure rankFilter values are allowed by operator
1987 for rank in rankFilter:
1988 if rank >= rmin and rank <= rmax:
1989 cleanRankFilter.append(rank)
1990 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01001991 # Ensure default behaviour is bounded by default range or by operator,
1992 # whichever is the smaller range of ranks.
1993 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001994 cleanRankFilter = (
1995 opRankRange
1996 if len(opRankRange) <= len(default_test_rank_range)
1997 else default_test_rank_range
1998 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01001999 else:
2000 cleanRankFilter = range(rmin, rmax + 1)
2001
2002 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002003
Matthew Haddon1c00b712021-10-01 15:51:03 +01002004 if dtypeFilter is not None:
2005 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002006 # Create list of operator dtypes filtered by requested dtypes
2007 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002008 if dtype in dtypeFilter or (
2009 isinstance(dtype, list) and dtype[0] in dtypeFilter
2010 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002011 cleanDtypeFilter.append(dtype)
2012 else:
2013 cleanDtypeFilter = dtypes
2014
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002015 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002016 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002017 "shapeFilter": shapeFilter,
2018 "rankFilter": cleanRankFilter,
2019 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002020 }
2021 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002022 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002023 if validator is not None:
2024 validator_info = validator(check=False, op=op)
2025 else:
2026 return None
2027
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002028 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002029
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002030 # Set parameters as required
2031 if error_arguments["rank"] is not None:
2032 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002033 else:
2034 rankFilter = cleanRankFilter
2035
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002036 if error_arguments["dtype"] is not None:
2037 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002038 else:
2039 dtypeFilter = cleanDtypeFilter
2040
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002041 if error_arguments["shape"] is not None:
2042 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002043 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002044 shapeFilter = shapeFilter[
2045 :2
2046 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002047
2048 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002049 "shapeFilter": shapeFilter,
2050 "rankFilter": rankFilter,
2051 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002052 }
2053 return filterDict
2054
Kevin Cheng550ccc52021-03-03 11:21:43 -08002055 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002056 self,
2057 opName,
2058 shapeFilter=[None],
2059 rankFilter=None,
2060 dtypeFilter=None,
2061 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002062 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002063
2064 try:
2065 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002066 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002067 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002068
2069 # Initialize a new random number generator
2070 self.rng = np.random.default_rng(self.random_seed)
2071
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002072 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002073
Eric Kunzee5e26762020-10-13 16:11:07 -07002074 # Test list consists of a tuple of:
2075 # (opName, testNameStr, dtype, shapeList, argumentsList)
2076 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002077 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002078 error_if_validators = op["error_if_validators"]
2079 else:
2080 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002081
Matthew Haddon1c00b712021-10-01 15:51:03 +01002082 for validator in error_if_validators:
2083 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002084 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002085 else:
2086 error_name = None
2087
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002088 filterDict = self.create_filter_lists(
2089 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2090 )
2091 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002092 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002093 cleanRankFilter = filterDict["rankFilter"]
2094 cleanDtypeFilter = filterDict["dtypeFilter"]
2095 cleanShapeFilter = filterDict["shapeFilter"]
2096 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002097
2098 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002099 for t in cleanDtypeFilter:
2100 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002101 # Filter out by rank
2102 if shape is not None and len(shape) != r:
2103 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002104 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002105 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002106
Matthew Haddon74567092021-07-16 15:38:20 +01002107 shapeStr = self.shapeStr(shapeList[0])
2108 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002109
Matthew Haddon74567092021-07-16 15:38:20 +01002110 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2111 argList = []
2112 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002113 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002114 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002115 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002116
Matthew Haddon74567092021-07-16 15:38:20 +01002117 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002118 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002119 if argStr:
2120 testStr = "{}_{}_{}_{}".format(
2121 opName, shapeStr, typeStr, argStr
2122 )
2123 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002124 testStr = "{}_{}_{}".format(
2125 opName, shapeStr, typeStr
2126 )
2127 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002128 if argStr:
2129 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2130 opName, error_name, shapeStr, typeStr, argStr
2131 )
2132 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 testStr = "{}_ERRORIF_{}_{}_{}".format(
2134 opName, error_name, shapeStr, typeStr
2135 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002136
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002137 testList.append(
2138 (opName, testStr, t, error_name, shapeList, args)
2139 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002140
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002141 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002142 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2143 if "invalid_test_validators" in op:
2144 invalid_test_validators = op["invalid_test_validators"]
2145 clean_testList = []
2146 for test in testList:
2147 for validator_fcn in invalid_test_validators:
2148 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002149 if validator_fcn(
2150 opName=test[0],
2151 input_dtype=test[2],
2152 shapeList=test[4],
2153 args=test[5],
2154 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002155 remove_test = True
2156 if not remove_test:
2157 clean_testList.append(test)
2158 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002159
2160 return testList
2161
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002162 def serializeTest(
2163 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2164 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002165 try:
2166 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002167 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002168 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002169
2170 # Create a serializer
2171 self.createSerializer(opName, testStr)
2172
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002173 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002174 if "error_if_validators" in op:
2175 error_if_validators = op["error_if_validators"]
2176 else:
2177 error_if_validators = None
2178
Kevin Cheng550ccc52021-03-03 11:21:43 -08002179 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002180 num_operands = pCount + cCount
2181
2182 if isinstance(dtype_or_dtypeList, list):
2183 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002184 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002185 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002186 else:
2187 dtypeList = [dtype_or_dtypeList] * (num_operands)
2188
Kevin Cheng93a16282021-08-31 16:14:03 -07002189 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002190 assert (
2191 len(shapeList) == num_operands
2192 ), "shapeList length {} must match number of operands {}".format(
2193 len(shapeList), num_operands
2194 )
2195 assert (
2196 len(dtypeList) == num_operands
2197 ), "dtypeList length {} must match number of operands {}".format(
2198 len(dtypeList), num_operands
2199 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002200
2201 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002202 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002203 except KeyError:
2204 qgen = None
2205
2206 # Build the random tensor operands and the test
2207 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002208
Matthew Haddon1c00b712021-10-01 15:51:03 +01002209 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002210 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002211 else:
2212 qinfo = None
2213
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002214 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002215
Matthew Haddon1c00b712021-10-01 15:51:03 +01002216 try:
2217 if error_if_validators is None:
2218 if qinfo is not None:
2219 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2220 else:
2221 resultName = build_fcn(self, op, *tens, *testArgs)
2222 else:
2223 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002224 resultName = build_fcn(
2225 self,
2226 op,
2227 *tens,
2228 *testArgs,
2229 validator_fcns=error_if_validators,
2230 error_name=error_name,
2231 qinfo=qinfo,
2232 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002233 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002234 resultName = build_fcn(
2235 self,
2236 op,
2237 *tens,
2238 *testArgs,
2239 validator_fcns=error_if_validators,
2240 error_name=error_name,
2241 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002242 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002243 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002244 raise e
2245
Les Bell729b0352021-11-24 10:28:21 +00002246 if resultName:
2247 # The test is valid, serialize it
2248 self.serialize("test")
2249 else:
2250 # The test is not valid
2251 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002252
Eric Kunzee5e26762020-10-13 16:11:07 -07002253 def createDynamicOpLists(self):
2254
2255 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002256 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002257
Kevin Cheng1533b852021-09-01 12:51:58 -07002258 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002259 testName = "conv2d_{}x{}".format(k[0], k[1])
2260 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2261 self.TOSA_OP_LIST[testName]["filter"] = k
2262 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002263
Kevin Cheng550ccc52021-03-03 11:21:43 -08002264 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2265 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2266 "depthwise_conv2d_TEMPLATE"
2267 ].copy()
2268 self.TOSA_OP_LIST[testName]["filter"] = k
2269 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002270
Kevin Cheng550ccc52021-03-03 11:21:43 -08002271 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2272 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2273 "transpose_conv2d_TEMPLATE"
2274 ].copy()
2275 self.TOSA_OP_LIST[testName]["filter"] = k
2276 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002277
Kevin Cheng1533b852021-09-01 12:51:58 -07002278 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2279 for k in KERNELS_3D:
2280 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2281 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2282 self.TOSA_OP_LIST[testName]["filter"] = k
2283 self.TOSA_OP_LIST[testName]["template"] = False
2284
Eric Kunzee5e26762020-10-13 16:11:07 -07002285 # Delete any templates after having created any dynamic ops
2286 # This is a two-pass operation because it's bad practice to delete
2287 # keys from dictionaries while iterating
2288 keyList = []
2289 for k in self.TOSA_OP_LIST:
2290 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002291 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002292 keyList.append(k)
2293 continue
2294 except KeyError:
2295 pass
2296
2297 for k in keyList:
2298 del self.TOSA_OP_LIST[k]
2299
2300 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002301 """Fill in default fields for ops if they aren't already specified.
2302 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002303 for op in self.TOSA_OP_LIST:
2304
2305 # Required fields
2306 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002307 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002308 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002309 raise Exception(
2310 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2311 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002312
2313 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002314 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002315 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002316 raise Exception(
2317 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2318 op
2319 )
2320 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002321
2322 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002323 _ = self.TOSA_OP_LIST[op]["types"]
2324 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002325 raise Exception(
2326 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2327 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002328
2329 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002330 _ = self.TOSA_OP_LIST[op]["op"]
2331 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002332 raise Exception(
2333 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2334 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002335
2336 # Put in default rank range, if missing
2337 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002338 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002339 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002340 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002341
2342 # Tensor operator list
2343 # 'op': op name
2344 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002345 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2346 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002347 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2348 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002349 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002350
Kevin Cheng550ccc52021-03-03 11:21:43 -08002351 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2352 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002353
Kevin Cheng550ccc52021-03-03 11:21:43 -08002354 TYPE_BOOL = [DType.BOOL]
2355 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2356 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2357 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002358
Kevin Cheng550ccc52021-03-03 11:21:43 -08002359 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002360
Kevin Cheng1533b852021-09-01 12:51:58 -07002361 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002362 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002363 [DType.INT8, DType.INT8, DType.INT32],
2364 [DType.INT16, DType.INT8, DType.INT48],
2365 DType.FLOAT,
2366 ]
2367
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002368 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002369
2370 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002371 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002372 "argmax": {
2373 "op": Op.ARGMAX,
2374 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002375 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002376 "build_fcn": (
2377 build_argmax,
2378 TosaTensorGen.tgBasic,
2379 TosaTensorValuesGen.tvgDefault,
2380 TosaArgGen.agAxis,
2381 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002382 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002383 "error_if_validators": (
2384 TosaErrorValidator.evAxisSmallerZero,
2385 TosaErrorValidator.evAxisLargerRank,
2386 TosaErrorValidator.evArgmaxOutputRankMismatch,
2387 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2388 TosaErrorValidator.evWrongRank,
2389 TosaErrorValidator.evWrongInputType,
2390 TosaErrorValidator.evWrongOutputType,
2391 TosaErrorValidator.evWrongInputList,
2392 TosaErrorValidator.evWrongOutputList,
2393 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002394 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002395 "avg_pool2d": {
2396 "op": Op.AVG_POOL2D,
2397 "operands": (1, 0),
2398 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002399 "build_fcn": (
2400 build_pool2d,
2401 TosaTensorGen.tgNHWC,
2402 TosaTensorValuesGen.tvgDefault,
2403 TosaArgGen.agPooling,
2404 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002405 "qgen": TosaQuantGen.qgUnary,
2406 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002407 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002408 "error_if_validators": (
2409 TosaErrorValidator.evKernelSmallerOne,
2410 TosaErrorValidator.evStrideSmallerOne,
2411 TosaErrorValidator.evPadSmallerZero,
2412 TosaErrorValidator.evWrongRank,
2413 TosaErrorValidator.evWrongInputType,
2414 TosaErrorValidator.evWrongOutputType,
2415 TosaErrorValidator.evWrongInputList,
2416 TosaErrorValidator.evWrongOutputList,
2417 TosaErrorValidator.evInputZeroPointNotZero,
2418 TosaErrorValidator.evOutputZeroPointNotZero,
2419 TosaErrorValidator.evPadLargerEqualKernel,
2420 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002421 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002422 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002423 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002424 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002425 "conv2d_TEMPLATE": {
2426 "op": Op.CONV2D,
2427 "operands": (1, 2),
2428 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002429 "build_fcn": (
2430 build_conv2d,
2431 TosaTensorGen.tgConv2D,
2432 TosaTensorValuesGen.tvgDefault,
2433 TosaArgGen.agConv,
2434 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002435 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002436 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002437 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2438 "error_if_validators": (
2439 TosaErrorValidator.evWrongInputType,
2440 TosaErrorValidator.evWrongOutputType,
2441 TosaErrorValidator.evWrongInputList,
2442 TosaErrorValidator.evWrongOutputList,
2443 TosaErrorValidator.evInputZeroPointNotZero,
2444 TosaErrorValidator.evWeightZeroPointNotZero,
2445 TosaErrorValidator.evPadSmallerZero,
2446 TosaErrorValidator.evStrideSmallerOne,
2447 TosaErrorValidator.evDilationSmallerOne,
2448 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002449 TosaErrorValidator.evConvOutputShapeMismatch,
2450 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002451 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002452 "template": True,
2453 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002454 # Templated operator. Filled in by createDynamicOpLists
2455 "conv3d_TEMPLATE": {
2456 "op": Op.CONV3D,
2457 "operands": (1, 2),
2458 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002459 "build_fcn": (
2460 build_conv3d,
2461 TosaTensorGen.tgConv3D,
2462 TosaTensorValuesGen.tvgDefault,
2463 TosaArgGen.agConv,
2464 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002465 "qgen": TosaQuantGen.qgConv,
2466 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002467 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2468 "error_if_validators": (
2469 TosaErrorValidator.evWrongInputType,
2470 TosaErrorValidator.evWrongOutputType,
2471 TosaErrorValidator.evWrongInputList,
2472 TosaErrorValidator.evWrongOutputList,
2473 TosaErrorValidator.evInputZeroPointNotZero,
2474 TosaErrorValidator.evWeightZeroPointNotZero,
2475 TosaErrorValidator.evPadSmallerZero,
2476 TosaErrorValidator.evStrideSmallerOne,
2477 TosaErrorValidator.evDilationSmallerOne,
2478 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002479 TosaErrorValidator.evConvOutputShapeMismatch,
2480 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002481 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002482 "template": True,
2483 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002484 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002485 "depthwise_conv2d_TEMPLATE": {
2486 "op": Op.DEPTHWISE_CONV2D,
2487 "operands": (1, 2),
2488 "filter": [1, 1],
2489 "rank": (4, 4),
2490 "build_fcn": (
2491 build_depthwise_conv2d,
2492 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002493 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002494 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002495 ),
2496 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002497 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002498 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2499 "error_if_validators": (
2500 TosaErrorValidator.evWrongInputType,
2501 TosaErrorValidator.evWrongOutputType,
2502 TosaErrorValidator.evWrongInputList,
2503 TosaErrorValidator.evWrongOutputList,
2504 TosaErrorValidator.evInputZeroPointNotZero,
2505 TosaErrorValidator.evWeightZeroPointNotZero,
2506 TosaErrorValidator.evPadSmallerZero,
2507 TosaErrorValidator.evStrideSmallerOne,
2508 TosaErrorValidator.evDilationSmallerOne,
2509 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002510 TosaErrorValidator.evConvOutputShapeMismatch,
2511 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002512 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002513 "template": True,
2514 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002515 "fully_connected": {
2516 "op": Op.FULLY_CONNECTED,
2517 "operands": (1, 2),
2518 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002519 "build_fcn": (
2520 build_fully_connected,
2521 TosaTensorGen.tgFullyConnected,
2522 TosaTensorValuesGen.tvgDefault,
2523 None,
2524 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002525 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002526 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002527 "error_if_validators": (
2528 TosaErrorValidator.evInputZeroPointNotZero,
2529 TosaErrorValidator.evWeightZeroPointNotZero,
2530 TosaErrorValidator.evWrongRank,
2531 TosaErrorValidator.evWrongInputType,
2532 TosaErrorValidator.evWrongOutputType,
2533 TosaErrorValidator.evWrongInputList,
2534 TosaErrorValidator.evWrongOutputList,
2535 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002536 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002537 "matmul": {
2538 "op": Op.MATMUL,
2539 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002540 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002541 "build_fcn": (
2542 build_matmul,
2543 TosaTensorGen.tgMatmul,
2544 TosaTensorValuesGen.tvgDefault,
2545 None,
2546 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002547 "qgen": TosaQuantGen.qgMatmul,
2548 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002549 "error_if_validators": (
2550 TosaErrorValidator.evInputZeroPointNotZero,
2551 TosaErrorValidator.evWrongRank,
2552 TosaErrorValidator.evWrongInputType,
2553 TosaErrorValidator.evWrongOutputType,
2554 TosaErrorValidator.evWrongInputList,
2555 TosaErrorValidator.evWrongOutputList,
2556 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002557 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002558 "max_pool2d": {
2559 "op": Op.MAX_POOL2D,
2560 "operands": (1, 0),
2561 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002562 "build_fcn": (
2563 build_pool2d,
2564 TosaTensorGen.tgNHWC,
2565 TosaTensorValuesGen.tvgDefault,
2566 TosaArgGen.agPooling,
2567 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002568 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002569 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002570 "error_if_validators": (
2571 TosaErrorValidator.evKernelSmallerOne,
2572 TosaErrorValidator.evStrideSmallerOne,
2573 TosaErrorValidator.evPadSmallerZero,
2574 TosaErrorValidator.evWrongRank,
2575 TosaErrorValidator.evWrongInputType,
2576 TosaErrorValidator.evWrongOutputType,
2577 TosaErrorValidator.evWrongInputList,
2578 TosaErrorValidator.evWrongOutputList,
2579 TosaErrorValidator.evPadLargerEqualKernel,
2580 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002581 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002582 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002583 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002584 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002585 "transpose_conv2d_TEMPLATE": {
2586 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002587 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002588 "rank": (4, 4),
2589 "build_fcn": (
2590 build_transpose_conv2d,
2591 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002592 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002593 TosaArgGen.agTransposeConv2D,
2594 ),
2595 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002596 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002597 "invalid_test_validators": (
2598 TosaInvalidValidator.ivHeightWidthInvalid,
2599 TosaInvalidValidator.ivNonPositiveOutputShape,
2600 ),
2601 "error_if_validators": (
2602 TosaErrorValidator.evWrongInputType,
2603 TosaErrorValidator.evWrongOutputType,
2604 TosaErrorValidator.evWrongInputList,
2605 TosaErrorValidator.evWrongOutputList,
2606 TosaErrorValidator.evInputZeroPointNotZero,
2607 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002608 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002609 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002610 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002611 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002612 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002613 "template": True,
2614 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002615 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002616 "clamp": {
2617 "op": Op.CLAMP,
2618 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002619 "build_fcn": (
2620 build_clamp,
2621 TosaTensorGen.tgBasic,
2622 TosaTensorValuesGen.tvgDefault,
2623 None,
2624 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002625 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002626 "error_if_validators": (
2627 TosaErrorValidator.evMaxSmallerMin,
2628 TosaErrorValidator.evWrongInputType,
2629 TosaErrorValidator.evWrongOutputType,
2630 TosaErrorValidator.evWrongInputList,
2631 TosaErrorValidator.evWrongOutputList,
2632 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002633 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002634 "sigmoid": {
2635 "op": Op.SIGMOID,
2636 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002637 "build_fcn": (
2638 build_sigmoid,
2639 TosaTensorGen.tgBasic,
2640 TosaTensorValuesGen.tvgDefault,
2641 None,
2642 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002643 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002644 "error_if_validators": (
2645 TosaErrorValidator.evWrongInputType,
2646 TosaErrorValidator.evWrongOutputType,
2647 TosaErrorValidator.evWrongInputList,
2648 TosaErrorValidator.evWrongOutputList,
2649 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002650 },
2651 "tanh": {
2652 "op": Op.TANH,
2653 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002654 "build_fcn": (
2655 build_tanh,
2656 TosaTensorGen.tgBasic,
2657 TosaTensorValuesGen.tvgDefault,
2658 None,
2659 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002660 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002661 "error_if_validators": (
2662 TosaErrorValidator.evWrongInputType,
2663 TosaErrorValidator.evWrongOutputType,
2664 TosaErrorValidator.evWrongInputList,
2665 TosaErrorValidator.evWrongOutputList,
2666 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002668 # Elementwise Binary Operators
2669 "add": {
2670 "op": Op.ADD,
2671 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002672 "build_fcn": (
2673 build_binary_broadcast,
2674 TosaTensorGen.tgBroadcastFuzz,
2675 TosaTensorValuesGen.tvgAddSub,
2676 None,
2677 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002678 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002679 "error_if_validators": (
2680 TosaErrorValidator.evRankMismatch,
2681 TosaErrorValidator.evWrongInputType,
2682 TosaErrorValidator.evWrongOutputType,
2683 TosaErrorValidator.evWrongInputList,
2684 TosaErrorValidator.evWrongOutputList,
2685 TosaErrorValidator.evDimensionMismatch,
2686 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002687 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002688 "arithmetic_right_shift": {
2689 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2690 "operands": (2, 0),
2691 "build_fcn": (
2692 build_arithmetic_right_shift,
2693 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002694 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002695 TosaArgGen.agArithmeticRightShift,
2696 ),
2697 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002698 "error_if_validators": (
2699 TosaErrorValidator.evRankMismatch,
2700 TosaErrorValidator.evWrongInputType,
2701 TosaErrorValidator.evWrongOutputType,
2702 TosaErrorValidator.evWrongInputList,
2703 TosaErrorValidator.evWrongOutputList,
2704 TosaErrorValidator.evDimensionMismatch,
2705 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002706 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002707 "bitwise_and": {
2708 "op": Op.BITWISE_AND,
2709 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002710 "build_fcn": (
2711 build_binary_broadcast,
2712 TosaTensorGen.tgBroadcastFuzz,
2713 TosaTensorValuesGen.tvgDefault,
2714 None,
2715 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002716 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002717 "error_if_validators": (
2718 TosaErrorValidator.evRankMismatch,
2719 TosaErrorValidator.evWrongInputType,
2720 TosaErrorValidator.evWrongOutputType,
2721 TosaErrorValidator.evWrongInputList,
2722 TosaErrorValidator.evWrongOutputList,
2723 TosaErrorValidator.evDimensionMismatch,
2724 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002725 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002726 "bitwise_or": {
2727 "op": Op.BITWISE_OR,
2728 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002729 "build_fcn": (
2730 build_binary_broadcast,
2731 TosaTensorGen.tgBroadcastFuzz,
2732 TosaTensorValuesGen.tvgDefault,
2733 None,
2734 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002735 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002736 "error_if_validators": (
2737 TosaErrorValidator.evRankMismatch,
2738 TosaErrorValidator.evWrongInputType,
2739 TosaErrorValidator.evWrongOutputType,
2740 TosaErrorValidator.evWrongInputList,
2741 TosaErrorValidator.evWrongOutputList,
2742 TosaErrorValidator.evDimensionMismatch,
2743 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002744 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002745 "bitwise_xor": {
2746 "op": Op.BITWISE_XOR,
2747 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002748 "build_fcn": (
2749 build_binary_broadcast,
2750 TosaTensorGen.tgBroadcastFuzz,
2751 TosaTensorValuesGen.tvgDefault,
2752 None,
2753 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002754 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002755 "error_if_validators": (
2756 TosaErrorValidator.evRankMismatch,
2757 TosaErrorValidator.evWrongInputType,
2758 TosaErrorValidator.evWrongOutputType,
2759 TosaErrorValidator.evWrongInputList,
2760 TosaErrorValidator.evWrongOutputList,
2761 TosaErrorValidator.evDimensionMismatch,
2762 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002763 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002764 "intdiv": {
2765 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002766 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002767 "build_fcn": (
2768 build_binary_broadcast,
2769 TosaTensorGen.tgBroadcastFuzz,
2770 TosaTensorValuesGen.tvgIntDiv,
2771 None,
2772 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002773 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002774 "error_if_validators": (
2775 TosaErrorValidator.evRankMismatch,
2776 TosaErrorValidator.evWrongInputType,
2777 TosaErrorValidator.evWrongOutputType,
2778 TosaErrorValidator.evWrongInputList,
2779 TosaErrorValidator.evWrongOutputList,
2780 TosaErrorValidator.evDimensionMismatch,
2781 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002782 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002783 "logical_and": {
2784 "op": Op.LOGICAL_AND,
2785 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002786 "build_fcn": (
2787 build_binary_broadcast,
2788 TosaTensorGen.tgBroadcastFuzz,
2789 TosaTensorValuesGen.tvgDefault,
2790 None,
2791 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002792 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002793 "error_if_validators": (
2794 TosaErrorValidator.evRankMismatch,
2795 TosaErrorValidator.evWrongInputType,
2796 TosaErrorValidator.evWrongOutputType,
2797 TosaErrorValidator.evWrongInputList,
2798 TosaErrorValidator.evWrongOutputList,
2799 TosaErrorValidator.evDimensionMismatch,
2800 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002801 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002802 "logical_left_shift": {
2803 "op": Op.LOGICAL_LEFT_SHIFT,
2804 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002805 "build_fcn": (
2806 build_binary_broadcast,
2807 TosaTensorGen.tgBroadcastFuzz,
2808 TosaTensorValuesGen.tvgLogicalShift,
2809 None,
2810 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002811 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002812 "error_if_validators": (
2813 TosaErrorValidator.evRankMismatch,
2814 TosaErrorValidator.evWrongInputType,
2815 TosaErrorValidator.evWrongOutputType,
2816 TosaErrorValidator.evWrongInputList,
2817 TosaErrorValidator.evWrongOutputList,
2818 TosaErrorValidator.evDimensionMismatch,
2819 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002820 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002821 "logical_right_shift": {
2822 "op": Op.LOGICAL_RIGHT_SHIFT,
2823 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002824 "build_fcn": (
2825 build_binary_broadcast,
2826 TosaTensorGen.tgBroadcastFuzz,
2827 TosaTensorValuesGen.tvgLogicalShift,
2828 None,
2829 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002830 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002831 "error_if_validators": (
2832 TosaErrorValidator.evRankMismatch,
2833 TosaErrorValidator.evWrongInputType,
2834 TosaErrorValidator.evWrongOutputType,
2835 TosaErrorValidator.evWrongInputList,
2836 TosaErrorValidator.evWrongOutputList,
2837 TosaErrorValidator.evDimensionMismatch,
2838 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002839 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002840 "logical_or": {
2841 "op": Op.LOGICAL_OR,
2842 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002843 "build_fcn": (
2844 build_binary_broadcast,
2845 TosaTensorGen.tgBroadcastFuzz,
2846 TosaTensorValuesGen.tvgDefault,
2847 None,
2848 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002849 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002850 "error_if_validators": (
2851 TosaErrorValidator.evRankMismatch,
2852 TosaErrorValidator.evWrongInputType,
2853 TosaErrorValidator.evWrongOutputType,
2854 TosaErrorValidator.evWrongInputList,
2855 TosaErrorValidator.evWrongOutputList,
2856 TosaErrorValidator.evDimensionMismatch,
2857 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002858 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002859 "logical_xor": {
2860 "op": Op.LOGICAL_XOR,
2861 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002862 "build_fcn": (
2863 build_binary_broadcast,
2864 TosaTensorGen.tgBroadcastFuzz,
2865 TosaTensorValuesGen.tvgDefault,
2866 None,
2867 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002868 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002869 "error_if_validators": (
2870 TosaErrorValidator.evRankMismatch,
2871 TosaErrorValidator.evWrongInputType,
2872 TosaErrorValidator.evWrongOutputType,
2873 TosaErrorValidator.evWrongInputList,
2874 TosaErrorValidator.evWrongOutputList,
2875 TosaErrorValidator.evDimensionMismatch,
2876 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002877 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002878 "maximum": {
2879 "op": Op.MAXIMUM,
2880 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002881 "build_fcn": (
2882 build_binary_broadcast,
2883 TosaTensorGen.tgBroadcastFuzz,
2884 TosaTensorValuesGen.tvgDefault,
2885 None,
2886 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002887 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002888 "error_if_validators": (
2889 TosaErrorValidator.evRankMismatch,
2890 TosaErrorValidator.evWrongInputType,
2891 TosaErrorValidator.evWrongOutputType,
2892 TosaErrorValidator.evWrongInputList,
2893 TosaErrorValidator.evWrongOutputList,
2894 TosaErrorValidator.evDimensionMismatch,
2895 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002896 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002897 "minimum": {
2898 "op": Op.MINIMUM,
2899 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002900 "build_fcn": (
2901 build_binary_broadcast,
2902 TosaTensorGen.tgBroadcastFuzz,
2903 TosaTensorValuesGen.tvgDefault,
2904 None,
2905 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002906 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002907 "error_if_validators": (
2908 TosaErrorValidator.evRankMismatch,
2909 TosaErrorValidator.evWrongInputType,
2910 TosaErrorValidator.evWrongOutputType,
2911 TosaErrorValidator.evWrongInputList,
2912 TosaErrorValidator.evWrongOutputList,
2913 TosaErrorValidator.evDimensionMismatch,
2914 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002915 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002916 "mul": {
2917 "op": Op.MUL,
2918 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002919 "build_fcn": (
2920 build_mul,
2921 TosaTensorGen.tgBroadcastFuzz,
2922 TosaTensorValuesGen.tvgMul,
2923 TosaArgGen.agMul,
2924 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002925 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002926 "error_if_validators": (
2927 TosaErrorValidator.evWrongInputType,
2928 TosaErrorValidator.evWrongOutputType,
2929 TosaErrorValidator.evWrongInputList,
2930 TosaErrorValidator.evWrongOutputList,
2931 TosaErrorValidator.evRankMismatch,
2932 TosaErrorValidator.evDimensionMismatch,
2933 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002934 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002935 "pow": {
2936 "op": Op.POW,
2937 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002938 "build_fcn": (
2939 build_binary_broadcast,
2940 TosaTensorGen.tgBroadcastFuzz,
2941 TosaTensorValuesGen.tvgDefault,
2942 None,
2943 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002944 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002945 "error_if_validators": (
2946 TosaErrorValidator.evRankMismatch,
2947 TosaErrorValidator.evWrongInputType,
2948 TosaErrorValidator.evWrongOutputType,
2949 TosaErrorValidator.evWrongInputList,
2950 TosaErrorValidator.evWrongOutputList,
2951 TosaErrorValidator.evDimensionMismatch,
2952 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002953 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002954 "sub": {
2955 "op": Op.SUB,
2956 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002957 "build_fcn": (
2958 build_binary_broadcast,
2959 TosaTensorGen.tgBroadcastFuzz,
2960 TosaTensorValuesGen.tvgAddSub,
2961 None,
2962 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002963 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002964 "error_if_validators": (
2965 TosaErrorValidator.evRankMismatch,
2966 TosaErrorValidator.evWrongInputType,
2967 TosaErrorValidator.evWrongOutputType,
2968 TosaErrorValidator.evWrongInputList,
2969 TosaErrorValidator.evWrongOutputList,
2970 TosaErrorValidator.evDimensionMismatch,
2971 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002972 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002973 "table": {
2974 "op": Op.TABLE,
2975 # Use the automatic generation functions to create the input array
2976 # but create the table tensor in the build function, as it may be
2977 # a different type from the input
2978 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002979 "build_fcn": (
2980 build_table,
2981 TosaTensorGen.tgBasic,
2982 TosaTensorValuesGen.tvgDefault,
2983 TosaArgGen.agTable,
2984 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002985 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002986 "error_if_validators": (
2987 TosaErrorValidator.evWrongInputType,
2988 TosaErrorValidator.evWrongOutputType,
2989 TosaErrorValidator.evWrongInputList,
2990 TosaErrorValidator.evWrongOutputList,
2991 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002992 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002993 # Elementwise Unary operators
2994 "abs": {
2995 "op": Op.ABS,
2996 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002997 "build_fcn": (
2998 build_unary,
2999 TosaTensorGen.tgBasic,
3000 TosaTensorValuesGen.tvgDefault,
3001 None,
3002 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003003 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003004 "error_if_validators": (
3005 TosaErrorValidator.evWrongInputType,
3006 TosaErrorValidator.evWrongOutputType,
3007 TosaErrorValidator.evWrongInputList,
3008 TosaErrorValidator.evWrongOutputList,
3009 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003010 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003011 "bitwise_not": {
3012 "op": Op.BITWISE_NOT,
3013 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003014 "build_fcn": (
3015 build_unary,
3016 TosaTensorGen.tgBasic,
3017 TosaTensorValuesGen.tvgDefault,
3018 None,
3019 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003020 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003021 "error_if_validators": (
3022 TosaErrorValidator.evWrongInputType,
3023 TosaErrorValidator.evWrongOutputType,
3024 TosaErrorValidator.evWrongInputList,
3025 TosaErrorValidator.evWrongOutputList,
3026 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003027 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003028 "ceil": {
3029 "op": Op.CEIL,
3030 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003031 "build_fcn": (
3032 build_unary,
3033 TosaTensorGen.tgBasic,
3034 TosaTensorValuesGen.tvgDefault,
3035 None,
3036 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003037 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003038 "error_if_validators": (
3039 TosaErrorValidator.evWrongInputType,
3040 TosaErrorValidator.evWrongOutputType,
3041 TosaErrorValidator.evWrongInputList,
3042 TosaErrorValidator.evWrongOutputList,
3043 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003044 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003045 "clz": {
3046 "op": Op.CLZ,
3047 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003048 "build_fcn": (
3049 build_unary,
3050 TosaTensorGen.tgBasic,
3051 TosaTensorValuesGen.tvgDefault,
3052 None,
3053 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003054 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003055 "error_if_validators": (
3056 TosaErrorValidator.evWrongInputType,
3057 TosaErrorValidator.evWrongOutputType,
3058 TosaErrorValidator.evWrongInputList,
3059 TosaErrorValidator.evWrongOutputList,
3060 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003061 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003062 "exp": {
3063 "op": Op.EXP,
3064 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003065 "build_fcn": (
3066 build_unary,
3067 TosaTensorGen.tgBasic,
3068 TosaTensorValuesGen.tvgDefault,
3069 None,
3070 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003071 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003072 "error_if_validators": (
3073 TosaErrorValidator.evWrongInputType,
3074 TosaErrorValidator.evWrongOutputType,
3075 TosaErrorValidator.evWrongInputList,
3076 TosaErrorValidator.evWrongOutputList,
3077 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003078 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003079 "floor": {
3080 "op": Op.FLOOR,
3081 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003082 "build_fcn": (
3083 build_unary,
3084 TosaTensorGen.tgBasic,
3085 TosaTensorValuesGen.tvgDefault,
3086 None,
3087 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003088 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003089 "error_if_validators": (
3090 TosaErrorValidator.evWrongInputType,
3091 TosaErrorValidator.evWrongOutputType,
3092 TosaErrorValidator.evWrongInputList,
3093 TosaErrorValidator.evWrongOutputList,
3094 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003095 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003096 "log": {
3097 "op": Op.LOG,
3098 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003099 "build_fcn": (
3100 build_unary,
3101 TosaTensorGen.tgBasic,
3102 TosaTensorValuesGen.tvgDefault,
3103 None,
3104 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003105 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003106 "error_if_validators": (
3107 TosaErrorValidator.evWrongInputType,
3108 TosaErrorValidator.evWrongOutputType,
3109 TosaErrorValidator.evWrongInputList,
3110 TosaErrorValidator.evWrongOutputList,
3111 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003112 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003113 "logical_not": {
3114 "op": Op.LOGICAL_NOT,
3115 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003116 "build_fcn": (
3117 build_unary,
3118 TosaTensorGen.tgBasic,
3119 TosaTensorValuesGen.tvgDefault,
3120 None,
3121 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003122 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003123 "error_if_validators": (
3124 TosaErrorValidator.evWrongInputType,
3125 TosaErrorValidator.evWrongOutputType,
3126 TosaErrorValidator.evWrongInputList,
3127 TosaErrorValidator.evWrongOutputList,
3128 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003129 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003130 "negate": {
3131 "op": Op.NEGATE,
3132 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003133 "build_fcn": (
3134 build_unary,
3135 TosaTensorGen.tgBasic,
3136 TosaTensorValuesGen.tvgNegate,
3137 None,
3138 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003139 "qgen": TosaQuantGen.qgUnary,
3140 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003141 "error_if_validators": (
3142 TosaErrorValidator.evInputZeroPointNotZero,
3143 TosaErrorValidator.evOutputZeroPointNotZero,
3144 TosaErrorValidator.evWrongInputType,
3145 TosaErrorValidator.evWrongOutputType,
3146 TosaErrorValidator.evWrongInputList,
3147 TosaErrorValidator.evWrongOutputList,
3148 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003149 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003150 "reciprocal": {
3151 "op": Op.RECIPROCAL,
3152 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003153 "build_fcn": (
3154 build_unary,
3155 TosaTensorGen.tgBasic,
3156 TosaTensorValuesGen.tvgDefault,
3157 None,
3158 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003159 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003160 "error_if_validators": (
3161 TosaErrorValidator.evWrongInputType,
3162 TosaErrorValidator.evWrongOutputType,
3163 TosaErrorValidator.evWrongInputList,
3164 TosaErrorValidator.evWrongOutputList,
3165 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003166 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003167 "rsqrt": {
3168 "op": Op.RSQRT,
3169 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003170 "build_fcn": (
3171 build_unary,
3172 TosaTensorGen.tgBasic,
3173 TosaTensorValuesGen.tvgDefault,
3174 None,
3175 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003176 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003177 "error_if_validators": (
3178 TosaErrorValidator.evWrongInputType,
3179 TosaErrorValidator.evWrongOutputType,
3180 TosaErrorValidator.evWrongInputList,
3181 TosaErrorValidator.evWrongOutputList,
3182 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003183 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003184 # Elementwise Ternary operators
3185 "select": {
3186 "op": Op.SELECT,
3187 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003188 "build_fcn": (
3189 build_select,
3190 TosaTensorGen.tgBroadcastFuzz,
3191 TosaTensorValuesGen.tvgSelect,
3192 None,
3193 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003194 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003195 "error_if_validators": (
3196 TosaErrorValidator.evRankMismatch,
3197 TosaErrorValidator.evWrongInputType,
3198 TosaErrorValidator.evWrongOutputType,
3199 TosaErrorValidator.evWrongInputList,
3200 TosaErrorValidator.evWrongOutputList,
3201 TosaErrorValidator.evDimensionMismatch,
3202 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003203 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003204 # Comparison operators
3205 "equal": {
3206 "op": Op.EQUAL,
3207 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003208 "build_fcn": (
3209 build_comparison,
3210 TosaTensorGen.tgBroadcastFuzz,
3211 TosaTensorValuesGen.tvgEqual,
3212 None,
3213 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003214 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003215 "error_if_validators": (
3216 TosaErrorValidator.evRankMismatch,
3217 TosaErrorValidator.evWrongInputType,
3218 TosaErrorValidator.evWrongOutputType,
3219 TosaErrorValidator.evWrongInputList,
3220 TosaErrorValidator.evWrongOutputList,
3221 TosaErrorValidator.evDimensionMismatch,
3222 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003224 "greater_equal": {
3225 "op": Op.GREATER_EQUAL,
3226 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003227 "build_fcn": (
3228 build_comparison,
3229 TosaTensorGen.tgBroadcastFuzz,
3230 TosaTensorValuesGen.tvgDefault,
3231 None,
3232 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003233 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003234 "error_if_validators": (
3235 TosaErrorValidator.evRankMismatch,
3236 TosaErrorValidator.evWrongInputType,
3237 TosaErrorValidator.evWrongOutputType,
3238 TosaErrorValidator.evWrongInputList,
3239 TosaErrorValidator.evWrongOutputList,
3240 TosaErrorValidator.evDimensionMismatch,
3241 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003242 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003243 "greater": {
3244 "op": Op.GREATER,
3245 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003246 "build_fcn": (
3247 build_comparison,
3248 TosaTensorGen.tgBroadcastFuzz,
3249 TosaTensorValuesGen.tvgDefault,
3250 None,
3251 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003252 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003253 "error_if_validators": (
3254 TosaErrorValidator.evRankMismatch,
3255 TosaErrorValidator.evWrongInputType,
3256 TosaErrorValidator.evWrongOutputType,
3257 TosaErrorValidator.evWrongInputList,
3258 TosaErrorValidator.evWrongOutputList,
3259 TosaErrorValidator.evDimensionMismatch,
3260 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003261 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003262 # Reduction operators
3263 "reduce_all": {
3264 "op": Op.REDUCE_ALL,
3265 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003266 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003267 "build_fcn": (
3268 build_reduce,
3269 TosaTensorGen.tgBasic,
3270 TosaTensorValuesGen.tvgDefault,
3271 TosaArgGen.agAxis,
3272 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003273 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003274 "error_if_validators": (
3275 TosaErrorValidator.evAxisLargerRank,
3276 TosaErrorValidator.evAxisSmallerZero,
3277 TosaErrorValidator.evShapeOfAxisNotOne,
3278 TosaErrorValidator.evWrongInputType,
3279 TosaErrorValidator.evWrongOutputType,
3280 TosaErrorValidator.evWrongRank,
3281 TosaErrorValidator.evWrongInputList,
3282 TosaErrorValidator.evWrongOutputList,
3283 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003284 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003285 "reduce_any": {
3286 "op": Op.REDUCE_ANY,
3287 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003288 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 "build_fcn": (
3290 build_reduce,
3291 TosaTensorGen.tgBasic,
3292 TosaTensorValuesGen.tvgDefault,
3293 TosaArgGen.agAxis,
3294 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003295 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003296 "error_if_validators": (
3297 TosaErrorValidator.evAxisLargerRank,
3298 TosaErrorValidator.evAxisSmallerZero,
3299 TosaErrorValidator.evShapeOfAxisNotOne,
3300 TosaErrorValidator.evWrongInputType,
3301 TosaErrorValidator.evWrongOutputType,
3302 TosaErrorValidator.evWrongRank,
3303 TosaErrorValidator.evWrongInputList,
3304 TosaErrorValidator.evWrongOutputList,
3305 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003306 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003307 "reduce_max": {
3308 "op": Op.REDUCE_MAX,
3309 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003310 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003311 "build_fcn": (
3312 build_reduce,
3313 TosaTensorGen.tgBasic,
3314 TosaTensorValuesGen.tvgDefault,
3315 TosaArgGen.agAxis,
3316 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003317 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003318 "error_if_validators": (
3319 TosaErrorValidator.evAxisLargerRank,
3320 TosaErrorValidator.evAxisSmallerZero,
3321 TosaErrorValidator.evShapeOfAxisNotOne,
3322 TosaErrorValidator.evWrongInputType,
3323 TosaErrorValidator.evWrongOutputType,
3324 TosaErrorValidator.evWrongRank,
3325 TosaErrorValidator.evWrongInputList,
3326 TosaErrorValidator.evWrongOutputList,
3327 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003328 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003329 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003330 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003331 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003332 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003333 "build_fcn": (
3334 build_reduce,
3335 TosaTensorGen.tgBasic,
3336 TosaTensorValuesGen.tvgDefault,
3337 TosaArgGen.agAxis,
3338 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003339 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003340 "error_if_validators": (
3341 TosaErrorValidator.evAxisLargerRank,
3342 TosaErrorValidator.evAxisSmallerZero,
3343 TosaErrorValidator.evShapeOfAxisNotOne,
3344 TosaErrorValidator.evWrongInputType,
3345 TosaErrorValidator.evWrongOutputType,
3346 TosaErrorValidator.evWrongRank,
3347 TosaErrorValidator.evWrongInputList,
3348 TosaErrorValidator.evWrongOutputList,
3349 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003350 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003351 "reduce_product": {
3352 "op": Op.REDUCE_PRODUCT,
3353 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003354 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003355 "build_fcn": (
3356 build_reduce,
3357 TosaTensorGen.tgBasic,
3358 TosaTensorValuesGen.tvgDefault,
3359 TosaArgGen.agAxis,
3360 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003361 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003362 "error_if_validators": (
3363 TosaErrorValidator.evAxisLargerRank,
3364 TosaErrorValidator.evAxisSmallerZero,
3365 TosaErrorValidator.evShapeOfAxisNotOne,
3366 TosaErrorValidator.evWrongInputType,
3367 TosaErrorValidator.evWrongOutputType,
3368 TosaErrorValidator.evWrongRank,
3369 TosaErrorValidator.evWrongInputList,
3370 TosaErrorValidator.evWrongOutputList,
3371 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003372 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003373 "reduce_sum": {
3374 "op": Op.REDUCE_SUM,
3375 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003376 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003377 "build_fcn": (
3378 build_reduce,
3379 TosaTensorGen.tgBasic,
3380 TosaTensorValuesGen.tvgReduceSum,
3381 TosaArgGen.agAxis,
3382 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003383 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003384 "error_if_validators": (
3385 TosaErrorValidator.evAxisLargerRank,
3386 TosaErrorValidator.evAxisSmallerZero,
3387 TosaErrorValidator.evShapeOfAxisNotOne,
3388 TosaErrorValidator.evWrongInputType,
3389 TosaErrorValidator.evWrongOutputType,
3390 TosaErrorValidator.evWrongRank,
3391 TosaErrorValidator.evWrongInputList,
3392 TosaErrorValidator.evWrongOutputList,
3393 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003394 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003395 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003396 "concat": {
3397 "op": Op.CONCAT,
3398 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003399 "build_fcn": (
3400 build_concat,
3401 TosaTensorGen.tgConcat,
3402 TosaTensorValuesGen.tvgConcat,
3403 TosaArgGen.agAxis,
3404 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003405 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003406 "error_if_validators": (
3407 TosaErrorValidator.evAxisLargerRank,
3408 TosaErrorValidator.evAxisSmallerZero,
3409 TosaErrorValidator.evConcatInputRankMismatch,
3410 TosaErrorValidator.evConcatShapeSumMismatch,
3411 TosaErrorValidator.evConcatInputDimMismatch,
3412 TosaErrorValidator.evWrongInputType,
3413 TosaErrorValidator.evWrongOutputType,
3414 TosaErrorValidator.evWrongOutputList,
3415 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003416 },
3417 "pad": {
3418 "op": Op.PAD,
3419 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003420 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003421 "build_fcn": (
3422 build_pad,
3423 TosaTensorGen.tgBasic,
3424 TosaTensorValuesGen.tvgDefault,
3425 TosaArgGen.agPad,
3426 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003427 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003428 "error_if_validators": (
3429 TosaErrorValidator.evWrongInputType,
3430 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003431 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003432 TosaErrorValidator.evWrongOutputType,
3433 TosaErrorValidator.evWrongInputList,
3434 TosaErrorValidator.evWrongOutputList,
3435 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003436 },
3437 "reshape": {
3438 "op": Op.RESHAPE,
3439 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003440 "build_fcn": (
3441 build_reshape,
3442 TosaTensorGen.tgBasic,
3443 TosaTensorValuesGen.tvgDefault,
3444 TosaArgGen.agReshape,
3445 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003446 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003447 "error_if_validators": (
3448 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3449 TosaErrorValidator.evWrongInputType,
3450 TosaErrorValidator.evWrongOutputType,
3451 TosaErrorValidator.evWrongInputList,
3452 TosaErrorValidator.evWrongOutputList,
3453 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003454 },
3455 "reverse": {
3456 "op": Op.REVERSE,
3457 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003458 "build_fcn": (
3459 build_reverse,
3460 TosaTensorGen.tgBasic,
3461 TosaTensorValuesGen.tvgDefault,
3462 TosaArgGen.agAxis,
3463 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003464 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003465 "error_if_validators": (
3466 TosaErrorValidator.evAxisSmallerZero,
3467 TosaErrorValidator.evAxisLargerRank,
3468 TosaErrorValidator.evWrongInputType,
3469 TosaErrorValidator.evWrongOutputType,
3470 TosaErrorValidator.evWrongInputList,
3471 TosaErrorValidator.evWrongOutputList,
3472 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003473 },
3474 "slice": {
3475 "op": Op.SLICE,
3476 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003477 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003478 "build_fcn": (
3479 build_slice,
3480 TosaTensorGen.tgBasic,
3481 TosaTensorValuesGen.tvgDefault,
3482 TosaArgGen.agSlice,
3483 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003484 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003485 "error_if_validators": (
3486 TosaErrorValidator.evStartSmallerZero,
3487 TosaErrorValidator.evSizeSmallerEqualZero,
3488 TosaErrorValidator.evStartSizeOutsideBounds,
3489 TosaErrorValidator.evSizeOutputShapeMismatch,
3490 TosaErrorValidator.evInputSizeStartLengthMismatch,
3491 TosaErrorValidator.evWrongRank,
3492 TosaErrorValidator.evWrongInputType,
3493 TosaErrorValidator.evWrongOutputType,
3494 TosaErrorValidator.evWrongInputList,
3495 TosaErrorValidator.evWrongOutputList,
3496 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003497 },
3498 "tile": {
3499 "op": Op.TILE,
3500 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003501 "build_fcn": (
3502 build_tile,
3503 TosaTensorGen.tgBasic,
3504 TosaTensorValuesGen.tvgDefault,
3505 TosaArgGen.agTile,
3506 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003507 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003508 "error_if_validators": (
3509 TosaErrorValidator.evWrongInputType,
3510 TosaErrorValidator.evWrongOutputType,
3511 TosaErrorValidator.evWrongInputList,
3512 TosaErrorValidator.evWrongOutputList,
3513 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003514 },
3515 "transpose": {
3516 "op": Op.TRANSPOSE,
3517 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003518 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003519 "build_fcn": (
3520 build_transpose,
3521 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003522 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003523 TosaArgGen.agTranspose,
3524 ),
3525 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003526 "error_if_validators": (
3527 TosaErrorValidator.evIndexOutsideBounds,
3528 TosaErrorValidator.evIndexUsedTwice,
3529 TosaErrorValidator.evWrongInputType,
3530 TosaErrorValidator.evWrongOutputType,
3531 TosaErrorValidator.evWrongInputList,
3532 TosaErrorValidator.evWrongOutputList,
3533 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003534 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 # Data nodes
3536 "const": {
3537 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003538 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003539 "build_fcn": (
3540 build_const,
3541 TosaTensorGen.tgBasic,
3542 TosaTensorValuesGen.tvgDefault,
3543 None,
3544 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003545 "types": TYPE_FIB,
3546 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003547 "identity": {
3548 "op": Op.IDENTITY,
3549 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003550 "build_fcn": (
3551 build_unary,
3552 TosaTensorGen.tgBasic,
3553 TosaTensorValuesGen.tvgDefault,
3554 None,
3555 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003556 "types": TYPE_FIB,
3557 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003558 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003559 "gather": {
3560 "op": Op.GATHER,
3561 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3562 "operands": (1, 0),
3563 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003564 "build_fcn": (
3565 build_gather,
3566 TosaTensorGen.tgBasic,
3567 TosaTensorValuesGen.tvgDefault,
3568 None,
3569 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003570 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003571 "error_if_validators": (
3572 TosaErrorValidator.evWrongInputType,
3573 TosaErrorValidator.evWrongOutputType,
3574 TosaErrorValidator.evWrongInputList,
3575 TosaErrorValidator.evWrongOutputList,
3576 TosaErrorValidator.evWrongRank,
3577 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003578 },
3579 "scatter": {
3580 "op": Op.SCATTER,
3581 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003582 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003583 "operands": (2, 0),
3584 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003585 "build_fcn": (
3586 build_scatter,
3587 TosaTensorGen.tgScatter,
3588 TosaTensorValuesGen.tvgDefault,
3589 None,
3590 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003591 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003592 "error_if_validators": (
3593 TosaErrorValidator.evWrongInputType,
3594 TosaErrorValidator.evWrongOutputType,
3595 TosaErrorValidator.evWrongInputList,
3596 TosaErrorValidator.evWrongOutputList,
3597 TosaErrorValidator.evWrongRank,
3598 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003599 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003600 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003601 "resize": {
3602 "op": Op.RESIZE,
3603 "operands": (1, 0),
3604 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003605 "build_fcn": (
3606 build_resize,
3607 TosaTensorGen.tgNHWC,
3608 TosaTensorValuesGen.tvgDefault,
3609 TosaArgGen.agResize,
3610 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003611 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003612 "invalid_test_validators": (
3613 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003614 ),
3615 "error_if_validators": (
3616 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003617 TosaErrorValidator.evScaleSmallerEqualZero,
3618 TosaErrorValidator.evScaleNLargerMax,
3619 TosaErrorValidator.evScaleDLargerMax,
3620 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003621 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003622 TosaErrorValidator.evBorderSmallerMin,
3623 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003624 TosaErrorValidator.evWrongInputType,
3625 TosaErrorValidator.evWrongOutputType,
3626 TosaErrorValidator.evWrongRank,
3627 TosaErrorValidator.evWrongInputList,
3628 TosaErrorValidator.evWrongOutputList,
3629 TosaErrorValidator.evBatchMismatch,
3630 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003631 TosaErrorValidator.evResizeOutputShapeMismatch,
3632 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003633 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003634 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003635 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003636 "cast": {
3637 "op": Op.CAST,
3638 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 "build_fcn": (
3640 build_cast,
3641 TosaTensorGen.tgBasic,
3642 TosaTensorValuesGen.tvgDefault,
3643 TosaArgGen.agCast,
3644 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003645 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003646 "error_if_validators": (
3647 TosaErrorValidator.evWrongInputType,
3648 TosaErrorValidator.evWrongOutputType,
3649 TosaErrorValidator.evWrongInputList,
3650 TosaErrorValidator.evWrongOutputList,
3651 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003652 },
3653 "rescale": {
3654 "op": Op.RESCALE,
3655 "operands": (1, 0),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003656 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003657 "build_fcn": (
3658 build_rescale,
3659 TosaTensorGen.tgBasic,
3660 TosaTensorValuesGen.tvgDefault,
3661 TosaArgGen.agRescale,
3662 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003663 "types": [
3664 DType.UINT8,
3665 DType.INT8,
3666 DType.INT16,
3667 DType.INT32,
3668 DType.INT48,
3669 DType.UINT16,
3670 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003671 "error_if_validators": (
3672 TosaErrorValidator.evInputZeroPointNotZero,
3673 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003674 TosaErrorValidator.evU16InputZeroPointNotValid,
3675 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003676 TosaErrorValidator.evScaleTrue,
3677 TosaErrorValidator.evScaleNotTrue,
3678 TosaErrorValidator.evWrongInputType,
3679 TosaErrorValidator.evWrongOutputType,
3680 TosaErrorValidator.evWrongRank,
3681 TosaErrorValidator.evWrongInputList,
3682 TosaErrorValidator.evWrongOutputList,
3683 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003684 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003685 # Custom
3686 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003688 # Two varients of cond_if, one that generates one of two constant tensors (no
3689 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3690 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003691 "cond_if_const": {
3692 "op": Op.COND_IF,
3693 "operands": (0, 2),
3694 "build_fcn": (
3695 build_cond_if_const,
3696 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003698 TosaArgGen.agCondIf,
3699 ),
3700 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003701 "error_if_validators": (
3702 TosaErrorValidator.evOutputListThenGraphMismatch,
3703 TosaErrorValidator.evOutputListElseGraphMismatch,
3704 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003705 },
3706 "cond_if_binary": {
3707 "op": Op.COND_IF,
3708 "operands": (2, 0),
3709 "build_fcn": (
3710 build_cond_if_binary,
3711 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003712 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003713 TosaArgGen.agCondIf,
3714 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003715 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003716 "error_if_validators": (
3717 TosaErrorValidator.evInputListThenGraphMismatch,
3718 TosaErrorValidator.evInputListElseGraphMismatch,
3719 TosaErrorValidator.evOutputListThenGraphMismatch,
3720 TosaErrorValidator.evOutputListElseGraphMismatch,
3721 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003722 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003723 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003724 "while_loop": {
3725 "op": Op.WHILE_LOOP,
3726 "operands": (0, 1),
3727 "build_fcn": (
3728 build_while_loop,
3729 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003730 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003731 TosaArgGen.agWhileLoop,
3732 ),
3733 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 "error_if_validators": (
3735 TosaErrorValidator.evInputListOutputListMismatch,
3736 TosaErrorValidator.evInputListCondGraphMismatch,
3737 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3738 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3739 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
3740 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003741 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003742 }
3743
Kevin Cheng550ccc52021-03-03 11:21:43 -08003744
Eric Kunzee5e26762020-10-13 16:11:07 -07003745class OutputShaper:
3746 # Methods in this class compute the expected output shape and datatype
3747 # for common classes of operations
3748 def __init__(self):
3749 pass
3750
3751 # These methods return arguments that can be used for
3752 # creating a new output tensor
3753 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003754 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3755 if error_name != ErrorIf.RankMismatch:
3756 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003757 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003758
3759 shape = []
3760 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003761 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003762 shape.append(b.shape[i])
3763 else:
3764 shape.append(a.shape[i])
3765
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003766 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003767 all_dtypes = [
3768 DType.INT8,
3769 DType.INT16,
3770 DType.INT32,
3771 DType.INT48,
3772 DType.FLOAT,
3773 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003774 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3775 outputDType = rng.choice(wrong_dtypes)
3776 else:
3777 outputDType = a.dtype
3778
3779 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003780
3781 @staticmethod
3782 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003783 assert len(a.shape) == len(b.shape)
3784 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003785
3786 shape = []
3787 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003788 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003789 shape.append(a.shape[i])
3790
Kevin Cheng550ccc52021-03-03 11:21:43 -08003791 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003792
3793 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003794 def unaryOp(ser, rng, a, error_name=None):
3795 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003796 all_dtypes = [
3797 DType.INT8,
3798 DType.INT16,
3799 DType.INT32,
3800 DType.INT48,
3801 DType.FLOAT,
3802 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003803 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3804 outputDType = rng.choice(wrong_dtypes)
3805 else:
3806 outputDType = a.dtype
3807
3808 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003809
3810 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003811 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003812 if error_name != ErrorIf.RankMismatch:
3813 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003814 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003815
3816 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003817 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003818 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003819 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3820 else:
3821 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003822
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003823 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003824 all_dtypes = [
3825 DType.INT8,
3826 DType.INT16,
3827 DType.INT32,
3828 DType.INT48,
3829 DType.FLOAT,
3830 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003831 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3832 outputDType = rng.choice(wrong_dtypes)
3833 else:
3834 outputDType = a.dtype
3835
3836 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003837
3838 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003840 if error_name != ErrorIf.RankMismatch:
3841 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003842 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003843
3844 # Do broadcast
3845 shape = []
3846 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08003847 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07003848 shape.append(b.shape[i])
3849 else:
3850 shape.append(a.shape[i])
3851
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003852 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003853 wrong_dtypes = [
3854 DType.INT8,
3855 DType.INT16,
3856 DType.INT32,
3857 DType.INT48,
3858 DType.FLOAT,
3859 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003860 outputDType = rng.choice(wrong_dtypes)
3861 else:
3862 outputDType = DType.BOOL
3863
3864 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003865
3866 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01003867 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003868 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003869 if error_name not in [
3870 ErrorIf.AxisSmallerZero,
3871 ErrorIf.AxisLargerRank,
3872 ErrorIf.ShapeOfAxisNotOne,
3873 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01003874 shape[axis] = 1
3875 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
3876 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07003877
Matthew Haddond6ce7252021-09-29 15:35:44 +01003878 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003879 all_dtypes = [
3880 DType.INT8,
3881 DType.INT16,
3882 DType.INT32,
3883 DType.INT48,
3884 DType.FLOAT,
3885 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01003886 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3887 outputDType = rng.choice(wrong_dtypes)
3888 else:
3889 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003890
Matthew Haddond6ce7252021-09-29 15:35:44 +01003891 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003892
3893 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003894 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003895 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003896
3897 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
3898 del shape[axis]
3899
3900 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
3901 remove = rng.choice([True, False])
3902 if remove and len(shape) > 1:
3903 del shape[0]
3904 else:
3905 shape.append(1)
3906 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
3907 for i in range(len(shape)):
3908 shape[i] = shape[i] + rng.integers(1, 10)
3909
3910 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003911 all_dtypes = [
3912 DType.INT8,
3913 DType.INT16,
3914 DType.INT32,
3915 DType.INT48,
3916 DType.FLOAT,
3917 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003918 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
3919 outputDType = rng.choice(wrong_dtypes)
3920 else:
3921 outputDType = DType.INT32
3922
3923 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003924
3925 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003926 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003927
3928 # IFM: NHWC
3929 # Filter: OHWI
3930 # OFM: NHWC
3931
Kevin Cheng550ccc52021-03-03 11:21:43 -08003932 h = (
3933 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003934 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003935 + padding[0]
3936 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003937 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003938 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003939
Kevin Cheng550ccc52021-03-03 11:21:43 -08003940 w = (
3941 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003942 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003943 + padding[2]
3944 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003945 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003946 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003947
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003948 if error_name == ErrorIf.ConvOutputShapeMismatch:
3949 choices = [1, 2, 3]
3950 change = rng.choice(choices)
3951 # increment in multiples of stride to not hit non-integer error case
3952 if change in [1, 3]:
3953 h = h + (rng.choice(choices) * strides[0])
3954 if change in [2, 3]:
3955 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00003956
Eric Kunzee5e26762020-10-13 16:11:07 -07003957 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3958
Kevin Cheng3a478572021-01-22 17:21:02 -08003959 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003960 out_dtype = DType.INT32
3961 elif ifm.dtype == DType.INT16:
3962 out_dtype = DType.INT48
3963 elif ifm.dtype == DType.FLOAT:
3964 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00003965 elif error_name == ErrorIf.WrongInputType:
3966 # Pick some potentially correct output dtype if input type is incorrect
3967 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07003968 else:
Les Bell0e027d42021-11-09 14:42:14 +00003969 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
3970
3971 if error_name == ErrorIf.WrongOutputType:
3972 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
3973 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07003974
Kevin Cheng550ccc52021-03-03 11:21:43 -08003975 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003976
3977 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003978 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07003979
3980 # IFM: NDHWC
3981 # Filter: ODHWI
3982 # OFM: NDHWC
3983
3984 d = (
3985 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003986 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003987 + padding[0]
3988 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003989 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07003990 ) // strides[0] + 1
3991
3992 h = (
3993 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003994 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003995 + padding[2]
3996 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003997 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07003998 ) // strides[1] + 1
3999
4000 w = (
4001 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004002 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004003 + padding[4]
4004 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004005 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004006 ) // strides[2] + 1
4007
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004008 if error_name == ErrorIf.ConvOutputShapeMismatch:
4009 choices = [1, 2, 3, 4]
4010 change = rng.choice(choices)
4011 # increment in multiples of stride to not hit non-integer error case
4012 if change in [1, 4]:
4013 d = d + (rng.choice(choices) * strides[0])
4014 if change in [2, 4]:
4015 h = h + (rng.choice(choices) * strides[1])
4016 if change in [3, 4]:
4017 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004018
Kevin Cheng1533b852021-09-01 12:51:58 -07004019 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4020
4021 if ifm.dtype == DType.INT8:
4022 out_dtype = DType.INT32
4023 elif ifm.dtype == DType.INT16:
4024 out_dtype = DType.INT48
4025 elif ifm.dtype == DType.FLOAT:
4026 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004027 elif error_name == ErrorIf.WrongInputType:
4028 # Pick some potentially correct output dtype if input type is incorrect
4029 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004030 else:
Les Bell0e027d42021-11-09 14:42:14 +00004031 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4032
4033 if error_name == ErrorIf.WrongOutputType:
4034 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4035 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004036
4037 return ser.addOutput(ofm_shape, out_dtype)
4038
4039 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004040 def depthwiseConv2dOp(
4041 ser, rng, ifm, filter, strides, padding, dilations, error_name=None
4042 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004043 # IFM: NHWC
4044 # Filter: HWCM
4045 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004046
Kevin Cheng550ccc52021-03-03 11:21:43 -08004047 h = (
4048 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004049 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004050 + padding[0]
4051 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004052 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004053 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004054
Kevin Cheng550ccc52021-03-03 11:21:43 -08004055 w = (
4056 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004057 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004058 + padding[2]
4059 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004060 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004061 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004062
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004063 if error_name == ErrorIf.ConvOutputShapeMismatch:
4064 choices = [1, 2, 3]
4065 change = rng.choice(choices)
4066 # increment in multiples of stride to not hit non-integer error case
4067 if change in [1, 3]:
4068 h = h + (rng.choice(choices) * strides[0])
4069 if change in [2, 3]:
4070 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004071
Eric Kunzee5e26762020-10-13 16:11:07 -07004072 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4073
Kevin Cheng3a478572021-01-22 17:21:02 -08004074 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004075 out_dtype = DType.INT32
4076 elif ifm.dtype == DType.INT16:
4077 out_dtype = DType.INT48
4078 elif ifm.dtype == DType.FLOAT:
4079 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004080 elif error_name == ErrorIf.WrongInputType:
4081 # Pick some potentially correct output dtype if input type is incorrect
4082 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004083 else:
Les Bell0e027d42021-11-09 14:42:14 +00004084 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4085
4086 if error_name == ErrorIf.WrongOutputType:
4087 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4088 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004089
Kevin Cheng550ccc52021-03-03 11:21:43 -08004090 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004091
4092 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004093 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004094 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004095 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004096 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004097 h = 1
4098 w = 1
4099 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004100 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4101 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004102
4103 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004104 choices = [1, 2, 3]
4105 change = rng.choice(choices)
4106 # increment in multiples of stride to not hit non-integer error case
4107 if change in [1, 3]:
4108 h = h + (rng.choice(choices) * stride[0])
4109 if change in [2, 3]:
4110 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004111 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004112
4113 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004114 all_dtypes = [
4115 DType.INT8,
4116 DType.INT16,
4117 DType.INT32,
4118 DType.INT48,
4119 DType.FLOAT,
4120 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004121 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4122 outputDType = rng.choice(wrong_dtypes)
4123 else:
4124 outputDType = ifm.dtype
4125
4126 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004127
4128 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004129 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004130 # input: N, IC
4131 # filter: OC, IC
4132 # output: N, OC
4133
4134 output_shape = [input.shape[0], filter.shape[0]]
4135
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004136 if error_name == ErrorIf.WrongOutputType:
4137 if input.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004138 incorrect_types = (
4139 DType.INT4,
4140 DType.INT8,
4141 DType.INT16,
4142 DType.INT48,
4143 DType.FLOAT,
4144 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004145 elif input.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004146 incorrect_types = (
4147 DType.INT4,
4148 DType.INT8,
4149 DType.INT16,
4150 DType.INT32,
4151 DType.FLOAT,
4152 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004153 elif input.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004154 incorrect_types = (
4155 DType.INT4,
4156 DType.INT8,
4157 DType.INT16,
4158 DType.INT32,
4159 DType.INT48,
4160 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004161 out_dtype = rng.choice(a=incorrect_types)
4162 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004163 out_dtype = DType.INT32
4164 elif input.dtype == DType.INT16:
4165 out_dtype = DType.INT48
4166 elif input.dtype == DType.FLOAT:
4167 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004168 elif error_name == ErrorIf.WrongInputType:
4169 # Pick some potentially correct output dtype if input type is incorrect
4170 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004171 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004172 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004173
Kevin Cheng550ccc52021-03-03 11:21:43 -08004174 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004175
4176 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004177 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004178 # a: N, H, C
4179 # b: N, C, W
4180 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004181
Kevin Cheng2d60f002021-06-09 14:18:32 -07004182 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004183
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004184 if error_name == ErrorIf.WrongOutputType:
4185 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004186 incorrect_types = (
4187 DType.INT4,
4188 DType.INT8,
4189 DType.INT16,
4190 DType.INT48,
4191 DType.FLOAT,
4192 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004193 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004194 incorrect_types = (
4195 DType.INT4,
4196 DType.INT8,
4197 DType.INT16,
4198 DType.INT32,
4199 DType.FLOAT,
4200 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004201 elif a.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004202 incorrect_types = (
4203 DType.INT4,
4204 DType.INT8,
4205 DType.INT16,
4206 DType.INT32,
4207 DType.INT48,
4208 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004209 out_dtype = rng.choice(a=incorrect_types)
4210 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004211 out_dtype = DType.INT32
4212 elif a.dtype == DType.INT16:
4213 out_dtype = DType.INT48
4214 elif a.dtype == DType.FLOAT:
4215 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004216 elif error_name == ErrorIf.WrongInputType:
4217 # Pick some potentially correct output dtype if input type is incorrect
4218 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004219 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004220 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004221
Kevin Cheng550ccc52021-03-03 11:21:43 -08004222 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004223
4224 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004225 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004226 input1 = a[0]
4227 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004228
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004229 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004230 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004231 if not (
4232 # unable to concat tensors of different ranks
4233 error_name == ErrorIf.ConcatInputRankMismatch
4234 # unable to concat tensors along an invalid axis
4235 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004236 ):
4237 for tensor in remaining_inputs:
4238 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004239
Matthew Haddon01c359d2021-10-15 16:30:48 +01004240 if error_name == ErrorIf.ConcatShapeSumMismatch:
4241 output_shape[axis] += rng.integers(5, 10)
4242
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004243 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004244 all_dtypes = {
4245 DType.INT8,
4246 DType.INT16,
4247 DType.INT32,
4248 DType.INT48,
4249 DType.FLOAT,
4250 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004251 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4252 outputDType = rng.choice(wrong_dtypes)
4253 else:
4254 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004255
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004256 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004257
4258 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004259 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004260
4261 output_shape = a.shape.copy()
4262
4263 for i in range(len(output_shape)):
4264 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4265
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004266 if error_name == ErrorIf.PadOutputShapeMismatch:
4267 bad_dim = rng.choice(range(len(output_shape)))
4268 output_shape[bad_dim] -= rng.choice([1, 2])
4269
Matthew Haddone807aae2021-10-11 18:12:58 +01004270 # Fix negative output shape if error_if test causes it
4271 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
4272 output_shape = [i if i >= 1 else 1 for i in output_shape]
4273
4274 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004275 all_dtypes = [
4276 DType.INT8,
4277 DType.INT16,
4278 DType.INT32,
4279 DType.INT48,
4280 DType.FLOAT,
4281 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004282 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4283 outputDType = rng.choice(wrong_dtypes)
4284 else:
4285 outputDType = a.dtype
4286
4287 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004288
4289 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004290 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004291 output_shape = shape.copy()
4292
Matthew Haddone807aae2021-10-11 18:12:58 +01004293 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4294 for i in range(len(output_shape)):
4295 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4296
4297 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004298 all_dtypes = [
4299 DType.INT8,
4300 DType.INT16,
4301 DType.INT32,
4302 DType.INT48,
4303 DType.FLOAT,
4304 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004305 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4306 outputDType = rng.choice(wrong_dtypes)
4307 else:
4308 outputDType = a.dtype
4309
4310 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004311
4312 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004313 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004314
Matthew Haddone807aae2021-10-11 18:12:58 +01004315 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004316 all_dtypes = [
4317 DType.INT8,
4318 DType.INT16,
4319 DType.INT32,
4320 DType.INT48,
4321 DType.FLOAT,
4322 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004323 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4324 outputDType = rng.choice(wrong_dtypes)
4325 else:
4326 outputDType = a.dtype
4327
4328 if error_name == ErrorIf.SizeOutputShapeMismatch:
4329 output_shape = size.copy()
4330 for index in range(len(output_shape)):
4331 if output_shape[index] <= 2:
4332 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4333 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004334 output_shape[index] = output_shape[index] + rng.choice(
4335 [-2, -1, 1, 2]
4336 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004337 else:
4338 output_shape = size.copy()
4339
4340 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004341
4342 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004343 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004344
4345 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004346 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004347
4348 for i in range(len(output_shape)):
4349 output_shape[i] = a.shape[i] * multiples[i]
4350
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004351 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004352 all_dtypes = [
4353 DType.INT8,
4354 DType.INT16,
4355 DType.INT32,
4356 DType.INT48,
4357 DType.FLOAT,
4358 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004359 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4360 outputDType = rng.choice(wrong_dtypes)
4361 else:
4362 outputDType = a.dtype
4363
4364 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004365
4366 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004367 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004368 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004369
Kevin Cheng550ccc52021-03-03 11:21:43 -08004370 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004371
Matthew Haddone807aae2021-10-11 18:12:58 +01004372 if error_name == ErrorIf.IndexOutsideBounds:
4373 for i in range(len(output_shape)):
4374 output_shape[i] = a.shape[0]
4375 else:
4376 for i in range(len(output_shape)):
4377 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004378
Matthew Haddone807aae2021-10-11 18:12:58 +01004379 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004380 all_dtypes = [
4381 DType.INT8,
4382 DType.INT16,
4383 DType.INT32,
4384 DType.INT48,
4385 DType.FLOAT,
4386 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004387 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4388 outputDType = rng.choice(wrong_dtypes)
4389 else:
4390 outputDType = a.dtype
4391
4392 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004393
4394 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004395 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004396 if error_name != ErrorIf.WrongRank:
4397 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004398 assert len(indices.shape) == 2
4399 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004400
Kevin Cheng77d0f762020-11-24 10:26:32 -08004401 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4402
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004403 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004404 all_dtypes = [
4405 DType.INT8,
4406 DType.INT16,
4407 DType.INT32,
4408 DType.INT48,
4409 DType.FLOAT,
4410 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004411 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4412 outputDType = rng.choice(wrong_dtypes)
4413 else:
4414 outputDType = values.dtype
4415
4416 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004417
4418 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004419 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004420 if error_name != ErrorIf.WrongRank:
4421 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004422 assert len(indices.shape) == 2
4423 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004424 assert values_in.shape[0] == indices.shape[0] # N
4425 assert input.shape[1] == indices.shape[1] # W
4426 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004427
4428 output_shape = values_in.shape
4429
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004430 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004431 all_dtypes = [
4432 DType.INT8,
4433 DType.INT16,
4434 DType.INT32,
4435 DType.INT48,
4436 DType.FLOAT,
4437 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004438 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4439 outputDType = rng.choice(wrong_dtypes)
4440 else:
4441 outputDType = values_in.dtype
4442
4443 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004444
4445 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004446 def tableOp(ser, rng, input, error_name=None):
4447 # Same shape as the input, dtype dependent on input dtype
4448 if error_name != ErrorIf.WrongInputType:
4449 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004450 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004451 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004452 wrong_dtypes = [
4453 DType.INT8,
4454 DType.INT16,
4455 DType.INT32,
4456 DType.INT48,
4457 DType.FLOAT,
4458 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004459 wrong_dtypes.remove(output_dtype)
4460 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004461 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004462
4463 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004464 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004465 serializer,
4466 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004467 input,
4468 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004469 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004470 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004471 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004472 input_dtype,
4473 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004474 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004475 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004476 # Calculate OH, OW
4477 scale_y_n = scale[0]
4478 scale_y_d = scale[1]
4479 scale_x_n = scale[2]
4480 scale_x_d = scale[3]
4481 if error_name == ErrorIf.ScaleSmallerEqualZero:
4482 scale_y_n = max(scale_y_n, 1)
4483 scale_y_d = max(scale_y_d, 1)
4484 scale_x_n = max(scale_x_n, 1)
4485 scale_x_d = max(scale_x_d, 1)
4486
4487 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4488 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4489
4490 if error_name is not None:
4491 # Make sure the output tensor is valid, which can occur when
4492 # scale, offset or border have been changed for ERROR_IFs
4493 oh = max(oh, 1)
4494 ow = max(ow, 1)
4495 if error_name != ErrorIf.MaxDimExceeded:
4496 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4497 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4498
4499 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4500 choices = [1, 2, 3]
4501 change = rng.choice(choices)
4502 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4503 if change in [1, 3]:
4504 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4505 oh -= scale_y_d
4506 assert oh > 0 # Should have been caught in agResize
4507 else:
4508 oh += scale_y_d
4509 if change in [2, 3]:
4510 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4511 ow -= scale_x_d
4512 assert ow > 0 # Should have been caught in agResize
4513 else:
4514 ow += scale_x_d
4515
Matthew Haddon848efb42021-09-09 12:30:53 +01004516 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004517 output_dims = [
4518 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004519 oh,
4520 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004521 input.shape[0],
4522 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004523 elif error_name == ErrorIf.BatchMismatch:
4524 output_dims = [
4525 input.shape[0] + rng.integers(1, 10),
4526 oh,
4527 ow,
4528 input.shape[3],
4529 ]
4530 elif error_name == ErrorIf.ChannelMismatch:
4531 output_dims = [
4532 input.shape[0],
4533 oh,
4534 ow,
4535 input.shape[3] + rng.integers(1, 10),
4536 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004537 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004538 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004539
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004540 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004541
4542 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004543 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004544 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004545
4546 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00004547 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004548 if error_name == ErrorIf.ConvOutputShapeMismatch:
4549 choices = [1, 2, 3]
4550 change = rng.choice(choices)
4551 if change in [1, 3]:
4552 output_shape[1] = output_shape[1] + rng.choice(choices)
4553 if change in [2, 3]:
4554 output_shape[2] = output_shape[2] + rng.choice(choices)
4555
Kevin Cheng3a478572021-01-22 17:21:02 -08004556 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004557 out_dtype = DType.INT32
4558 elif ifm.dtype == DType.INT16:
4559 out_dtype = DType.INT48
4560 elif ifm.dtype == DType.FLOAT:
4561 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004562 elif error_name == ErrorIf.WrongInputType:
4563 # Pick some potentially correct output dtype if input type is incorrect
4564 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004565 else:
Les Bell0e027d42021-11-09 14:42:14 +00004566 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4567
4568 if error_name == ErrorIf.WrongOutputType:
4569 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4570 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004571
Kevin Cheng550ccc52021-03-03 11:21:43 -08004572 return ser.addOutput(output_shape, out_dtype)