blob: fe05b57f2e6703592a54cd1f620ee0928924f348 [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010016from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010017from generator.tosa_utils import usableDTypes
Les Bell0e027d42021-11-09 14:42:14 +000018from tosa.DType import DType
19from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010020
21
Eric Kunzee5e26762020-10-13 16:11:07 -070022class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010023 # Maximum rank of tensor supported by test generator.
24 TOSA_TENSOR_MAX_RANK = 6
25
Eric Kunzee5e26762020-10-13 16:11:07 -070026 def __init__(self, args):
27 self.args = args
28 self.basePath = args.output_dir
29 self.random_seed = args.random_seed
30 self.ser = None
31 self.rng = np.random.default_rng(self.random_seed)
32 self.createDynamicOpLists()
33 self.initOpListDefaults()
34 self.quantGen = TosaQuantGen()
35 # Force makeShape to do a specific starting shape
36 self.targetted_shape = None
37
38 def createSerializer(self, opName, testPath):
39 self.testPath = os.path.join(opName, testPath)
40
41 fullPath = os.path.join(self.basePath, self.testPath)
42 os.makedirs(fullPath, exist_ok=True)
43 self.ser = ts.TosaSerializer(fullPath)
44
45 def getSerializer(self):
46 return self.ser
47
48 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080049 with open(
50 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
51 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070052 fd.write(self.ser.serialize())
53
Kevin Cheng550ccc52021-03-03 11:21:43 -080054 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
55 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070056
Matthew Haddon74567092021-07-16 15:38:20 +010057 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000058 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010059 seed = self.random_seed + 1
60 self.rng = np.random.default_rng(seed)
61
Eric Kunzee5e26762020-10-13 16:11:07 -070062 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070063 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070064 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070065 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070066 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070067 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070068 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010069 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
70 elif dtype == DType.UINT8:
71 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070072 elif dtype == DType.INT16:
73 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010074 elif dtype == DType.UINT16:
75 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070076 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080077 return np.int32(
78 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
79 )
Eric Kunzee5e26762020-10-13 16:11:07 -070080 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080081 return np.int64(
82 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
83 )
Eric Kunzee5e26762020-10-13 16:11:07 -070084 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +010085 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070086 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -080087 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070088
Kevin Cheng989cb052021-04-28 16:29:44 -070089 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -070090 placeholders = []
91
Kevin Cheng989cb052021-04-28 16:29:44 -070092 assert len(shape_list) == len(dtype_list)
93
94 for idx, shape in enumerate(shape_list):
95 arr = self.getRandTensor(shape, dtype_list[idx])
96 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -070097
98 return placeholders
99
Kevin Cheng989cb052021-04-28 16:29:44 -0700100 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700101 consts = []
102
Kevin Cheng989cb052021-04-28 16:29:44 -0700103 assert len(shape_list) == len(dtype_list)
104
105 for idx, shape in enumerate(shape_list):
106 arr = self.getRandTensor(shape, dtype_list[idx])
107 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700108
109 return consts
110
111 def makeShape(self, rank):
112 if self.targetted_shape:
113 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800114 return np.int32(
115 self.rng.integers(
116 low=self.args.tensor_shape_range[0],
117 high=self.args.tensor_shape_range[1],
118 size=rank,
119 )
120 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 def setTargetShape(self, shape):
123 self.targetted_shape = shape
124
125 def randInt(self, low=0, high=256):
126 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
127
128 def getRandNumberDType(self, dtype):
129 if dtype == DType.FLOAT:
130 return self.rng.random()
131 elif dtype == DType.BOOL:
132 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700133 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700134 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700135 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100137 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700138 elif dtype == DType.INT16:
139 low, high = (-32768, 32768)
140 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800141 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700142 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800143 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700144 # Special size
145 return np.int64(self.rng.integers(low, high, size=1))[0]
146 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800147 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700148
149 return np.int32(self.rng.integers(low, high, size=1))[0]
150
151 def shapeStr(self, shape):
152
153 sStr = []
154 # Convert to strings
155 for i in shape:
156 sStr.append(str(i))
157
Kevin Cheng550ccc52021-03-03 11:21:43 -0800158 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700159
160 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -0700161 if isinstance(t, list):
162 assert len(t) >= 2
163 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 else:
Kevin Cheng989cb052021-04-28 16:29:44 -0700165 if t == DType.BOOL:
166 return "b"
167 elif t == DType.INT4:
168 return "i4"
169 elif t == DType.INT8:
170 return "i8"
171 elif t == DType.UINT8:
172 return "u8"
173 elif t == DType.INT16:
174 return "i16"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100175 elif t == DType.UINT16:
176 return "u16"
Kevin Cheng989cb052021-04-28 16:29:44 -0700177 elif t == DType.INT32:
178 return "i32"
179 elif t == DType.INT48:
180 return "i48"
181 elif t == DType.FLOAT:
182 return "float"
183 else:
184 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -0700185
186 def typeWidth(self, t):
Jeremy Johnson5d1a3472022-03-31 09:50:06 +0100187 """Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -0800188 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -0700189 return 4
190 elif t == DType.INT8:
191 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -0800192 elif t == DType.UINT8:
193 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -0700194 elif t == DType.INT16:
195 return 16
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100196 elif t == DType.UINT16:
197 return 16
Eric Kunzee5e26762020-10-13 16:11:07 -0700198 elif t == DType.INT32:
199 return 32
200 elif t == DType.INT48:
201 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +0100202 elif t == DType.FLOAT:
203 return 32
204 elif t == DType.BOOL:
205 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 else:
Les Bell729b0352021-11-24 10:28:21 +0000207 raise Exception(f"Unknown dtype, cannot determine width: {t}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
209 # Argument generators
210 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
211 # Where the string descriptor is used to generate the test name and
212 # The build_fcn_arg_list is expanded and passed to the operator test
213 # build function
214
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100215 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
216 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
217
Matthew Haddon848efb42021-09-09 12:30:53 +0100218 # build_placeholder returns an int, ABS/other ops does not
219 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000220 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100221 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000222 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000223 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100224 return result_tens
225
226 # Ensure new output type has correct qinfo
227 if error_name == ErrorIf.WrongOutputType:
228 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000229 qinfo = [
230 TosaQuantGen.getZeroPoint(self, a.dtype),
231 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
232 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100233
234 # Invalidate Input/Output list for error if checks.
235 input_list = [a.name]
236 output_list = [result_tens.name]
237 pCount, cCount = op["operands"]
238 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000239 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
240 self, error_name, input_list, output_list
241 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100242
Les Bell729b0352021-11-24 10:28:21 +0000243 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100244 self.ser,
245 validator_fcns,
246 error_name,
247 op=op,
248 input_dtype=a.dtype,
249 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000250 qinfo=qinfo,
251 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100252 input_list=input_list,
253 output_list=output_list,
254 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000255 ):
256 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100257
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000258 attr = None
259 if op["op"] == Op.NEGATE:
260 attr = ts.TosaSerializerAttribute()
261 attr.NegateAttribute(qinfo[0], qinfo[1])
262
263 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700264 return result_tens
265
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100266 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000267 result_tens = OutputShaper.binaryBroadcastOp(
268 self.ser, self.rng, a, b, error_name
269 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100270
271 # Invalidate Input/Output list for error if checks.
272 input_list = [a.name, b.name]
273 output_list = [result_tens.name]
274 pCount, cCount = op["operands"]
275 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000276 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
277 self, error_name, input_list, output_list
278 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100279
Les Bell729b0352021-11-24 10:28:21 +0000280 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100281 self.ser,
282 validator_fcns,
283 error_name,
284 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000285 input1=a,
286 input2=b,
287 input_dtype=a.dtype,
288 output_dtype=result_tens.dtype,
289 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100290 input_list=input_list,
291 output_list=output_list,
292 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000293 ):
294 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100295
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000296 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700297 return result_tens
298
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100299 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700300 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000301 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700302 return result_tens
303
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000304 def build_arithmetic_right_shift(
305 self, op, a, b, round, validator_fcns=None, error_name=None
306 ):
307 result_tens = OutputShaper.binaryBroadcastOp(
308 self.ser, self.rng, a, b, error_name
309 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100310
311 # Invalidate Input/Output list for error if checks.
312 input_list = [a.name, b.name]
313 output_list = [result_tens.name]
314 pCount, cCount = op["operands"]
315 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000316 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
317 self, error_name, input_list, output_list
318 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100319
Les Bell729b0352021-11-24 10:28:21 +0000320 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100321 self.ser,
322 validator_fcns,
323 error_name,
324 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000325 input1=a,
326 input2=b,
327 input_dtype=a.dtype,
328 output_dtype=result_tens.dtype,
329 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100330 input_list=input_list,
331 output_list=output_list,
332 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000333 ):
334 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800335
336 attr = ts.TosaSerializerAttribute()
337 attr.ArithmeticRightShiftAttribute(round)
338
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000339 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800340 return result_tens
341
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100342 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000343 result_tens = OutputShaper.binaryBroadcastOp(
344 self.ser, self.rng, a, b, error_name
345 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700346
347 # Special for multiply:
348 # Force the result to INT32 for INT types
349 if a.dtype != DType.FLOAT:
350 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100351 if error_name == ErrorIf.WrongOutputType:
352 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
353 outputDType = self.rng.choice(all_dtypes)
354 result_tens.setDtype(outputDType)
355
356 # Invalidate Input/Output list for error if checks.
357 input_list = [a.name, b.name]
358 output_list = [result_tens.name]
359 pCount, cCount = op["operands"]
360 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000361 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
362 self, error_name, input_list, output_list
363 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100364
Les Bell729b0352021-11-24 10:28:21 +0000365 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100366 self.ser,
367 validator_fcns,
368 error_name,
369 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000370 input1=a,
371 input2=b,
372 input_dtype=a.dtype,
373 output_dtype=result_tens.dtype,
374 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100375 input_list=input_list,
376 output_list=output_list,
377 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000378 ):
379 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700380
Kevin Chengaee1fac2020-11-11 13:54:06 -0800381 attr = ts.TosaSerializerAttribute()
382 attr.MulAttribute(shift)
383
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000384 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700385 return result_tens
386
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100387 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
388 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700389
Kevin Chengfe392ce2021-10-18 21:51:55 +0000390 attr = ts.TosaSerializerAttribute()
391 attr.TableAttribute(table)
392
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100393 # Invalidate Input/Output list for error if checks.
394 input_list = [a.name]
395 output_list = [result_tens.name]
396 pCount, cCount = op["operands"]
397 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000398 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
399 self, error_name, input_list, output_list
400 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100401
Les Bell729b0352021-11-24 10:28:21 +0000402 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100403 self.ser,
404 validator_fcns,
405 error_name,
406 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000407 input_shape=a.shape,
408 input_dtype=a.dtype,
409 output_dtype=result_tens.dtype,
410 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100411 input_list=input_list,
412 output_list=output_list,
413 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000414 ):
415 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100416
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000417 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700418
419 return result_tens
420
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100421 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
422 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
423
424 # Invalidate Input/Output list for error if checks.
425 input_list = [cond.name, a.name, b.name]
426 output_list = [result_tens.name]
427 pCount, cCount = op["operands"]
428 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000429 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
430 self, error_name, input_list, output_list
431 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100432
Les Bell729b0352021-11-24 10:28:21 +0000433 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100434 self.ser,
435 validator_fcns,
436 error_name,
437 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000438 input1=cond,
439 input2=a,
440 input3=b,
441 input_shape=a.shape,
442 input_dtype=a.dtype,
443 output_dtype=result_tens.dtype,
444 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100445 input_list=input_list,
446 output_list=output_list,
447 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000448 ):
449 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100450
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000451 self.ser.addOperator(
452 op["op"],
453 input_list,
454 output_list,
455 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700456 return result_tens
457
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100458 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000459 result_tens = OutputShaper.binaryComparisonOp(
460 self.ser, self.rng, a, b, error_name
461 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100462
463 # Invalidate Input/Output list for error if checks.
464 input_list = [a.name, b.name]
465 output_list = [result_tens.name]
466 pCount, cCount = op["operands"]
467 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000468 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
469 self, error_name, input_list, output_list
470 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471
Les Bell729b0352021-11-24 10:28:21 +0000472 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100473 self.ser,
474 validator_fcns,
475 error_name,
476 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000477 input1=a,
478 input2=b,
479 input_shape=a.shape,
480 input_dtype=a.dtype,
481 output_shape=result_tens.shape,
482 output_dtype=result_tens.dtype,
483 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100484 input_list=input_list,
485 output_list=output_list,
486 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000487 ):
488 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100489
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000490 self.ser.addOperator(
491 op["op"],
492 input_list,
493 output_list,
494 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700495 return result_tens
496
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100497 def build_argmax(self, op, a, axis, validator_fcns, error_name):
498 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
499
500 # Invalidate Input/Output list for error if checks.
501 input_list = [a.name]
502 output_list = [result_tens.name]
503 pCount, cCount = op["operands"]
504 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000505 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
506 self, error_name, input_list, output_list
507 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100508
Les Bell729b0352021-11-24 10:28:21 +0000509 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100510 self.ser,
511 validator_fcns,
512 error_name,
513 op=op,
514 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000515 input_shape=a.shape,
516 input_dtype=a.dtype,
517 output_shape=result_tens.shape,
518 output_dtype=result_tens.dtype,
519 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100520 input_list=input_list,
521 output_list=output_list,
522 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000523 ):
524 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700525
526 attr = ts.TosaSerializerAttribute()
527 attr.AxisAttribute(axis)
528
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000529 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700530 return result_tens
531
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000532 def build_pool2d(
533 self,
534 op,
535 input,
536 stride,
537 pad,
538 kernel,
539 validator_fcns=None,
540 error_name=None,
541 qinfo=None,
542 ):
543 result_tens = OutputShaper.pool2dOp(
544 self.ser, self.rng, input, kernel, stride, pad, error_name
545 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100546
547 # Ensure new output type has correct qinfo
548 if error_name == ErrorIf.WrongInputType:
549 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000550 qinfo = [
551 TosaQuantGen.getZeroPoint(self, input.dtype),
552 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
553 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100554
555 # Invalidate Input/Output list for error if checks.
556 input_list = [input.name]
557 output_list = [result_tens.name]
558 pCount, cCount = op["operands"]
559 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000560 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
561 self, error_name, input_list, output_list
562 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100563
Les Bell729b0352021-11-24 10:28:21 +0000564 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100565 self.ser,
566 validator_fcns,
567 error_name,
568 op=op,
569 input_shape=input.shape,
570 input_dtype=input.dtype,
571 output_shape=result_tens.shape,
572 output_dtype=result_tens.dtype,
573 kernel=kernel,
574 stride=stride,
575 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000576 qinfo=qinfo,
577 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100578 input_list=input_list,
579 output_list=output_list,
580 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000581 ):
582 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700583
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000584 if qinfo is None:
585 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700586
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000587 attr = ts.TosaSerializerAttribute()
588 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1])
589
590 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700591 return result_tens
592
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000593 def build_conv2d(
594 self,
595 op,
596 ifm,
597 filter,
598 bias,
599 strides,
600 padding,
601 dilations,
602 validator_fcns=None,
603 error_name=None,
604 qinfo=None,
605 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800606 assert len(padding) == 4
607 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000608 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
609 )
610
611 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000612 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
613 DType.INT8,
614 DType.UINT8,
615 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000616 qinfo = [
617 TosaQuantGen.getZeroPoint(self, ifm.dtype),
618 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
619 ]
Les Bell0e027d42021-11-09 14:42:14 +0000620
621 # Invalidate Input/Output list for error_if checks.
622 input_list = [ifm.name, filter.name, bias.name]
623 output_list = [result_tens.name]
624 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000625 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
626 self, error_name, input_list, output_list
627 )
Les Bell0e027d42021-11-09 14:42:14 +0000628
Les Bell729b0352021-11-24 10:28:21 +0000629 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000630 self.ser,
631 validator_fcns,
632 error_name,
633 op=op,
634 input_dtype=ifm.dtype,
635 weight_dtype=filter.dtype,
636 output_dtype=result_tens.dtype,
637 qinfo=qinfo,
638 input_list=input_list,
639 num_operands=num_operands,
640 output_list=output_list,
641 pad=padding,
642 stride=strides,
643 dilation=dilations,
644 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100645 weight_shape=filter.shape,
646 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000647 ):
648 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700649
650 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000651 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700652
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000653 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700654 return result_tens
655
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000656 def build_conv3d(
657 self,
658 op,
659 ifm,
660 filter,
661 bias,
662 strides,
663 padding,
664 dilations,
665 validator_fcns=None,
666 error_name=None,
667 qinfo=None,
668 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700669 assert len(padding) == 6
670 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000671 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
672 )
673
674 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000675 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
676 DType.INT8,
677 DType.UINT8,
678 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000679 qinfo = [
680 TosaQuantGen.getZeroPoint(self, ifm.dtype),
681 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
682 ]
Les Bell0e027d42021-11-09 14:42:14 +0000683
684 # Invalidate Input/Output list for error_if checks.
685 input_list = [ifm.name, filter.name, bias.name]
686 output_list = [result_tens.name]
687 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000688 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
689 self, error_name, input_list, output_list
690 )
Les Bell0e027d42021-11-09 14:42:14 +0000691
Les Bell729b0352021-11-24 10:28:21 +0000692 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000693 self.ser,
694 validator_fcns,
695 error_name,
696 op=op,
697 input_dtype=ifm.dtype,
698 weight_dtype=filter.dtype,
699 output_dtype=result_tens.dtype,
700 qinfo=qinfo,
701 input_list=input_list,
702 num_operands=num_operands,
703 output_list=output_list,
704 pad=padding,
705 stride=strides,
706 dilation=dilations,
707 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100708 weight_shape=filter.shape,
709 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000710 ):
711 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700712
713 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000714 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700715
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000716 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700717 return result_tens
718
Kevin Cheng550ccc52021-03-03 11:21:43 -0800719 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000720 self,
721 op,
722 ifm,
723 filter,
724 bias,
725 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700726 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000727 output_shape,
728 validator_fcns=None,
729 error_name=None,
730 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800731 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700732 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000733 result_tens = OutputShaper.transposeConv2DOp(
734 self.ser, self.rng, ifm, output_shape, error_name
735 )
Les Bell0e027d42021-11-09 14:42:14 +0000736
737 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000738 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
739 DType.INT8,
740 DType.UINT8,
741 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000742 qinfo = [
743 TosaQuantGen.getZeroPoint(self, ifm.dtype),
744 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
745 ]
Les Bell0e027d42021-11-09 14:42:14 +0000746
747 # Invalidate Input/Output list for error_if checks.
748 input_list = [ifm.name, filter.name, bias.name]
749 output_list = [result_tens.name]
750 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000751 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
752 self, error_name, input_list, output_list
753 )
Les Bell0e027d42021-11-09 14:42:14 +0000754
Les Bell729b0352021-11-24 10:28:21 +0000755 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000756 self.ser,
757 validator_fcns,
758 error_name,
759 op=op,
760 input_dtype=ifm.dtype,
761 weight_dtype=filter.dtype,
762 output_dtype=result_tens.dtype,
763 qinfo=qinfo,
764 input_list=input_list,
765 num_operands=num_operands,
766 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700767 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000768 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000769 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100770 weight_shape=filter.shape,
771 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000772 ):
773 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700774
775 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000776 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700777
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000778 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700779 return result_tens
780
Kevin Cheng550ccc52021-03-03 11:21:43 -0800781 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000782 self,
783 op,
784 ifm,
785 filter,
786 bias,
787 strides,
788 padding,
789 dilations,
790 validator_fcns=None,
791 error_name=None,
792 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800793 ):
794 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000795 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
796 )
797
798 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000799 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
800 DType.INT8,
801 DType.UINT8,
802 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000803 qinfo = [
804 TosaQuantGen.getZeroPoint(self, ifm.dtype),
805 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
806 ]
Les Bell0e027d42021-11-09 14:42:14 +0000807
808 # Invalidate Input/Output list for error_if checks.
809 input_list = [ifm.name, filter.name, bias.name]
810 output_list = [result_tens.name]
811 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000812 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
813 self, error_name, input_list, output_list
814 )
Les Bell0e027d42021-11-09 14:42:14 +0000815
Les Bell729b0352021-11-24 10:28:21 +0000816 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000817 self.ser,
818 validator_fcns,
819 error_name,
820 op=op,
821 input_dtype=ifm.dtype,
822 weight_dtype=filter.dtype,
823 output_dtype=result_tens.dtype,
824 qinfo=qinfo,
825 input_list=input_list,
826 num_operands=num_operands,
827 output_list=output_list,
828 pad=padding,
829 stride=strides,
830 dilation=dilations,
831 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100832 weight_shape=filter.shape,
833 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000834 ):
835 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700836
837 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000838 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700839
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000840 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700841 return result_tens
842
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000843 def build_fully_connected(
844 self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
845 ):
846 result_tens = OutputShaper.fullyConnectedOp(
847 self.ser, self.rng, ifm, filter, error_name
848 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100849
850 # Invalidate Input/Output list for error if checks.
851 input_list = [ifm.name, filter.name, bias.name]
852 output_list = [result_tens.name]
853 pCount, cCount = op["operands"]
854 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000855 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
856 self, error_name, input_list, output_list
857 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100858
Les Bell729b0352021-11-24 10:28:21 +0000859 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100860 self.ser,
861 validator_fcns,
862 error_name,
863 op=op,
864 input_shape=ifm.shape,
865 input_dtype=ifm.dtype,
866 weight_dtype=filter.dtype,
867 output_shape=result_tens.shape,
868 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000869 qinfo=qinfo,
870 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100871 input_list=input_list,
872 output_list=output_list,
873 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000874 ):
875 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700876
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000877 attr = ts.TosaSerializerAttribute()
878 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
879
880 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700881 return result_tens
882
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100883 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
884 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
885
886 # Invalidate Input/Output list for error if checks.
887 input_list = [a.name, b.name]
888 output_list = [result_tens.name]
889 pCount, cCount = op["operands"]
890 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000891 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
892 self, error_name, input_list, output_list
893 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100894
Les Bell729b0352021-11-24 10:28:21 +0000895 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100896 self.ser,
897 validator_fcns,
898 error_name,
899 op=op,
900 input_shape=a.shape,
901 input_dtype=a.dtype,
902 input2_shape=b.shape,
903 input2_dtype=b.dtype,
904 output_shape=result_tens.shape,
905 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 qinfo=qinfo,
907 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100908 input_list=input_list,
909 output_list=output_list,
910 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000911 ):
912 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100913
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000914 attr = ts.TosaSerializerAttribute()
915 attr.MatMulAttribute(qinfo[0], qinfo[1])
916
917 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700918 return result_tens
919
Matthew Haddond6ce7252021-09-29 15:35:44 +0100920 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
921 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
922
923 # Invalidate Input/Output list for error if checks.
924 input_list = [a.name]
925 output_list = [result_tens.name]
926 pCount, cCount = op["operands"]
927 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000928 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
929 self, error_name, input_list, output_list
930 )
Matthew Haddond6ce7252021-09-29 15:35:44 +0100931
Les Bell729b0352021-11-24 10:28:21 +0000932 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +0100933 self.ser,
934 validator_fcns,
935 error_name,
936 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000937 axis=axis,
938 input_shape=a.shape,
939 output_shape=result_tens.shape,
940 input_dtype=a.dtype,
941 output_dtype=result_tens.dtype,
942 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +0100943 input_list=input_list,
944 output_list=output_list,
945 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000946 ):
947 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700948
949 attr = ts.TosaSerializerAttribute()
950 attr.AxisAttribute(axis)
951
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000952 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700953 return result_tens
954
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100955 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
956 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700957
Jeremy Johnson18e26662021-07-22 16:15:29 +0100958 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700959
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100960 if error_name == ErrorIf.MaxSmallerMin:
961 # Make sure the numbers are different to invoke this error
962 while v[0] == v[1]:
963 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
964 max_val = min(v)
965 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700966 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100967 max_val = max(v)
968 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700969
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100970 # Invalidate Input/Output list for error if checks.
971 input_list = [a.name]
972 output_list = [result_tens.name]
973 pCount, cCount = op["operands"]
974 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000975 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
976 self, error_name, input_list, output_list
977 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100978
Les Bell729b0352021-11-24 10:28:21 +0000979 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100980 self.ser,
981 validator_fcns,
982 error_name,
983 op=op,
984 max_val=max_val,
985 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000986 input_shape=a.shape,
987 output_shape=result_tens.shape,
988 input_dtype=a.dtype,
989 output_dtype=result_tens.dtype,
990 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100991 input_list=input_list,
992 output_list=output_list,
993 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000994 ):
995 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100996
997 attr = ts.TosaSerializerAttribute()
998 if a.dtype == DType.FLOAT:
999 attr.ClampAttribute(0, 0, min_val, max_val)
1000 else:
1001 attr.ClampAttribute(min_val, max_val, 0, 0)
1002
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001003 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001004 return result_tens
1005
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001006 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1007 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001008 attr = ts.TosaSerializerAttribute()
1009
1010 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1011
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001012 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001013 return result_tens
1014
1015 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001016 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1017 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001018
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001019 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001020 return result_tens
1021
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001022 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1023 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1024
1025 # Invalidate Input/Output list for error if checks.
1026 input_list = [a.name]
1027 output_list = [result_tens.name]
1028 pCount, cCount = op["operands"]
1029 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001030 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1031 self, error_name, input_list, output_list
1032 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001033
Les Bell729b0352021-11-24 10:28:21 +00001034 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001035 self.ser,
1036 validator_fcns,
1037 error_name,
1038 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001039 input_shape=a.shape,
1040 output_shape=result_tens.shape,
1041 input_dtype=a.dtype,
1042 output_dtype=result_tens.dtype,
1043 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001044 input_list=input_list,
1045 output_list=output_list,
1046 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001047 ):
1048 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001049
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001050 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001051 return result_tens
1052
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001053 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1054 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1055
1056 # Invalidate Input/Output list for error if checks.
1057 input_list = [a.name]
1058 output_list = [result_tens.name]
1059 pCount, cCount = op["operands"]
1060 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001061 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1062 self, error_name, input_list, output_list
1063 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001064
Les Bell729b0352021-11-24 10:28:21 +00001065 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001066 self.ser,
1067 validator_fcns,
1068 error_name,
1069 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001070 input_shape=a.shape,
1071 output_shape=result_tens.shape,
1072 input_dtype=a.dtype,
1073 output_dtype=result_tens.dtype,
1074 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001075 input_list=input_list,
1076 output_list=output_list,
1077 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001078 ):
1079 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001080
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001081 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001082 return result_tens
1083
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001084 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1085 if error_name != ErrorIf.WrongInputType:
1086 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001087
1088 # To store variable length list of input tensors we need to store axis along with it
1089 axis = a[-1]
1090 a = a[:-1]
1091
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001092 result_tens = OutputShaper.concatOp(
1093 self.ser, self.rng, axis, *a, error_name=error_name
1094 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001095
Matthew Haddon818ab902021-07-27 09:12:49 +01001096 input_tensor_names = []
1097 for tensor in a:
1098 input_tensor_names.append(tensor.name)
1099
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001100 # Invalidate Input/Output list for error if checks.
1101 input_list = input_tensor_names
1102 output_list = [result_tens.name]
1103 pCount, cCount = op["operands"]
1104 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001105 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1106 self, error_name, input_list, output_list
1107 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001108
Les Bell729b0352021-11-24 10:28:21 +00001109 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001110 self.ser,
1111 validator_fcns,
1112 error_name,
1113 op=op,
1114 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001115 input_shape=a[0].shape,
1116 output_shape=result_tens.shape,
1117 input_dtype=a[0].dtype,
1118 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001119 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001120 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001121 input_list=input_list,
1122 output_list=output_list,
1123 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001124 ):
1125 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001126
1127 attr = ts.TosaSerializerAttribute()
1128 attr.AxisAttribute(axis)
1129
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001130 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001131 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001132
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001133 def build_pad(
1134 self,
1135 op,
1136 a,
1137 padding,
1138 pad_const_int,
1139 pad_const_float,
1140 validator_fcns=None,
1141 error_name=None,
1142 qinfo=None,
1143 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001144 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001145
Kevin Chengfe392ce2021-10-18 21:51:55 +00001146 attr = ts.TosaSerializerAttribute()
1147 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001148
Matthew Haddone807aae2021-10-11 18:12:58 +01001149 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001150 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001151 output_list = [result_tens.name]
1152 pCount, cCount = op["operands"]
1153 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001154 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1155 self, error_name, input_list, output_list
1156 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001157
Les Bell729b0352021-11-24 10:28:21 +00001158 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001159 self.ser,
1160 validator_fcns,
1161 error_name,
1162 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001163 input_shape=a.shape,
1164 output_shape=result_tens.shape,
1165 input_dtype=a.dtype,
1166 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001167 pad=padding,
1168 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001169 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001170 input_list=input_list,
1171 output_list=output_list,
1172 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001173 ):
1174 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001175
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001176 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001177 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001178
Matthew Haddone807aae2021-10-11 18:12:58 +01001179 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001180 result_tens = OutputShaper.reshapeOp(
1181 self.ser, self.rng, a, newShape, error_name
1182 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001183
1184 # Invalidate Input/Output list for error if checks.
1185 input_list = [a.name]
1186 output_list = [result_tens.name]
1187 pCount, cCount = op["operands"]
1188 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001189 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1190 self, error_name, input_list, output_list
1191 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001192
Les Bell729b0352021-11-24 10:28:21 +00001193 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001194 self.ser,
1195 validator_fcns,
1196 error_name,
1197 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001198 input_shape=a.shape,
1199 output_shape=result_tens.shape,
1200 input_dtype=a.dtype,
1201 output_dtype=result_tens.dtype,
1202 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001203 input_list=input_list,
1204 output_list=output_list,
1205 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001206 ):
1207 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001208
1209 attr = ts.TosaSerializerAttribute()
1210 attr.ReshapeAttribute(newShape)
1211
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001212 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001213 return result_tens
1214
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001215 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1216 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1217
1218 # Invalidate Input/Output list for error if checks.
1219 input_list = [a.name]
1220 output_list = [result_tens.name]
1221 pCount, cCount = op["operands"]
1222 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001223 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1224 self, error_name, input_list, output_list
1225 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001226
Les Bell729b0352021-11-24 10:28:21 +00001227 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001228 self.ser,
1229 validator_fcns,
1230 error_name,
1231 op=op,
1232 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001233 input_shape=a.shape,
1234 output_shape=result_tens.shape,
1235 input_dtype=a.dtype,
1236 output_dtype=result_tens.dtype,
1237 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001238 input_list=input_list,
1239 output_list=output_list,
1240 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001241 ):
1242 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001243
1244 attr = ts.TosaSerializerAttribute()
1245 attr.AxisAttribute(axis)
1246
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001247 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001248 return result_tens
1249
Matthew Haddone807aae2021-10-11 18:12:58 +01001250 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1251 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001252
Kevin Chengfe392ce2021-10-18 21:51:55 +00001253 attr = ts.TosaSerializerAttribute()
1254 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001255
Matthew Haddone807aae2021-10-11 18:12:58 +01001256 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001257 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001258 output_list = [result_tens.name]
1259 pCount, cCount = op["operands"]
1260 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001261 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1262 self, error_name, input_list, output_list
1263 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001264
Les Bell729b0352021-11-24 10:28:21 +00001265 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001266 self.ser,
1267 validator_fcns,
1268 error_name,
1269 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001270 input_shape=a.shape,
1271 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001272 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001273 input_dtype=a.dtype,
1274 output_dtype=result_tens.dtype,
1275 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001276 input_list=input_list,
1277 output_list=output_list,
1278 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001279 ):
1280 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001281
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001282 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001283 return result_tens
1284
Matthew Haddone807aae2021-10-11 18:12:58 +01001285 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001286 result_tens = OutputShaper.sliceOp(
1287 self.ser, self.rng, a, start, size, error_name
1288 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001289
1290 # Invalidate Input/Output list for error if checks.
1291 input_list = [a.name]
1292 output_list = [result_tens.name]
1293 pCount, cCount = op["operands"]
1294 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001295 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1296 self, error_name, input_list, output_list
1297 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001298
Les Bell729b0352021-11-24 10:28:21 +00001299 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001300 self.ser,
1301 validator_fcns,
1302 error_name,
1303 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001304 input_shape=a.shape,
1305 output_shape=result_tens.shape,
1306 input_dtype=a.dtype,
1307 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001308 start=start,
1309 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001311 input_list=input_list,
1312 output_list=output_list,
1313 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001314 ):
1315 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001316
1317 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001318 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001319
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001320 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001321 return result_tens
1322
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001323 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1324 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1325
1326 # Invalidate Input/Output list for error if checks.
1327 input_list = [a.name]
1328 output_list = [result_tens.name]
1329 pCount, cCount = op["operands"]
1330 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001331 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1332 self, error_name, input_list, output_list
1333 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001334
Les Bell729b0352021-11-24 10:28:21 +00001335 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001336 self.ser,
1337 validator_fcns,
1338 error_name,
1339 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001340 input_shape=a.shape,
1341 output_shape=result_tens.shape,
1342 input_dtype=a.dtype,
1343 output_dtype=result_tens.dtype,
1344 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001345 input_list=input_list,
1346 output_list=output_list,
1347 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001348 ):
1349 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001350
1351 attr = ts.TosaSerializerAttribute()
1352 attr.TileAttribute(multiples)
1353
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001354 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001355 return result_tens
1356
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001357 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001358
1359 # Create a new indicies tensor
1360 # here with data that doesn't exceed the dimensions of the values tensor
1361
Kevin Cheng550ccc52021-03-03 11:21:43 -08001362 K = values.shape[1] # K
1363 W = self.randInt(
1364 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1365 ) # W
1366 indicies_arr = np.int32(
1367 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1368 ) # (N, W)
1369 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001370
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 result_tens = OutputShaper.gatherOp(
1372 self.ser, self.rng, values, indicies, error_name
1373 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001374
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001375 # Invalidate Input/Output list for error if checks.
1376 input_list = [values.name, indicies.name]
1377 output_list = [result_tens.name]
1378 pCount, cCount = op["operands"]
1379 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001380 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1381 self, error_name, input_list, output_list
1382 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001383
Les Bell729b0352021-11-24 10:28:21 +00001384 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001385 self.ser,
1386 validator_fcns,
1387 error_name,
1388 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001389 input_shape=values.shape,
1390 output_shape=result_tens.shape,
1391 input_dtype=values.dtype,
1392 output_dtype=result_tens.dtype,
1393 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001394 input_list=input_list,
1395 output_list=output_list,
1396 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001397 ):
1398 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001399
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001400 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001401
1402 return result_tens
1403
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001404 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001405
1406 # Create a new indicies tensor
1407 # here with data that doesn't exceed the dimensions of the values_in tensor
1408
Kevin Cheng550ccc52021-03-03 11:21:43 -08001409 K = values_in.shape[1] # K
1410 W = input.shape[1] # W
1411 indicies_arr = np.int32(
1412 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1413 ) # (N, W)
1414 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001415
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001416 result_tens = OutputShaper.scatterOp(
1417 self.ser, self.rng, values_in, indicies, input, error_name
1418 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001419
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001420 # Invalidate Input/Output list for error if checks.
1421 input_list = [values_in.name, indicies.name, input.name]
1422 output_list = [result_tens.name]
1423 pCount, cCount = op["operands"]
1424 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001425 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1426 self, error_name, input_list, output_list
1427 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001428
Les Bell729b0352021-11-24 10:28:21 +00001429 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001430 self.ser,
1431 validator_fcns,
1432 error_name,
1433 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001434 input_shape=values_in.shape,
1435 output_shape=result_tens.shape,
1436 input_dtype=values_in.dtype,
1437 output_dtype=result_tens.dtype,
1438 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001439 input_list=input_list,
1440 output_list=output_list,
1441 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001442 ):
1443 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001444
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001445 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001446
Kevin Cheng77d0f762020-11-24 10:26:32 -08001447 return result_tens
1448
Kevin Cheng550ccc52021-03-03 11:21:43 -08001449 def build_resize(
1450 self,
1451 op,
1452 input,
1453 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001454 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001455 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001456 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001457 input_dtype,
1458 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001459 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001460 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001461 ):
1462 result_tens = OutputShaper.resizeOp(
1463 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001464 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001465 input,
1466 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001467 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001468 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001469 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001470 input_dtype,
1471 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001472 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001473 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001474
Matthew Haddon848efb42021-09-09 12:30:53 +01001475 # Invalidate Input/Output list for error if checks.
1476 input_list = [input.name]
1477 output_list = [result_tens.name]
1478 pCount, cCount = op["operands"]
1479 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001480 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1481 self, error_name, input_list, output_list
1482 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001483
Les Bell729b0352021-11-24 10:28:21 +00001484 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001485 self.ser,
1486 validator_fcns,
1487 error_name,
1488 op=op,
1489 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001490 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001491 input_dtype=input_dtype,
1492 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001493 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001494 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001495 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001496 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001497 input_list=input_list,
1498 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001499 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001500 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001501 ):
1502 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001503
Eric Kunzee5e26762020-10-13 16:11:07 -07001504 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001505
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001506 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001507
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001508 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001509 return result_tens
1510
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001511 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1512 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1513 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001514 self.ser.addOperator(
1515 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1516 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001517 return result_tens
1518
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001519 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001520 self.ser.addOutputTensor(val)
1521 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001522
1523 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001524 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001525 result_tens = OutputShaper.typeConversionOp(
1526 self.ser, self.rng, val, out_dtype, error_name
1527 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001528
1529 # Invalidate Input/Output list for error if checks.
1530 input_list = [val.name]
1531 output_list = [result_tens.name]
1532 pCount, cCount = op["operands"]
1533 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001534 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1535 self, error_name, input_list, output_list
1536 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001537
Les Bell729b0352021-11-24 10:28:21 +00001538 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001539 self.ser,
1540 validator_fcns,
1541 error_name,
1542 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001543 input_shape=val.shape,
1544 output_shape=result_tens.shape,
1545 input_dtype=val.dtype,
1546 output_dtype=result_tens.dtype,
1547 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001548 input_list=input_list,
1549 output_list=output_list,
1550 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001551 ):
1552 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001553
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001554 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001555 return result_tens
1556
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001557 def build_rescale(
1558 self,
1559 op,
1560 val,
1561 out_dtype,
1562 scale32,
1563 double_round,
1564 per_channel,
1565 validator_fcns,
1566 error_name,
1567 ):
1568 result_tens = OutputShaper.typeConversionOp(
1569 self.ser, self.rng, val, out_dtype, error_name
1570 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001571
1572 if per_channel:
1573 nc = val.shape[-1]
1574 else:
1575 nc = 1
1576
1577 in_type_width = self.typeWidth(val.dtype)
1578 out_type_width = self.typeWidth(out_dtype)
1579
Kevin Cheng3a478572021-01-22 17:21:02 -08001580 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001581 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001582 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001583 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001584 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001585 in_type_width += 1
1586 elif error_name in [
1587 ErrorIf.InputZeroPointNotZero,
1588 ErrorIf.U16InputZeroPointNotValid,
1589 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001590 input_zp = self.randInt(-128, 128)
1591 if input_zp == 0:
1592 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001593 in_type_width += 1
1594 elif val.dtype == DType.UINT16:
1595 # Must come after ErrorIf.U16InputZeroPointNotValid check
1596 input_zp = self.rng.choice([0, 32768])
1597 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001598 else:
1599 input_zp = 0
1600
Kevin Cheng3a478572021-01-22 17:21:02 -08001601 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001602 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001603 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001604 elif out_dtype == DType.UINT8:
1605 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001606 out_type_width += 1
1607 elif error_name in [
1608 ErrorIf.OutputZeroPointNotZero,
1609 ErrorIf.U16OutputZeroPointNotValid,
1610 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001611 output_zp = self.randInt(-128, 128)
1612 if output_zp == 0:
1613 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001614 out_type_width += 1
1615 elif out_dtype == DType.UINT16:
1616 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1617 output_zp = self.rng.choice([0, 32768])
1618 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001619 else:
1620 output_zp = 0
1621
1622 # Calculate scale based on:
1623 # scale = a *(2^output_width)/(2^input_width))
1624
1625 a = np.float32(self.rng.random(size=[nc]))
1626 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1627
1628 if scale32:
1629 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001630 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001631 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1632 else:
1633 # Cap the scaling at 2^15 - 1 for scale16
1634 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1635
Kevin Cheng550ccc52021-03-03 11:21:43 -08001636 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001637
1638 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1639 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001640 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1641 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001642
1643 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001644 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1645 scale_arr[i], scale32
1646 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001647 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1648 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001649
Kevin Cheng550ccc52021-03-03 11:21:43 -08001650 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001651 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001652 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001653 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001654 assert val.placeholderFilename
1655 values = np.load(
1656 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1657 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001658 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1659 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1660 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1661 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001662 if not np.all(np.array_equal(values, val_adj)):
1663 # Values changed so overwrite file with new values
1664 np.save(
1665 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1666 val_adj,
1667 False,
1668 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001669
Matthew Haddonc2025212021-10-08 21:21:05 +01001670 # Invalidate Input/Output list for error if checks.
1671 input_list = [val.name]
1672 output_list = [result_tens.name]
1673 pCount, cCount = op["operands"]
1674 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001675 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1676 self, error_name, input_list, output_list
1677 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001678
1679 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001680 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001681 self.ser,
1682 validator_fcns,
1683 error_name,
1684 op=op,
1685 input_dtype=val.dtype,
1686 output_dtype=out_dtype,
1687 input_shape=val.shape,
1688 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001689 scale32=scale32,
1690 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001691 input_list=input_list,
1692 output_list=output_list,
1693 result_tensor=result_tens,
1694 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001695 ):
1696 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001697
Eric Kunzee5e26762020-10-13 16:11:07 -07001698 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001699 attr.RescaleAttribute(
1700 input_zp,
1701 output_zp,
1702 multiplier_arr,
1703 shift_arr,
1704 scale32,
1705 double_round,
1706 per_channel,
1707 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001709 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001710 return result_tens
1711
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001712 def build_cond_if_const(
1713 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1714 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001715 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1716 # (except for the generated shap) and the condition. Build Then/Else blocks
1717 # and fill them with const nodes for the body.
1718
1719 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001720 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001721
1722 # Make then/else tensors
1723 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001724
1725 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001726 if error_name in [
1727 ErrorIf.CondIfOutputListThenGraphMismatch,
1728 ErrorIf.CondIfOutputListElseGraphMismatch,
1729 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001730 incorrect_shape = deepcopy(then_tens.shape)
1731 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001732 incorrect_shape[i] += (
1733 self.rng.choice([-3, -2, 2, 3])
1734 if incorrect_shape[i] > 3
1735 else self.rng.choice([1, 2, 4])
1736 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001737 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1738
Jeremy Johnson18e26662021-07-22 16:15:29 +01001739 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1740 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001741
1742 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001743 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001744
1745 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001746 then_block = "THEN_BLOCK"
1747 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001748 attr = ts.TosaSerializerAttribute()
1749 attr.CondIfAttribute(then_block, else_block)
1750
1751 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001752 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001753
1754 self.ser.startBasicBlock(then_block)
1755 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001756 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1757 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1758 else:
1759 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001760 self.ser.addOutputTensor(then_tens)
1761
1762 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001763 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1764 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1765 else:
1766 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001767 self.ser.addOutputTensor(else_tens)
1768
Les Bell729b0352021-11-24 10:28:21 +00001769 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001770 self.ser,
1771 validator_fcns,
1772 error_name,
1773 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001774 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001775 ):
1776 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001777
Eric Kunzee5e26762020-10-13 16:11:07 -07001778 return result_tens
1779
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001780 def build_cond_if_binary(
1781 self, op, a, b, cond, validator_fcns=None, error_name=None
1782 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001783 # For cond_if with a binary op in the then/else blocks, take a and b and
1784 # alternately add or subtract them based on the condition
1785
1786 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001787 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001788
Kevin Cheng550ccc52021-03-03 11:21:43 -08001789 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001790
1791 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001792 then_block = "THEN_BLOCK"
1793 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001794 attr = ts.TosaSerializerAttribute()
1795 attr.CondIfAttribute(then_block, else_block)
1796
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001797 if error_name in [
1798 ErrorIf.CondIfInputListThenGraphMismatch,
1799 ErrorIf.CondIfInputListElseGraphMismatch,
1800 ErrorIf.CondIfOutputListElseGraphMismatch,
1801 ErrorIf.CondIfOutputListThenGraphMismatch,
1802 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001803 incorrect_shape = a.shape.copy()
1804 for i in range(len(incorrect_shape)):
1805 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1806 incorrect_block_input = deepcopy(a)
1807 incorrect_block_input.shape = incorrect_shape
1808
Eric Kunzee5e26762020-10-13 16:11:07 -07001809 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001810 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001811 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001812 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001813
Les Bell6040b4d2021-10-11 12:50:31 +01001814 if a.dtype in (DType.FLOAT, DType.INT32):
1815 then_op, else_op = Op.ADD, Op.SUB
1816 elif a.dtype in (DType.INT8, DType.INT16):
1817 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1818 else:
1819 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001820
Les Bell6040b4d2021-10-11 12:50:31 +01001821 for block, op in ((then_block, then_op), (else_block, else_op)):
1822 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001823 if (
1824 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1825 and block == then_block
1826 ) or (
1827 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1828 and block == else_block
1829 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001830 self.ser.addInputTensor(incorrect_block_input)
1831 self.ser.addInputTensor(b)
1832 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001833 elif (
1834 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1835 and block == then_block
1836 ) or (
1837 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1838 and block == else_block
1839 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001840 self.ser.addInputTensor(a)
1841 self.ser.addInputTensor(b)
1842 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1843 else:
1844 self.ser.addInputTensor(a)
1845 self.ser.addInputTensor(b)
1846 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001847 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001848
Les Bell729b0352021-11-24 10:28:21 +00001849 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001850 self.ser,
1851 validator_fcns,
1852 error_name,
1853 op=op,
1854 a=a,
1855 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001856 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001857 ):
1858 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001859
Eric Kunzee5e26762020-10-13 16:11:07 -07001860 return result_tens
1861
Matthew Haddon630c17c2021-10-14 15:05:41 +01001862 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001863 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001864
Kevin Cheng550ccc52021-03-03 11:21:43 -08001865 cond_block = "COND_BLOCK"
1866 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001867
1868 attr = ts.TosaSerializerAttribute()
1869 attr.WhileLoopAttribute(cond_block, body_block)
1870
1871 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001872 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001873 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001874 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001875
1876 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001877 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1878 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001879 if error_name == ErrorIf.InputListOutputListMismatch:
1880 incorrect_acc = deepcopy(acc)
1881 for i in range(len(incorrect_acc.shape)):
1882 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1883 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1884 else:
1885 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001886
1887 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001888 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001889 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001890 [iter.name, a.name, acc.name],
1891 [iter_out.name, a_out.name, acc_out.name],
1892 attr,
1893 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001894 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001895
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001896 if error_name in [
1897 ErrorIf.InputListCondGraphMismatch,
1898 ErrorIf.InputListBodyGraphInputMismatch,
1899 ErrorIf.InputListBodyGraphOutputMismatch,
1900 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001901 incorrect_iter = deepcopy(iter)
1902 for i in range(len(incorrect_iter.shape)):
1903 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
1904 if len(incorrect_iter.shape) == 0:
1905 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
1906
1907 incorrect_acc = deepcopy(acc)
1908 for i in range(len(incorrect_acc.shape)):
1909 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1910
Eric Kunzee5e26762020-10-13 16:11:07 -07001911 # COND block (input: iter, output: cond_tens )
1912 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001913 if error_name == ErrorIf.InputListCondGraphMismatch:
1914 self.ser.addInputTensor(incorrect_iter)
1915 self.ser.addInputTensor(a)
1916 self.ser.addInputTensor(incorrect_acc)
1917 else:
1918 self.ser.addInputTensor(iter)
1919 self.ser.addInputTensor(a)
1920 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001921 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001922
1923 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001924 cond_tens = self.ser.addOutput(
1925 [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
1926 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001927 else:
1928 cond_tens = self.ser.addOutput([], DType.BOOL)
1929
Kevin Cheng550ccc52021-03-03 11:21:43 -08001930 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001931
1932 # BODY block (input: a, acc, iter, output: a, acc, iter)
1933 # Note that local intermediate tensors need to be declared here for the outputs
1934 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001935 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
1936 self.ser.addInputTensor(incorrect_iter)
1937 self.ser.addInputTensor(a)
1938 self.ser.addInputTensor(incorrect_acc)
1939 else:
1940 self.ser.addInputTensor(iter)
1941 self.ser.addInputTensor(a)
1942 self.ser.addInputTensor(acc)
1943
Kevin Cheng550ccc52021-03-03 11:21:43 -08001944 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001945
1946 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001947 iter_body_out = self.ser.addIntermediate(
1948 incorrect_iter.shape, incorrect_iter.dtype
1949 )
1950 acc_body_out = self.ser.addIntermediate(
1951 incorrect_acc.shape, incorrect_acc.dtype
1952 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001953 else:
1954 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1955 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
1956
Eric Kunzee5e26762020-10-13 16:11:07 -07001957 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1958 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1959 self.ser.addOutputTensor(iter_body_out)
1960 self.ser.addOutputTensor(a)
1961 self.ser.addOutputTensor(acc_body_out)
1962
Les Bell729b0352021-11-24 10:28:21 +00001963 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001964 self.ser,
1965 validator_fcns,
1966 error_name,
1967 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001968 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001969 ):
1970 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001971
Eric Kunzee5e26762020-10-13 16:11:07 -07001972 return acc_out
1973
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001974 def create_filter_lists(
1975 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
1976 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01001977 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1978 default_test_rank_range = range(1, 5)
1979 if not shapeFilter:
1980 shapeFilter = [None]
1981
1982 # Calculate the filters based on what is requested and what the operator allows
1983 rmin, rmax = op["rank"]
1984 if rankFilter is not None:
1985 cleanRankFilter = []
1986 # Ensure rankFilter values are allowed by operator
1987 for rank in rankFilter:
1988 if rank >= rmin and rank <= rmax:
1989 cleanRankFilter.append(rank)
1990 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01001991 # Ensure default behaviour is bounded by default range or by operator,
1992 # whichever is the smaller range of ranks.
1993 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001994 cleanRankFilter = (
1995 opRankRange
1996 if len(opRankRange) <= len(default_test_rank_range)
1997 else default_test_rank_range
1998 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01001999 else:
2000 cleanRankFilter = range(rmin, rmax + 1)
2001
2002 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002003
Matthew Haddon1c00b712021-10-01 15:51:03 +01002004 if dtypeFilter is not None:
2005 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002006 # Create list of operator dtypes filtered by requested dtypes
2007 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002008 if dtype in dtypeFilter or (
2009 isinstance(dtype, list) and dtype[0] in dtypeFilter
2010 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002011 cleanDtypeFilter.append(dtype)
2012 else:
2013 cleanDtypeFilter = dtypes
2014
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002015 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002016 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002017 "shapeFilter": shapeFilter,
2018 "rankFilter": cleanRankFilter,
2019 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002020 }
2021 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002022 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002023 if validator is not None:
2024 validator_info = validator(check=False, op=op)
2025 else:
2026 return None
2027
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002028 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002029
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002030 # Set parameters as required
2031 if error_arguments["rank"] is not None:
2032 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002033 else:
2034 rankFilter = cleanRankFilter
2035
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002036 if error_arguments["dtype"] is not None:
2037 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002038 else:
2039 dtypeFilter = cleanDtypeFilter
2040
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002041 if error_arguments["shape"] is not None:
2042 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002043 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002044 shapeFilter = shapeFilter[
2045 :2
2046 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002047
2048 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002049 "shapeFilter": shapeFilter,
2050 "rankFilter": rankFilter,
2051 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002052 }
2053 return filterDict
2054
Kevin Cheng550ccc52021-03-03 11:21:43 -08002055 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002056 self,
2057 opName,
2058 shapeFilter=[None],
2059 rankFilter=None,
2060 dtypeFilter=None,
2061 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002062 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002063
2064 try:
2065 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002066 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002067 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002068
2069 # Initialize a new random number generator
2070 self.rng = np.random.default_rng(self.random_seed)
2071
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002072 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002073
Eric Kunzee5e26762020-10-13 16:11:07 -07002074 # Test list consists of a tuple of:
2075 # (opName, testNameStr, dtype, shapeList, argumentsList)
2076 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002077 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002078 error_if_validators = op["error_if_validators"]
2079 else:
2080 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002081
Matthew Haddon1c00b712021-10-01 15:51:03 +01002082 for validator in error_if_validators:
2083 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002084 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002085 else:
2086 error_name = None
2087
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002088 filterDict = self.create_filter_lists(
2089 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2090 )
2091 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002092 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002093 cleanRankFilter = filterDict["rankFilter"]
2094 cleanDtypeFilter = filterDict["dtypeFilter"]
2095 cleanShapeFilter = filterDict["shapeFilter"]
2096 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002097
2098 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002099 for t in cleanDtypeFilter:
2100 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002101 # Filter out by rank
2102 if shape is not None and len(shape) != r:
2103 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002104 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002105 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002106
Matthew Haddon74567092021-07-16 15:38:20 +01002107 shapeStr = self.shapeStr(shapeList[0])
2108 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002109
Matthew Haddon74567092021-07-16 15:38:20 +01002110 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2111 argList = []
2112 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002113 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002114 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002115 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002116
Matthew Haddon74567092021-07-16 15:38:20 +01002117 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002118 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002119 if argStr:
2120 testStr = "{}_{}_{}_{}".format(
2121 opName, shapeStr, typeStr, argStr
2122 )
2123 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002124 testStr = "{}_{}_{}".format(
2125 opName, shapeStr, typeStr
2126 )
2127 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002128 if argStr:
2129 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2130 opName, error_name, shapeStr, typeStr, argStr
2131 )
2132 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 testStr = "{}_ERRORIF_{}_{}_{}".format(
2134 opName, error_name, shapeStr, typeStr
2135 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002136
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002137 testList.append(
2138 (opName, testStr, t, error_name, shapeList, args)
2139 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002140
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002141 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002142 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2143 if "invalid_test_validators" in op:
2144 invalid_test_validators = op["invalid_test_validators"]
2145 clean_testList = []
2146 for test in testList:
2147 for validator_fcn in invalid_test_validators:
2148 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002149 if validator_fcn(
2150 opName=test[0],
2151 input_dtype=test[2],
2152 shapeList=test[4],
2153 args=test[5],
2154 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002155 remove_test = True
2156 if not remove_test:
2157 clean_testList.append(test)
2158 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002159
2160 return testList
2161
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002162 def serializeTest(
2163 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2164 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002165 try:
2166 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002167 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002168 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002169
2170 # Create a serializer
2171 self.createSerializer(opName, testStr)
2172
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002173 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002174 if "error_if_validators" in op:
2175 error_if_validators = op["error_if_validators"]
2176 else:
2177 error_if_validators = None
2178
Kevin Cheng550ccc52021-03-03 11:21:43 -08002179 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002180 num_operands = pCount + cCount
2181
2182 if isinstance(dtype_or_dtypeList, list):
2183 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002184 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002185 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002186 else:
2187 dtypeList = [dtype_or_dtypeList] * (num_operands)
2188
Kevin Cheng93a16282021-08-31 16:14:03 -07002189 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002190 assert (
2191 len(shapeList) == num_operands
2192 ), "shapeList length {} must match number of operands {}".format(
2193 len(shapeList), num_operands
2194 )
2195 assert (
2196 len(dtypeList) == num_operands
2197 ), "dtypeList length {} must match number of operands {}".format(
2198 len(dtypeList), num_operands
2199 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002200
2201 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002202 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002203 except KeyError:
2204 qgen = None
2205
2206 # Build the random tensor operands and the test
2207 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002208
Matthew Haddon1c00b712021-10-01 15:51:03 +01002209 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002210 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002211 else:
2212 qinfo = None
2213
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002214 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002215
Matthew Haddon1c00b712021-10-01 15:51:03 +01002216 try:
2217 if error_if_validators is None:
2218 if qinfo is not None:
2219 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2220 else:
2221 resultName = build_fcn(self, op, *tens, *testArgs)
2222 else:
2223 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002224 resultName = build_fcn(
2225 self,
2226 op,
2227 *tens,
2228 *testArgs,
2229 validator_fcns=error_if_validators,
2230 error_name=error_name,
2231 qinfo=qinfo,
2232 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002233 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002234 resultName = build_fcn(
2235 self,
2236 op,
2237 *tens,
2238 *testArgs,
2239 validator_fcns=error_if_validators,
2240 error_name=error_name,
2241 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002242 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002243 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002244 raise e
2245
Les Bell729b0352021-11-24 10:28:21 +00002246 if resultName:
2247 # The test is valid, serialize it
2248 self.serialize("test")
2249 else:
2250 # The test is not valid
2251 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002252
Eric Kunzee5e26762020-10-13 16:11:07 -07002253 def createDynamicOpLists(self):
2254
Jeremy Johnson00423432022-09-12 17:27:37 +01002255 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2256 # Already created these lists (can occur when class is initialized more than once)
2257 return
2258
Eric Kunzee5e26762020-10-13 16:11:07 -07002259 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002260 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002261
Kevin Cheng1533b852021-09-01 12:51:58 -07002262 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002263 testName = "conv2d_{}x{}".format(k[0], k[1])
2264 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2265 self.TOSA_OP_LIST[testName]["filter"] = k
2266 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002267
Kevin Cheng550ccc52021-03-03 11:21:43 -08002268 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2269 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2270 "depthwise_conv2d_TEMPLATE"
2271 ].copy()
2272 self.TOSA_OP_LIST[testName]["filter"] = k
2273 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
Kevin Cheng550ccc52021-03-03 11:21:43 -08002275 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2276 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2277 "transpose_conv2d_TEMPLATE"
2278 ].copy()
2279 self.TOSA_OP_LIST[testName]["filter"] = k
2280 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002281
Kevin Cheng1533b852021-09-01 12:51:58 -07002282 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2283 for k in KERNELS_3D:
2284 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2285 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2286 self.TOSA_OP_LIST[testName]["filter"] = k
2287 self.TOSA_OP_LIST[testName]["template"] = False
2288
Eric Kunzee5e26762020-10-13 16:11:07 -07002289 # Delete any templates after having created any dynamic ops
2290 # This is a two-pass operation because it's bad practice to delete
2291 # keys from dictionaries while iterating
2292 keyList = []
2293 for k in self.TOSA_OP_LIST:
2294 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002295 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002296 keyList.append(k)
2297 continue
2298 except KeyError:
2299 pass
2300
2301 for k in keyList:
2302 del self.TOSA_OP_LIST[k]
2303
2304 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002305 """Fill in default fields for ops if they aren't already specified.
2306 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002307 for op in self.TOSA_OP_LIST:
2308
2309 # Required fields
2310 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002311 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002312 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002313 raise Exception(
2314 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2315 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002316
2317 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002318 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002319 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002320 raise Exception(
2321 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2322 op
2323 )
2324 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002325
2326 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002327 _ = self.TOSA_OP_LIST[op]["types"]
2328 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002329 raise Exception(
2330 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2331 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002332
2333 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002334 _ = self.TOSA_OP_LIST[op]["op"]
2335 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002336 raise Exception(
2337 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2338 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002339
2340 # Put in default rank range, if missing
2341 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002342 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002343 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002344 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002345
2346 # Tensor operator list
2347 # 'op': op name
2348 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002349 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2350 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002351 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2352 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002353 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002354
Kevin Cheng550ccc52021-03-03 11:21:43 -08002355 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2356 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002357
Kevin Cheng550ccc52021-03-03 11:21:43 -08002358 TYPE_BOOL = [DType.BOOL]
2359 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2360 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2361 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002362
Kevin Cheng550ccc52021-03-03 11:21:43 -08002363 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002364
Kevin Cheng1533b852021-09-01 12:51:58 -07002365 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002366 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002367 [DType.INT8, DType.INT8, DType.INT32],
2368 [DType.INT16, DType.INT8, DType.INT48],
2369 DType.FLOAT,
2370 ]
2371
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002372 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002373
2374 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002375 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002376 "argmax": {
2377 "op": Op.ARGMAX,
2378 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002379 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002380 "build_fcn": (
2381 build_argmax,
2382 TosaTensorGen.tgBasic,
2383 TosaTensorValuesGen.tvgDefault,
2384 TosaArgGen.agAxis,
2385 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002386 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002387 "error_if_validators": (
2388 TosaErrorValidator.evAxisSmallerZero,
2389 TosaErrorValidator.evAxisLargerRank,
2390 TosaErrorValidator.evArgmaxOutputRankMismatch,
2391 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2392 TosaErrorValidator.evWrongRank,
2393 TosaErrorValidator.evWrongInputType,
2394 TosaErrorValidator.evWrongOutputType,
2395 TosaErrorValidator.evWrongInputList,
2396 TosaErrorValidator.evWrongOutputList,
2397 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002398 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002399 "avg_pool2d": {
2400 "op": Op.AVG_POOL2D,
2401 "operands": (1, 0),
2402 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002403 "build_fcn": (
2404 build_pool2d,
2405 TosaTensorGen.tgNHWC,
2406 TosaTensorValuesGen.tvgDefault,
2407 TosaArgGen.agPooling,
2408 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002409 "qgen": TosaQuantGen.qgUnary,
2410 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002411 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002412 "error_if_validators": (
2413 TosaErrorValidator.evKernelSmallerOne,
2414 TosaErrorValidator.evStrideSmallerOne,
2415 TosaErrorValidator.evPadSmallerZero,
2416 TosaErrorValidator.evWrongRank,
2417 TosaErrorValidator.evWrongInputType,
2418 TosaErrorValidator.evWrongOutputType,
2419 TosaErrorValidator.evWrongInputList,
2420 TosaErrorValidator.evWrongOutputList,
2421 TosaErrorValidator.evInputZeroPointNotZero,
2422 TosaErrorValidator.evOutputZeroPointNotZero,
2423 TosaErrorValidator.evPadLargerEqualKernel,
2424 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002425 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002426 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002427 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002428 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002429 "conv2d_TEMPLATE": {
2430 "op": Op.CONV2D,
2431 "operands": (1, 2),
2432 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002433 "build_fcn": (
2434 build_conv2d,
2435 TosaTensorGen.tgConv2D,
2436 TosaTensorValuesGen.tvgDefault,
2437 TosaArgGen.agConv,
2438 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002439 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002440 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002441 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2442 "error_if_validators": (
2443 TosaErrorValidator.evWrongInputType,
2444 TosaErrorValidator.evWrongOutputType,
2445 TosaErrorValidator.evWrongInputList,
2446 TosaErrorValidator.evWrongOutputList,
2447 TosaErrorValidator.evInputZeroPointNotZero,
2448 TosaErrorValidator.evWeightZeroPointNotZero,
2449 TosaErrorValidator.evPadSmallerZero,
2450 TosaErrorValidator.evStrideSmallerOne,
2451 TosaErrorValidator.evDilationSmallerOne,
2452 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002453 TosaErrorValidator.evConvOutputShapeMismatch,
2454 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002455 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002456 "template": True,
2457 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002458 # Templated operator. Filled in by createDynamicOpLists
2459 "conv3d_TEMPLATE": {
2460 "op": Op.CONV3D,
2461 "operands": (1, 2),
2462 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002463 "build_fcn": (
2464 build_conv3d,
2465 TosaTensorGen.tgConv3D,
2466 TosaTensorValuesGen.tvgDefault,
2467 TosaArgGen.agConv,
2468 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002469 "qgen": TosaQuantGen.qgConv,
2470 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002471 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2472 "error_if_validators": (
2473 TosaErrorValidator.evWrongInputType,
2474 TosaErrorValidator.evWrongOutputType,
2475 TosaErrorValidator.evWrongInputList,
2476 TosaErrorValidator.evWrongOutputList,
2477 TosaErrorValidator.evInputZeroPointNotZero,
2478 TosaErrorValidator.evWeightZeroPointNotZero,
2479 TosaErrorValidator.evPadSmallerZero,
2480 TosaErrorValidator.evStrideSmallerOne,
2481 TosaErrorValidator.evDilationSmallerOne,
2482 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002483 TosaErrorValidator.evConvOutputShapeMismatch,
2484 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002485 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002486 "template": True,
2487 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002488 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002489 "depthwise_conv2d_TEMPLATE": {
2490 "op": Op.DEPTHWISE_CONV2D,
2491 "operands": (1, 2),
2492 "filter": [1, 1],
2493 "rank": (4, 4),
2494 "build_fcn": (
2495 build_depthwise_conv2d,
2496 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002497 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002498 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002499 ),
2500 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002501 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002502 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2503 "error_if_validators": (
2504 TosaErrorValidator.evWrongInputType,
2505 TosaErrorValidator.evWrongOutputType,
2506 TosaErrorValidator.evWrongInputList,
2507 TosaErrorValidator.evWrongOutputList,
2508 TosaErrorValidator.evInputZeroPointNotZero,
2509 TosaErrorValidator.evWeightZeroPointNotZero,
2510 TosaErrorValidator.evPadSmallerZero,
2511 TosaErrorValidator.evStrideSmallerOne,
2512 TosaErrorValidator.evDilationSmallerOne,
2513 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002514 TosaErrorValidator.evConvOutputShapeMismatch,
2515 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002516 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002517 "template": True,
2518 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002519 "fully_connected": {
2520 "op": Op.FULLY_CONNECTED,
2521 "operands": (1, 2),
2522 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002523 "build_fcn": (
2524 build_fully_connected,
2525 TosaTensorGen.tgFullyConnected,
2526 TosaTensorValuesGen.tvgDefault,
2527 None,
2528 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002529 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002530 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002531 "error_if_validators": (
2532 TosaErrorValidator.evInputZeroPointNotZero,
2533 TosaErrorValidator.evWeightZeroPointNotZero,
2534 TosaErrorValidator.evWrongRank,
2535 TosaErrorValidator.evWrongInputType,
2536 TosaErrorValidator.evWrongOutputType,
2537 TosaErrorValidator.evWrongInputList,
2538 TosaErrorValidator.evWrongOutputList,
2539 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002540 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002541 "matmul": {
2542 "op": Op.MATMUL,
2543 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002544 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002545 "build_fcn": (
2546 build_matmul,
2547 TosaTensorGen.tgMatmul,
2548 TosaTensorValuesGen.tvgDefault,
2549 None,
2550 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002551 "qgen": TosaQuantGen.qgMatmul,
2552 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002553 "error_if_validators": (
2554 TosaErrorValidator.evInputZeroPointNotZero,
2555 TosaErrorValidator.evWrongRank,
2556 TosaErrorValidator.evWrongInputType,
2557 TosaErrorValidator.evWrongOutputType,
2558 TosaErrorValidator.evWrongInputList,
2559 TosaErrorValidator.evWrongOutputList,
2560 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002561 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002562 "max_pool2d": {
2563 "op": Op.MAX_POOL2D,
2564 "operands": (1, 0),
2565 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002566 "build_fcn": (
2567 build_pool2d,
2568 TosaTensorGen.tgNHWC,
2569 TosaTensorValuesGen.tvgDefault,
2570 TosaArgGen.agPooling,
2571 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002572 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002573 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002574 "error_if_validators": (
2575 TosaErrorValidator.evKernelSmallerOne,
2576 TosaErrorValidator.evStrideSmallerOne,
2577 TosaErrorValidator.evPadSmallerZero,
2578 TosaErrorValidator.evWrongRank,
2579 TosaErrorValidator.evWrongInputType,
2580 TosaErrorValidator.evWrongOutputType,
2581 TosaErrorValidator.evWrongInputList,
2582 TosaErrorValidator.evWrongOutputList,
2583 TosaErrorValidator.evPadLargerEqualKernel,
2584 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002585 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002586 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002587 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002588 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002589 "transpose_conv2d_TEMPLATE": {
2590 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002591 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002592 "rank": (4, 4),
2593 "build_fcn": (
2594 build_transpose_conv2d,
2595 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002596 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002597 TosaArgGen.agTransposeConv2D,
2598 ),
2599 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002600 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002601 "invalid_test_validators": (
2602 TosaInvalidValidator.ivHeightWidthInvalid,
2603 TosaInvalidValidator.ivNonPositiveOutputShape,
2604 ),
2605 "error_if_validators": (
2606 TosaErrorValidator.evWrongInputType,
2607 TosaErrorValidator.evWrongOutputType,
2608 TosaErrorValidator.evWrongInputList,
2609 TosaErrorValidator.evWrongOutputList,
2610 TosaErrorValidator.evInputZeroPointNotZero,
2611 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002612 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002613 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002614 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002615 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002616 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002617 "template": True,
2618 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002619 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002620 "clamp": {
2621 "op": Op.CLAMP,
2622 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002623 "build_fcn": (
2624 build_clamp,
2625 TosaTensorGen.tgBasic,
2626 TosaTensorValuesGen.tvgDefault,
2627 None,
2628 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002629 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002630 "error_if_validators": (
2631 TosaErrorValidator.evMaxSmallerMin,
2632 TosaErrorValidator.evWrongInputType,
2633 TosaErrorValidator.evWrongOutputType,
2634 TosaErrorValidator.evWrongInputList,
2635 TosaErrorValidator.evWrongOutputList,
2636 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002637 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002638 "sigmoid": {
2639 "op": Op.SIGMOID,
2640 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002641 "build_fcn": (
2642 build_sigmoid,
2643 TosaTensorGen.tgBasic,
2644 TosaTensorValuesGen.tvgDefault,
2645 None,
2646 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002647 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002648 "error_if_validators": (
2649 TosaErrorValidator.evWrongInputType,
2650 TosaErrorValidator.evWrongOutputType,
2651 TosaErrorValidator.evWrongInputList,
2652 TosaErrorValidator.evWrongOutputList,
2653 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002654 },
2655 "tanh": {
2656 "op": Op.TANH,
2657 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002658 "build_fcn": (
2659 build_tanh,
2660 TosaTensorGen.tgBasic,
2661 TosaTensorValuesGen.tvgDefault,
2662 None,
2663 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002664 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002665 "error_if_validators": (
2666 TosaErrorValidator.evWrongInputType,
2667 TosaErrorValidator.evWrongOutputType,
2668 TosaErrorValidator.evWrongInputList,
2669 TosaErrorValidator.evWrongOutputList,
2670 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002671 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002672 # Elementwise Binary Operators
2673 "add": {
2674 "op": Op.ADD,
2675 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002676 "build_fcn": (
2677 build_binary_broadcast,
2678 TosaTensorGen.tgBroadcastFuzz,
2679 TosaTensorValuesGen.tvgAddSub,
2680 None,
2681 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002682 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002683 "error_if_validators": (
2684 TosaErrorValidator.evRankMismatch,
2685 TosaErrorValidator.evWrongInputType,
2686 TosaErrorValidator.evWrongOutputType,
2687 TosaErrorValidator.evWrongInputList,
2688 TosaErrorValidator.evWrongOutputList,
2689 TosaErrorValidator.evDimensionMismatch,
2690 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002691 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002692 "arithmetic_right_shift": {
2693 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2694 "operands": (2, 0),
2695 "build_fcn": (
2696 build_arithmetic_right_shift,
2697 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002698 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002699 TosaArgGen.agArithmeticRightShift,
2700 ),
2701 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002702 "error_if_validators": (
2703 TosaErrorValidator.evRankMismatch,
2704 TosaErrorValidator.evWrongInputType,
2705 TosaErrorValidator.evWrongOutputType,
2706 TosaErrorValidator.evWrongInputList,
2707 TosaErrorValidator.evWrongOutputList,
2708 TosaErrorValidator.evDimensionMismatch,
2709 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002710 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002711 "bitwise_and": {
2712 "op": Op.BITWISE_AND,
2713 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002714 "build_fcn": (
2715 build_binary_broadcast,
2716 TosaTensorGen.tgBroadcastFuzz,
2717 TosaTensorValuesGen.tvgDefault,
2718 None,
2719 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002720 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002721 "error_if_validators": (
2722 TosaErrorValidator.evRankMismatch,
2723 TosaErrorValidator.evWrongInputType,
2724 TosaErrorValidator.evWrongOutputType,
2725 TosaErrorValidator.evWrongInputList,
2726 TosaErrorValidator.evWrongOutputList,
2727 TosaErrorValidator.evDimensionMismatch,
2728 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002729 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002730 "bitwise_or": {
2731 "op": Op.BITWISE_OR,
2732 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002733 "build_fcn": (
2734 build_binary_broadcast,
2735 TosaTensorGen.tgBroadcastFuzz,
2736 TosaTensorValuesGen.tvgDefault,
2737 None,
2738 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002739 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002740 "error_if_validators": (
2741 TosaErrorValidator.evRankMismatch,
2742 TosaErrorValidator.evWrongInputType,
2743 TosaErrorValidator.evWrongOutputType,
2744 TosaErrorValidator.evWrongInputList,
2745 TosaErrorValidator.evWrongOutputList,
2746 TosaErrorValidator.evDimensionMismatch,
2747 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002748 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002749 "bitwise_xor": {
2750 "op": Op.BITWISE_XOR,
2751 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002752 "build_fcn": (
2753 build_binary_broadcast,
2754 TosaTensorGen.tgBroadcastFuzz,
2755 TosaTensorValuesGen.tvgDefault,
2756 None,
2757 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002758 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002759 "error_if_validators": (
2760 TosaErrorValidator.evRankMismatch,
2761 TosaErrorValidator.evWrongInputType,
2762 TosaErrorValidator.evWrongOutputType,
2763 TosaErrorValidator.evWrongInputList,
2764 TosaErrorValidator.evWrongOutputList,
2765 TosaErrorValidator.evDimensionMismatch,
2766 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002767 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002768 "intdiv": {
2769 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002770 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002771 "build_fcn": (
2772 build_binary_broadcast,
2773 TosaTensorGen.tgBroadcastFuzz,
2774 TosaTensorValuesGen.tvgIntDiv,
2775 None,
2776 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002777 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002778 "error_if_validators": (
2779 TosaErrorValidator.evRankMismatch,
2780 TosaErrorValidator.evWrongInputType,
2781 TosaErrorValidator.evWrongOutputType,
2782 TosaErrorValidator.evWrongInputList,
2783 TosaErrorValidator.evWrongOutputList,
2784 TosaErrorValidator.evDimensionMismatch,
2785 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002786 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002787 "logical_and": {
2788 "op": Op.LOGICAL_AND,
2789 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002790 "build_fcn": (
2791 build_binary_broadcast,
2792 TosaTensorGen.tgBroadcastFuzz,
2793 TosaTensorValuesGen.tvgDefault,
2794 None,
2795 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002796 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002797 "error_if_validators": (
2798 TosaErrorValidator.evRankMismatch,
2799 TosaErrorValidator.evWrongInputType,
2800 TosaErrorValidator.evWrongOutputType,
2801 TosaErrorValidator.evWrongInputList,
2802 TosaErrorValidator.evWrongOutputList,
2803 TosaErrorValidator.evDimensionMismatch,
2804 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002805 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002806 "logical_left_shift": {
2807 "op": Op.LOGICAL_LEFT_SHIFT,
2808 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002809 "build_fcn": (
2810 build_binary_broadcast,
2811 TosaTensorGen.tgBroadcastFuzz,
2812 TosaTensorValuesGen.tvgLogicalShift,
2813 None,
2814 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002815 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002816 "error_if_validators": (
2817 TosaErrorValidator.evRankMismatch,
2818 TosaErrorValidator.evWrongInputType,
2819 TosaErrorValidator.evWrongOutputType,
2820 TosaErrorValidator.evWrongInputList,
2821 TosaErrorValidator.evWrongOutputList,
2822 TosaErrorValidator.evDimensionMismatch,
2823 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002824 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002825 "logical_right_shift": {
2826 "op": Op.LOGICAL_RIGHT_SHIFT,
2827 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002828 "build_fcn": (
2829 build_binary_broadcast,
2830 TosaTensorGen.tgBroadcastFuzz,
2831 TosaTensorValuesGen.tvgLogicalShift,
2832 None,
2833 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002834 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002835 "error_if_validators": (
2836 TosaErrorValidator.evRankMismatch,
2837 TosaErrorValidator.evWrongInputType,
2838 TosaErrorValidator.evWrongOutputType,
2839 TosaErrorValidator.evWrongInputList,
2840 TosaErrorValidator.evWrongOutputList,
2841 TosaErrorValidator.evDimensionMismatch,
2842 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002843 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002844 "logical_or": {
2845 "op": Op.LOGICAL_OR,
2846 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002847 "build_fcn": (
2848 build_binary_broadcast,
2849 TosaTensorGen.tgBroadcastFuzz,
2850 TosaTensorValuesGen.tvgDefault,
2851 None,
2852 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002853 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002854 "error_if_validators": (
2855 TosaErrorValidator.evRankMismatch,
2856 TosaErrorValidator.evWrongInputType,
2857 TosaErrorValidator.evWrongOutputType,
2858 TosaErrorValidator.evWrongInputList,
2859 TosaErrorValidator.evWrongOutputList,
2860 TosaErrorValidator.evDimensionMismatch,
2861 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002862 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002863 "logical_xor": {
2864 "op": Op.LOGICAL_XOR,
2865 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002866 "build_fcn": (
2867 build_binary_broadcast,
2868 TosaTensorGen.tgBroadcastFuzz,
2869 TosaTensorValuesGen.tvgDefault,
2870 None,
2871 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002872 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002873 "error_if_validators": (
2874 TosaErrorValidator.evRankMismatch,
2875 TosaErrorValidator.evWrongInputType,
2876 TosaErrorValidator.evWrongOutputType,
2877 TosaErrorValidator.evWrongInputList,
2878 TosaErrorValidator.evWrongOutputList,
2879 TosaErrorValidator.evDimensionMismatch,
2880 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002881 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002882 "maximum": {
2883 "op": Op.MAXIMUM,
2884 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002885 "build_fcn": (
2886 build_binary_broadcast,
2887 TosaTensorGen.tgBroadcastFuzz,
2888 TosaTensorValuesGen.tvgDefault,
2889 None,
2890 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002891 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002892 "error_if_validators": (
2893 TosaErrorValidator.evRankMismatch,
2894 TosaErrorValidator.evWrongInputType,
2895 TosaErrorValidator.evWrongOutputType,
2896 TosaErrorValidator.evWrongInputList,
2897 TosaErrorValidator.evWrongOutputList,
2898 TosaErrorValidator.evDimensionMismatch,
2899 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002900 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002901 "minimum": {
2902 "op": Op.MINIMUM,
2903 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002904 "build_fcn": (
2905 build_binary_broadcast,
2906 TosaTensorGen.tgBroadcastFuzz,
2907 TosaTensorValuesGen.tvgDefault,
2908 None,
2909 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002910 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002911 "error_if_validators": (
2912 TosaErrorValidator.evRankMismatch,
2913 TosaErrorValidator.evWrongInputType,
2914 TosaErrorValidator.evWrongOutputType,
2915 TosaErrorValidator.evWrongInputList,
2916 TosaErrorValidator.evWrongOutputList,
2917 TosaErrorValidator.evDimensionMismatch,
2918 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002919 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002920 "mul": {
2921 "op": Op.MUL,
2922 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002923 "build_fcn": (
2924 build_mul,
2925 TosaTensorGen.tgBroadcastFuzz,
2926 TosaTensorValuesGen.tvgMul,
2927 TosaArgGen.agMul,
2928 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002929 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002930 "error_if_validators": (
2931 TosaErrorValidator.evWrongInputType,
2932 TosaErrorValidator.evWrongOutputType,
2933 TosaErrorValidator.evWrongInputList,
2934 TosaErrorValidator.evWrongOutputList,
2935 TosaErrorValidator.evRankMismatch,
2936 TosaErrorValidator.evDimensionMismatch,
2937 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002938 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002939 "pow": {
2940 "op": Op.POW,
2941 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002942 "build_fcn": (
2943 build_binary_broadcast,
2944 TosaTensorGen.tgBroadcastFuzz,
2945 TosaTensorValuesGen.tvgDefault,
2946 None,
2947 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002948 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002949 "error_if_validators": (
2950 TosaErrorValidator.evRankMismatch,
2951 TosaErrorValidator.evWrongInputType,
2952 TosaErrorValidator.evWrongOutputType,
2953 TosaErrorValidator.evWrongInputList,
2954 TosaErrorValidator.evWrongOutputList,
2955 TosaErrorValidator.evDimensionMismatch,
2956 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002957 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002958 "sub": {
2959 "op": Op.SUB,
2960 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002961 "build_fcn": (
2962 build_binary_broadcast,
2963 TosaTensorGen.tgBroadcastFuzz,
2964 TosaTensorValuesGen.tvgAddSub,
2965 None,
2966 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002967 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002968 "error_if_validators": (
2969 TosaErrorValidator.evRankMismatch,
2970 TosaErrorValidator.evWrongInputType,
2971 TosaErrorValidator.evWrongOutputType,
2972 TosaErrorValidator.evWrongInputList,
2973 TosaErrorValidator.evWrongOutputList,
2974 TosaErrorValidator.evDimensionMismatch,
2975 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002976 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002977 "table": {
2978 "op": Op.TABLE,
2979 # Use the automatic generation functions to create the input array
2980 # but create the table tensor in the build function, as it may be
2981 # a different type from the input
2982 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002983 "build_fcn": (
2984 build_table,
2985 TosaTensorGen.tgBasic,
2986 TosaTensorValuesGen.tvgDefault,
2987 TosaArgGen.agTable,
2988 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002989 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002990 "error_if_validators": (
2991 TosaErrorValidator.evWrongInputType,
2992 TosaErrorValidator.evWrongOutputType,
2993 TosaErrorValidator.evWrongInputList,
2994 TosaErrorValidator.evWrongOutputList,
2995 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002996 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002997 # Elementwise Unary operators
2998 "abs": {
2999 "op": Op.ABS,
3000 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003001 "build_fcn": (
3002 build_unary,
3003 TosaTensorGen.tgBasic,
3004 TosaTensorValuesGen.tvgDefault,
3005 None,
3006 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003007 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003008 "error_if_validators": (
3009 TosaErrorValidator.evWrongInputType,
3010 TosaErrorValidator.evWrongOutputType,
3011 TosaErrorValidator.evWrongInputList,
3012 TosaErrorValidator.evWrongOutputList,
3013 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003014 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003015 "bitwise_not": {
3016 "op": Op.BITWISE_NOT,
3017 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003018 "build_fcn": (
3019 build_unary,
3020 TosaTensorGen.tgBasic,
3021 TosaTensorValuesGen.tvgDefault,
3022 None,
3023 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003024 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003025 "error_if_validators": (
3026 TosaErrorValidator.evWrongInputType,
3027 TosaErrorValidator.evWrongOutputType,
3028 TosaErrorValidator.evWrongInputList,
3029 TosaErrorValidator.evWrongOutputList,
3030 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003031 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003032 "ceil": {
3033 "op": Op.CEIL,
3034 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003035 "build_fcn": (
3036 build_unary,
3037 TosaTensorGen.tgBasic,
3038 TosaTensorValuesGen.tvgDefault,
3039 None,
3040 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003041 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003042 "error_if_validators": (
3043 TosaErrorValidator.evWrongInputType,
3044 TosaErrorValidator.evWrongOutputType,
3045 TosaErrorValidator.evWrongInputList,
3046 TosaErrorValidator.evWrongOutputList,
3047 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003048 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003049 "clz": {
3050 "op": Op.CLZ,
3051 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003052 "build_fcn": (
3053 build_unary,
3054 TosaTensorGen.tgBasic,
3055 TosaTensorValuesGen.tvgDefault,
3056 None,
3057 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003058 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003059 "error_if_validators": (
3060 TosaErrorValidator.evWrongInputType,
3061 TosaErrorValidator.evWrongOutputType,
3062 TosaErrorValidator.evWrongInputList,
3063 TosaErrorValidator.evWrongOutputList,
3064 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003065 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003066 "exp": {
3067 "op": Op.EXP,
3068 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003069 "build_fcn": (
3070 build_unary,
3071 TosaTensorGen.tgBasic,
3072 TosaTensorValuesGen.tvgDefault,
3073 None,
3074 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003075 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003076 "error_if_validators": (
3077 TosaErrorValidator.evWrongInputType,
3078 TosaErrorValidator.evWrongOutputType,
3079 TosaErrorValidator.evWrongInputList,
3080 TosaErrorValidator.evWrongOutputList,
3081 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003082 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003083 "floor": {
3084 "op": Op.FLOOR,
3085 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003086 "build_fcn": (
3087 build_unary,
3088 TosaTensorGen.tgBasic,
3089 TosaTensorValuesGen.tvgDefault,
3090 None,
3091 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003092 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003093 "error_if_validators": (
3094 TosaErrorValidator.evWrongInputType,
3095 TosaErrorValidator.evWrongOutputType,
3096 TosaErrorValidator.evWrongInputList,
3097 TosaErrorValidator.evWrongOutputList,
3098 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003099 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003100 "log": {
3101 "op": Op.LOG,
3102 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003103 "build_fcn": (
3104 build_unary,
3105 TosaTensorGen.tgBasic,
3106 TosaTensorValuesGen.tvgDefault,
3107 None,
3108 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003109 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003110 "error_if_validators": (
3111 TosaErrorValidator.evWrongInputType,
3112 TosaErrorValidator.evWrongOutputType,
3113 TosaErrorValidator.evWrongInputList,
3114 TosaErrorValidator.evWrongOutputList,
3115 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003116 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003117 "logical_not": {
3118 "op": Op.LOGICAL_NOT,
3119 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003120 "build_fcn": (
3121 build_unary,
3122 TosaTensorGen.tgBasic,
3123 TosaTensorValuesGen.tvgDefault,
3124 None,
3125 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003127 "error_if_validators": (
3128 TosaErrorValidator.evWrongInputType,
3129 TosaErrorValidator.evWrongOutputType,
3130 TosaErrorValidator.evWrongInputList,
3131 TosaErrorValidator.evWrongOutputList,
3132 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003133 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003134 "negate": {
3135 "op": Op.NEGATE,
3136 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003137 "build_fcn": (
3138 build_unary,
3139 TosaTensorGen.tgBasic,
3140 TosaTensorValuesGen.tvgNegate,
3141 None,
3142 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003143 "qgen": TosaQuantGen.qgUnary,
3144 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003145 "error_if_validators": (
3146 TosaErrorValidator.evInputZeroPointNotZero,
3147 TosaErrorValidator.evOutputZeroPointNotZero,
3148 TosaErrorValidator.evWrongInputType,
3149 TosaErrorValidator.evWrongOutputType,
3150 TosaErrorValidator.evWrongInputList,
3151 TosaErrorValidator.evWrongOutputList,
3152 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003153 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003154 "reciprocal": {
3155 "op": Op.RECIPROCAL,
3156 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003157 "build_fcn": (
3158 build_unary,
3159 TosaTensorGen.tgBasic,
3160 TosaTensorValuesGen.tvgDefault,
3161 None,
3162 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003163 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003164 "error_if_validators": (
3165 TosaErrorValidator.evWrongInputType,
3166 TosaErrorValidator.evWrongOutputType,
3167 TosaErrorValidator.evWrongInputList,
3168 TosaErrorValidator.evWrongOutputList,
3169 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003170 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003171 "rsqrt": {
3172 "op": Op.RSQRT,
3173 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003174 "build_fcn": (
3175 build_unary,
3176 TosaTensorGen.tgBasic,
3177 TosaTensorValuesGen.tvgDefault,
3178 None,
3179 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003180 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003181 "error_if_validators": (
3182 TosaErrorValidator.evWrongInputType,
3183 TosaErrorValidator.evWrongOutputType,
3184 TosaErrorValidator.evWrongInputList,
3185 TosaErrorValidator.evWrongOutputList,
3186 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003187 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003188 # Elementwise Ternary operators
3189 "select": {
3190 "op": Op.SELECT,
3191 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003192 "build_fcn": (
3193 build_select,
3194 TosaTensorGen.tgBroadcastFuzz,
3195 TosaTensorValuesGen.tvgSelect,
3196 None,
3197 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003198 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003199 "error_if_validators": (
3200 TosaErrorValidator.evRankMismatch,
3201 TosaErrorValidator.evWrongInputType,
3202 TosaErrorValidator.evWrongOutputType,
3203 TosaErrorValidator.evWrongInputList,
3204 TosaErrorValidator.evWrongOutputList,
3205 TosaErrorValidator.evDimensionMismatch,
3206 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003207 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003208 # Comparison operators
3209 "equal": {
3210 "op": Op.EQUAL,
3211 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003212 "build_fcn": (
3213 build_comparison,
3214 TosaTensorGen.tgBroadcastFuzz,
3215 TosaTensorValuesGen.tvgEqual,
3216 None,
3217 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003218 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003219 "error_if_validators": (
3220 TosaErrorValidator.evRankMismatch,
3221 TosaErrorValidator.evWrongInputType,
3222 TosaErrorValidator.evWrongOutputType,
3223 TosaErrorValidator.evWrongInputList,
3224 TosaErrorValidator.evWrongOutputList,
3225 TosaErrorValidator.evDimensionMismatch,
3226 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003227 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003228 "greater_equal": {
3229 "op": Op.GREATER_EQUAL,
3230 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003231 "build_fcn": (
3232 build_comparison,
3233 TosaTensorGen.tgBroadcastFuzz,
3234 TosaTensorValuesGen.tvgDefault,
3235 None,
3236 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003237 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003238 "error_if_validators": (
3239 TosaErrorValidator.evRankMismatch,
3240 TosaErrorValidator.evWrongInputType,
3241 TosaErrorValidator.evWrongOutputType,
3242 TosaErrorValidator.evWrongInputList,
3243 TosaErrorValidator.evWrongOutputList,
3244 TosaErrorValidator.evDimensionMismatch,
3245 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003246 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003247 "greater": {
3248 "op": Op.GREATER,
3249 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003250 "build_fcn": (
3251 build_comparison,
3252 TosaTensorGen.tgBroadcastFuzz,
3253 TosaTensorValuesGen.tvgDefault,
3254 None,
3255 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003256 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003257 "error_if_validators": (
3258 TosaErrorValidator.evRankMismatch,
3259 TosaErrorValidator.evWrongInputType,
3260 TosaErrorValidator.evWrongOutputType,
3261 TosaErrorValidator.evWrongInputList,
3262 TosaErrorValidator.evWrongOutputList,
3263 TosaErrorValidator.evDimensionMismatch,
3264 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003265 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003266 # Reduction operators
3267 "reduce_all": {
3268 "op": Op.REDUCE_ALL,
3269 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003270 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003271 "build_fcn": (
3272 build_reduce,
3273 TosaTensorGen.tgBasic,
3274 TosaTensorValuesGen.tvgDefault,
3275 TosaArgGen.agAxis,
3276 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003277 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003278 "error_if_validators": (
3279 TosaErrorValidator.evAxisLargerRank,
3280 TosaErrorValidator.evAxisSmallerZero,
3281 TosaErrorValidator.evShapeOfAxisNotOne,
3282 TosaErrorValidator.evWrongInputType,
3283 TosaErrorValidator.evWrongOutputType,
3284 TosaErrorValidator.evWrongRank,
3285 TosaErrorValidator.evWrongInputList,
3286 TosaErrorValidator.evWrongOutputList,
3287 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003288 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003289 "reduce_any": {
3290 "op": Op.REDUCE_ANY,
3291 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003292 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003293 "build_fcn": (
3294 build_reduce,
3295 TosaTensorGen.tgBasic,
3296 TosaTensorValuesGen.tvgDefault,
3297 TosaArgGen.agAxis,
3298 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003299 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003300 "error_if_validators": (
3301 TosaErrorValidator.evAxisLargerRank,
3302 TosaErrorValidator.evAxisSmallerZero,
3303 TosaErrorValidator.evShapeOfAxisNotOne,
3304 TosaErrorValidator.evWrongInputType,
3305 TosaErrorValidator.evWrongOutputType,
3306 TosaErrorValidator.evWrongRank,
3307 TosaErrorValidator.evWrongInputList,
3308 TosaErrorValidator.evWrongOutputList,
3309 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003310 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003311 "reduce_max": {
3312 "op": Op.REDUCE_MAX,
3313 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003314 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003315 "build_fcn": (
3316 build_reduce,
3317 TosaTensorGen.tgBasic,
3318 TosaTensorValuesGen.tvgDefault,
3319 TosaArgGen.agAxis,
3320 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003321 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003322 "error_if_validators": (
3323 TosaErrorValidator.evAxisLargerRank,
3324 TosaErrorValidator.evAxisSmallerZero,
3325 TosaErrorValidator.evShapeOfAxisNotOne,
3326 TosaErrorValidator.evWrongInputType,
3327 TosaErrorValidator.evWrongOutputType,
3328 TosaErrorValidator.evWrongRank,
3329 TosaErrorValidator.evWrongInputList,
3330 TosaErrorValidator.evWrongOutputList,
3331 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003332 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003333 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003334 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003335 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003336 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003337 "build_fcn": (
3338 build_reduce,
3339 TosaTensorGen.tgBasic,
3340 TosaTensorValuesGen.tvgDefault,
3341 TosaArgGen.agAxis,
3342 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003344 "error_if_validators": (
3345 TosaErrorValidator.evAxisLargerRank,
3346 TosaErrorValidator.evAxisSmallerZero,
3347 TosaErrorValidator.evShapeOfAxisNotOne,
3348 TosaErrorValidator.evWrongInputType,
3349 TosaErrorValidator.evWrongOutputType,
3350 TosaErrorValidator.evWrongRank,
3351 TosaErrorValidator.evWrongInputList,
3352 TosaErrorValidator.evWrongOutputList,
3353 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003354 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003355 "reduce_product": {
3356 "op": Op.REDUCE_PRODUCT,
3357 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003358 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003359 "build_fcn": (
3360 build_reduce,
3361 TosaTensorGen.tgBasic,
3362 TosaTensorValuesGen.tvgDefault,
3363 TosaArgGen.agAxis,
3364 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003365 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003366 "error_if_validators": (
3367 TosaErrorValidator.evAxisLargerRank,
3368 TosaErrorValidator.evAxisSmallerZero,
3369 TosaErrorValidator.evShapeOfAxisNotOne,
3370 TosaErrorValidator.evWrongInputType,
3371 TosaErrorValidator.evWrongOutputType,
3372 TosaErrorValidator.evWrongRank,
3373 TosaErrorValidator.evWrongInputList,
3374 TosaErrorValidator.evWrongOutputList,
3375 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003376 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003377 "reduce_sum": {
3378 "op": Op.REDUCE_SUM,
3379 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003380 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003381 "build_fcn": (
3382 build_reduce,
3383 TosaTensorGen.tgBasic,
3384 TosaTensorValuesGen.tvgReduceSum,
3385 TosaArgGen.agAxis,
3386 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003387 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003388 "error_if_validators": (
3389 TosaErrorValidator.evAxisLargerRank,
3390 TosaErrorValidator.evAxisSmallerZero,
3391 TosaErrorValidator.evShapeOfAxisNotOne,
3392 TosaErrorValidator.evWrongInputType,
3393 TosaErrorValidator.evWrongOutputType,
3394 TosaErrorValidator.evWrongRank,
3395 TosaErrorValidator.evWrongInputList,
3396 TosaErrorValidator.evWrongOutputList,
3397 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003398 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003399 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003400 "concat": {
3401 "op": Op.CONCAT,
3402 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003403 "build_fcn": (
3404 build_concat,
3405 TosaTensorGen.tgConcat,
3406 TosaTensorValuesGen.tvgConcat,
3407 TosaArgGen.agAxis,
3408 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003409 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003410 "error_if_validators": (
3411 TosaErrorValidator.evAxisLargerRank,
3412 TosaErrorValidator.evAxisSmallerZero,
3413 TosaErrorValidator.evConcatInputRankMismatch,
3414 TosaErrorValidator.evConcatShapeSumMismatch,
3415 TosaErrorValidator.evConcatInputDimMismatch,
3416 TosaErrorValidator.evWrongInputType,
3417 TosaErrorValidator.evWrongOutputType,
3418 TosaErrorValidator.evWrongOutputList,
3419 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003420 },
3421 "pad": {
3422 "op": Op.PAD,
3423 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003424 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003425 "build_fcn": (
3426 build_pad,
3427 TosaTensorGen.tgBasic,
3428 TosaTensorValuesGen.tvgDefault,
3429 TosaArgGen.agPad,
3430 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003431 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003432 "error_if_validators": (
3433 TosaErrorValidator.evWrongInputType,
3434 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003435 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003436 TosaErrorValidator.evWrongOutputType,
3437 TosaErrorValidator.evWrongInputList,
3438 TosaErrorValidator.evWrongOutputList,
3439 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003440 },
3441 "reshape": {
3442 "op": Op.RESHAPE,
3443 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003444 "build_fcn": (
3445 build_reshape,
3446 TosaTensorGen.tgBasic,
3447 TosaTensorValuesGen.tvgDefault,
3448 TosaArgGen.agReshape,
3449 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003450 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003451 "error_if_validators": (
3452 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3453 TosaErrorValidator.evWrongInputType,
3454 TosaErrorValidator.evWrongOutputType,
3455 TosaErrorValidator.evWrongInputList,
3456 TosaErrorValidator.evWrongOutputList,
3457 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003458 },
3459 "reverse": {
3460 "op": Op.REVERSE,
3461 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003462 "build_fcn": (
3463 build_reverse,
3464 TosaTensorGen.tgBasic,
3465 TosaTensorValuesGen.tvgDefault,
3466 TosaArgGen.agAxis,
3467 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003468 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003469 "error_if_validators": (
3470 TosaErrorValidator.evAxisSmallerZero,
3471 TosaErrorValidator.evAxisLargerRank,
3472 TosaErrorValidator.evWrongInputType,
3473 TosaErrorValidator.evWrongOutputType,
3474 TosaErrorValidator.evWrongInputList,
3475 TosaErrorValidator.evWrongOutputList,
3476 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003477 },
3478 "slice": {
3479 "op": Op.SLICE,
3480 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003481 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003482 "build_fcn": (
3483 build_slice,
3484 TosaTensorGen.tgBasic,
3485 TosaTensorValuesGen.tvgDefault,
3486 TosaArgGen.agSlice,
3487 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003488 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003489 "error_if_validators": (
3490 TosaErrorValidator.evStartSmallerZero,
3491 TosaErrorValidator.evSizeSmallerEqualZero,
3492 TosaErrorValidator.evStartSizeOutsideBounds,
3493 TosaErrorValidator.evSizeOutputShapeMismatch,
3494 TosaErrorValidator.evInputSizeStartLengthMismatch,
3495 TosaErrorValidator.evWrongRank,
3496 TosaErrorValidator.evWrongInputType,
3497 TosaErrorValidator.evWrongOutputType,
3498 TosaErrorValidator.evWrongInputList,
3499 TosaErrorValidator.evWrongOutputList,
3500 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003501 },
3502 "tile": {
3503 "op": Op.TILE,
3504 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003505 "build_fcn": (
3506 build_tile,
3507 TosaTensorGen.tgBasic,
3508 TosaTensorValuesGen.tvgDefault,
3509 TosaArgGen.agTile,
3510 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003511 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003512 "error_if_validators": (
3513 TosaErrorValidator.evWrongInputType,
3514 TosaErrorValidator.evWrongOutputType,
3515 TosaErrorValidator.evWrongInputList,
3516 TosaErrorValidator.evWrongOutputList,
3517 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003518 },
3519 "transpose": {
3520 "op": Op.TRANSPOSE,
3521 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003522 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003523 "build_fcn": (
3524 build_transpose,
3525 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003526 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003527 TosaArgGen.agTranspose,
3528 ),
3529 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003530 "error_if_validators": (
3531 TosaErrorValidator.evIndexOutsideBounds,
3532 TosaErrorValidator.evIndexUsedTwice,
3533 TosaErrorValidator.evWrongInputType,
3534 TosaErrorValidator.evWrongOutputType,
3535 TosaErrorValidator.evWrongInputList,
3536 TosaErrorValidator.evWrongOutputList,
3537 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003538 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003539 # Data nodes
3540 "const": {
3541 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003542 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003543 "build_fcn": (
3544 build_const,
3545 TosaTensorGen.tgBasic,
3546 TosaTensorValuesGen.tvgDefault,
3547 None,
3548 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 "types": TYPE_FIB,
3550 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003551 "identity": {
3552 "op": Op.IDENTITY,
3553 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003554 "build_fcn": (
3555 build_unary,
3556 TosaTensorGen.tgBasic,
3557 TosaTensorValuesGen.tvgDefault,
3558 None,
3559 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003560 "types": TYPE_FIB,
3561 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003562 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003563 "gather": {
3564 "op": Op.GATHER,
3565 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3566 "operands": (1, 0),
3567 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003568 "build_fcn": (
3569 build_gather,
3570 TosaTensorGen.tgBasic,
3571 TosaTensorValuesGen.tvgDefault,
3572 None,
3573 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003574 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003575 "error_if_validators": (
3576 TosaErrorValidator.evWrongInputType,
3577 TosaErrorValidator.evWrongOutputType,
3578 TosaErrorValidator.evWrongInputList,
3579 TosaErrorValidator.evWrongOutputList,
3580 TosaErrorValidator.evWrongRank,
3581 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003582 },
3583 "scatter": {
3584 "op": Op.SCATTER,
3585 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003586 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003587 "operands": (2, 0),
3588 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003589 "build_fcn": (
3590 build_scatter,
3591 TosaTensorGen.tgScatter,
3592 TosaTensorValuesGen.tvgDefault,
3593 None,
3594 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003595 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003596 "error_if_validators": (
3597 TosaErrorValidator.evWrongInputType,
3598 TosaErrorValidator.evWrongOutputType,
3599 TosaErrorValidator.evWrongInputList,
3600 TosaErrorValidator.evWrongOutputList,
3601 TosaErrorValidator.evWrongRank,
3602 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003603 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003604 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003605 "resize": {
3606 "op": Op.RESIZE,
3607 "operands": (1, 0),
3608 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003609 "build_fcn": (
3610 build_resize,
3611 TosaTensorGen.tgNHWC,
3612 TosaTensorValuesGen.tvgDefault,
3613 TosaArgGen.agResize,
3614 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003615 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003616 "invalid_test_validators": (
3617 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003618 ),
3619 "error_if_validators": (
3620 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003621 TosaErrorValidator.evScaleSmallerEqualZero,
3622 TosaErrorValidator.evScaleNLargerMax,
3623 TosaErrorValidator.evScaleDLargerMax,
3624 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003625 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003626 TosaErrorValidator.evBorderSmallerMin,
3627 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003628 TosaErrorValidator.evWrongInputType,
3629 TosaErrorValidator.evWrongOutputType,
3630 TosaErrorValidator.evWrongRank,
3631 TosaErrorValidator.evWrongInputList,
3632 TosaErrorValidator.evWrongOutputList,
3633 TosaErrorValidator.evBatchMismatch,
3634 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003635 TosaErrorValidator.evResizeOutputShapeMismatch,
3636 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003637 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003638 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003639 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003640 "cast": {
3641 "op": Op.CAST,
3642 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003643 "build_fcn": (
3644 build_cast,
3645 TosaTensorGen.tgBasic,
3646 TosaTensorValuesGen.tvgDefault,
3647 TosaArgGen.agCast,
3648 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003649 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003650 "error_if_validators": (
3651 TosaErrorValidator.evWrongInputType,
3652 TosaErrorValidator.evWrongOutputType,
3653 TosaErrorValidator.evWrongInputList,
3654 TosaErrorValidator.evWrongOutputList,
3655 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003656 },
3657 "rescale": {
3658 "op": Op.RESCALE,
3659 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003660 "build_fcn": (
3661 build_rescale,
3662 TosaTensorGen.tgBasic,
3663 TosaTensorValuesGen.tvgDefault,
3664 TosaArgGen.agRescale,
3665 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003666 "types": [
3667 DType.UINT8,
3668 DType.INT8,
3669 DType.INT16,
3670 DType.INT32,
3671 DType.INT48,
3672 DType.UINT16,
3673 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003674 "error_if_validators": (
3675 TosaErrorValidator.evInputZeroPointNotZero,
3676 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003677 TosaErrorValidator.evU16InputZeroPointNotValid,
3678 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 TosaErrorValidator.evScaleTrue,
3680 TosaErrorValidator.evScaleNotTrue,
3681 TosaErrorValidator.evWrongInputType,
3682 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003686 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003687 # Custom
3688 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003690 # Two varients of cond_if, one that generates one of two constant tensors (no
3691 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3692 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003693 "cond_if_const": {
3694 "op": Op.COND_IF,
3695 "operands": (0, 2),
3696 "build_fcn": (
3697 build_cond_if_const,
3698 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003699 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003700 TosaArgGen.agCondIf,
3701 ),
3702 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003703 "error_if_validators": (
3704 TosaErrorValidator.evOutputListThenGraphMismatch,
3705 TosaErrorValidator.evOutputListElseGraphMismatch,
3706 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003707 },
3708 "cond_if_binary": {
3709 "op": Op.COND_IF,
3710 "operands": (2, 0),
3711 "build_fcn": (
3712 build_cond_if_binary,
3713 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003714 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003715 TosaArgGen.agCondIf,
3716 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003717 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003718 "error_if_validators": (
3719 TosaErrorValidator.evInputListThenGraphMismatch,
3720 TosaErrorValidator.evInputListElseGraphMismatch,
3721 TosaErrorValidator.evOutputListThenGraphMismatch,
3722 TosaErrorValidator.evOutputListElseGraphMismatch,
3723 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003724 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003725 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003726 "while_loop": {
3727 "op": Op.WHILE_LOOP,
3728 "operands": (0, 1),
3729 "build_fcn": (
3730 build_while_loop,
3731 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003733 TosaArgGen.agWhileLoop,
3734 ),
3735 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003736 "error_if_validators": (
3737 TosaErrorValidator.evInputListOutputListMismatch,
3738 TosaErrorValidator.evInputListCondGraphMismatch,
3739 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3740 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3741 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
3742 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003743 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003744 }
3745
Kevin Cheng550ccc52021-03-03 11:21:43 -08003746
Eric Kunzee5e26762020-10-13 16:11:07 -07003747class OutputShaper:
3748 # Methods in this class compute the expected output shape and datatype
3749 # for common classes of operations
3750 def __init__(self):
3751 pass
3752
3753 # These methods return arguments that can be used for
3754 # creating a new output tensor
3755 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003756 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3757 if error_name != ErrorIf.RankMismatch:
3758 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003759 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003760
3761 shape = []
3762 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003763 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003764 shape.append(b.shape[i])
3765 else:
3766 shape.append(a.shape[i])
3767
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003768 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003769 all_dtypes = [
3770 DType.INT8,
3771 DType.INT16,
3772 DType.INT32,
3773 DType.INT48,
3774 DType.FLOAT,
3775 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003776 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3777 outputDType = rng.choice(wrong_dtypes)
3778 else:
3779 outputDType = a.dtype
3780
3781 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003782
3783 @staticmethod
3784 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003785 assert len(a.shape) == len(b.shape)
3786 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003787
3788 shape = []
3789 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003790 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003791 shape.append(a.shape[i])
3792
Kevin Cheng550ccc52021-03-03 11:21:43 -08003793 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003794
3795 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003796 def unaryOp(ser, rng, a, error_name=None):
3797 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003798 all_dtypes = [
3799 DType.INT8,
3800 DType.INT16,
3801 DType.INT32,
3802 DType.INT48,
3803 DType.FLOAT,
3804 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003805 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3806 outputDType = rng.choice(wrong_dtypes)
3807 else:
3808 outputDType = a.dtype
3809
3810 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003811
3812 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003813 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003814 if error_name != ErrorIf.RankMismatch:
3815 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003816 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003817
3818 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003819 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003820 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003821 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3822 else:
3823 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003824
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003825 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003826 all_dtypes = [
3827 DType.INT8,
3828 DType.INT16,
3829 DType.INT32,
3830 DType.INT48,
3831 DType.FLOAT,
3832 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003833 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3834 outputDType = rng.choice(wrong_dtypes)
3835 else:
3836 outputDType = a.dtype
3837
3838 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003839
3840 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003841 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003842 if error_name != ErrorIf.RankMismatch:
3843 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003844 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003845
3846 # Do broadcast
3847 shape = []
3848 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08003849 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07003850 shape.append(b.shape[i])
3851 else:
3852 shape.append(a.shape[i])
3853
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003854 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003855 wrong_dtypes = [
3856 DType.INT8,
3857 DType.INT16,
3858 DType.INT32,
3859 DType.INT48,
3860 DType.FLOAT,
3861 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003862 outputDType = rng.choice(wrong_dtypes)
3863 else:
3864 outputDType = DType.BOOL
3865
3866 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003867
3868 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01003869 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003870 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 if error_name not in [
3872 ErrorIf.AxisSmallerZero,
3873 ErrorIf.AxisLargerRank,
3874 ErrorIf.ShapeOfAxisNotOne,
3875 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01003876 shape[axis] = 1
3877 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
3878 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07003879
Matthew Haddond6ce7252021-09-29 15:35:44 +01003880 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003881 all_dtypes = [
3882 DType.INT8,
3883 DType.INT16,
3884 DType.INT32,
3885 DType.INT48,
3886 DType.FLOAT,
3887 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01003888 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3889 outputDType = rng.choice(wrong_dtypes)
3890 else:
3891 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003892
Matthew Haddond6ce7252021-09-29 15:35:44 +01003893 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003894
3895 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003896 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003897 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003898
3899 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
3900 del shape[axis]
3901
3902 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
3903 remove = rng.choice([True, False])
3904 if remove and len(shape) > 1:
3905 del shape[0]
3906 else:
3907 shape.append(1)
3908 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
3909 for i in range(len(shape)):
3910 shape[i] = shape[i] + rng.integers(1, 10)
3911
3912 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003913 all_dtypes = [
3914 DType.INT8,
3915 DType.INT16,
3916 DType.INT32,
3917 DType.INT48,
3918 DType.FLOAT,
3919 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003920 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
3921 outputDType = rng.choice(wrong_dtypes)
3922 else:
3923 outputDType = DType.INT32
3924
3925 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003926
3927 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003928 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003929
3930 # IFM: NHWC
3931 # Filter: OHWI
3932 # OFM: NHWC
3933
Kevin Cheng550ccc52021-03-03 11:21:43 -08003934 h = (
3935 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003936 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003937 + padding[0]
3938 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003939 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003940 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003941
Kevin Cheng550ccc52021-03-03 11:21:43 -08003942 w = (
3943 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003944 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003945 + padding[2]
3946 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003947 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003948 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003949
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003950 if error_name == ErrorIf.ConvOutputShapeMismatch:
3951 choices = [1, 2, 3]
3952 change = rng.choice(choices)
3953 # increment in multiples of stride to not hit non-integer error case
3954 if change in [1, 3]:
3955 h = h + (rng.choice(choices) * strides[0])
3956 if change in [2, 3]:
3957 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00003958
Eric Kunzee5e26762020-10-13 16:11:07 -07003959 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3960
Kevin Cheng3a478572021-01-22 17:21:02 -08003961 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003962 out_dtype = DType.INT32
3963 elif ifm.dtype == DType.INT16:
3964 out_dtype = DType.INT48
3965 elif ifm.dtype == DType.FLOAT:
3966 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00003967 elif error_name == ErrorIf.WrongInputType:
3968 # Pick some potentially correct output dtype if input type is incorrect
3969 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07003970 else:
Les Bell0e027d42021-11-09 14:42:14 +00003971 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
3972
3973 if error_name == ErrorIf.WrongOutputType:
3974 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
3975 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07003976
Kevin Cheng550ccc52021-03-03 11:21:43 -08003977 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003978
3979 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003980 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07003981
3982 # IFM: NDHWC
3983 # Filter: ODHWI
3984 # OFM: NDHWC
3985
3986 d = (
3987 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003988 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003989 + padding[0]
3990 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003991 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07003992 ) // strides[0] + 1
3993
3994 h = (
3995 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003996 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003997 + padding[2]
3998 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003999 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004000 ) // strides[1] + 1
4001
4002 w = (
4003 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004004 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004005 + padding[4]
4006 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004007 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004008 ) // strides[2] + 1
4009
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004010 if error_name == ErrorIf.ConvOutputShapeMismatch:
4011 choices = [1, 2, 3, 4]
4012 change = rng.choice(choices)
4013 # increment in multiples of stride to not hit non-integer error case
4014 if change in [1, 4]:
4015 d = d + (rng.choice(choices) * strides[0])
4016 if change in [2, 4]:
4017 h = h + (rng.choice(choices) * strides[1])
4018 if change in [3, 4]:
4019 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004020
Kevin Cheng1533b852021-09-01 12:51:58 -07004021 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4022
4023 if ifm.dtype == DType.INT8:
4024 out_dtype = DType.INT32
4025 elif ifm.dtype == DType.INT16:
4026 out_dtype = DType.INT48
4027 elif ifm.dtype == DType.FLOAT:
4028 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004029 elif error_name == ErrorIf.WrongInputType:
4030 # Pick some potentially correct output dtype if input type is incorrect
4031 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004032 else:
Les Bell0e027d42021-11-09 14:42:14 +00004033 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4034
4035 if error_name == ErrorIf.WrongOutputType:
4036 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4037 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004038
4039 return ser.addOutput(ofm_shape, out_dtype)
4040
4041 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004042 def depthwiseConv2dOp(
4043 ser, rng, ifm, filter, strides, padding, dilations, error_name=None
4044 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004045 # IFM: NHWC
4046 # Filter: HWCM
4047 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004048
Kevin Cheng550ccc52021-03-03 11:21:43 -08004049 h = (
4050 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004051 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004052 + padding[0]
4053 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004054 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004055 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004056
Kevin Cheng550ccc52021-03-03 11:21:43 -08004057 w = (
4058 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004059 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004060 + padding[2]
4061 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004062 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004063 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004064
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004065 if error_name == ErrorIf.ConvOutputShapeMismatch:
4066 choices = [1, 2, 3]
4067 change = rng.choice(choices)
4068 # increment in multiples of stride to not hit non-integer error case
4069 if change in [1, 3]:
4070 h = h + (rng.choice(choices) * strides[0])
4071 if change in [2, 3]:
4072 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004073
Eric Kunzee5e26762020-10-13 16:11:07 -07004074 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4075
Kevin Cheng3a478572021-01-22 17:21:02 -08004076 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004077 out_dtype = DType.INT32
4078 elif ifm.dtype == DType.INT16:
4079 out_dtype = DType.INT48
4080 elif ifm.dtype == DType.FLOAT:
4081 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004082 elif error_name == ErrorIf.WrongInputType:
4083 # Pick some potentially correct output dtype if input type is incorrect
4084 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004085 else:
Les Bell0e027d42021-11-09 14:42:14 +00004086 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4087
4088 if error_name == ErrorIf.WrongOutputType:
4089 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4090 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004091
Kevin Cheng550ccc52021-03-03 11:21:43 -08004092 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004093
4094 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004095 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004096 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004097 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004098 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004099 h = 1
4100 w = 1
4101 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004102 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4103 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004104
4105 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004106 choices = [1, 2, 3]
4107 change = rng.choice(choices)
4108 # increment in multiples of stride to not hit non-integer error case
4109 if change in [1, 3]:
4110 h = h + (rng.choice(choices) * stride[0])
4111 if change in [2, 3]:
4112 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004113 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004114
4115 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004116 all_dtypes = [
4117 DType.INT8,
4118 DType.INT16,
4119 DType.INT32,
4120 DType.INT48,
4121 DType.FLOAT,
4122 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004123 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4124 outputDType = rng.choice(wrong_dtypes)
4125 else:
4126 outputDType = ifm.dtype
4127
4128 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004129
4130 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004131 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004132 # input: N, IC
4133 # filter: OC, IC
4134 # output: N, OC
4135
4136 output_shape = [input.shape[0], filter.shape[0]]
4137
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004138 if error_name == ErrorIf.WrongOutputType:
4139 if input.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004140 incorrect_types = (
4141 DType.INT4,
4142 DType.INT8,
4143 DType.INT16,
4144 DType.INT48,
4145 DType.FLOAT,
4146 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004147 elif input.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004148 incorrect_types = (
4149 DType.INT4,
4150 DType.INT8,
4151 DType.INT16,
4152 DType.INT32,
4153 DType.FLOAT,
4154 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004155 elif input.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004156 incorrect_types = (
4157 DType.INT4,
4158 DType.INT8,
4159 DType.INT16,
4160 DType.INT32,
4161 DType.INT48,
4162 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004163 out_dtype = rng.choice(a=incorrect_types)
4164 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004165 out_dtype = DType.INT32
4166 elif input.dtype == DType.INT16:
4167 out_dtype = DType.INT48
4168 elif input.dtype == DType.FLOAT:
4169 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004170 elif error_name == ErrorIf.WrongInputType:
4171 # Pick some potentially correct output dtype if input type is incorrect
4172 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004173 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004174 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004175
Kevin Cheng550ccc52021-03-03 11:21:43 -08004176 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004177
4178 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004179 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004180 # a: N, H, C
4181 # b: N, C, W
4182 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004183
Kevin Cheng2d60f002021-06-09 14:18:32 -07004184 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004185
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004186 if error_name == ErrorIf.WrongOutputType:
4187 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004188 incorrect_types = (
4189 DType.INT4,
4190 DType.INT8,
4191 DType.INT16,
4192 DType.INT48,
4193 DType.FLOAT,
4194 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004195 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004196 incorrect_types = (
4197 DType.INT4,
4198 DType.INT8,
4199 DType.INT16,
4200 DType.INT32,
4201 DType.FLOAT,
4202 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004203 elif a.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004204 incorrect_types = (
4205 DType.INT4,
4206 DType.INT8,
4207 DType.INT16,
4208 DType.INT32,
4209 DType.INT48,
4210 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004211 out_dtype = rng.choice(a=incorrect_types)
4212 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004213 out_dtype = DType.INT32
4214 elif a.dtype == DType.INT16:
4215 out_dtype = DType.INT48
4216 elif a.dtype == DType.FLOAT:
4217 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004218 elif error_name == ErrorIf.WrongInputType:
4219 # Pick some potentially correct output dtype if input type is incorrect
4220 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004221 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004222 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004223
Kevin Cheng550ccc52021-03-03 11:21:43 -08004224 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004225
4226 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004227 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004228 input1 = a[0]
4229 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004230
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004231 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004232 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004233 if not (
4234 # unable to concat tensors of different ranks
4235 error_name == ErrorIf.ConcatInputRankMismatch
4236 # unable to concat tensors along an invalid axis
4237 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004238 ):
4239 for tensor in remaining_inputs:
4240 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004241
Matthew Haddon01c359d2021-10-15 16:30:48 +01004242 if error_name == ErrorIf.ConcatShapeSumMismatch:
4243 output_shape[axis] += rng.integers(5, 10)
4244
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004245 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004246 all_dtypes = {
4247 DType.INT8,
4248 DType.INT16,
4249 DType.INT32,
4250 DType.INT48,
4251 DType.FLOAT,
4252 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004253 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4254 outputDType = rng.choice(wrong_dtypes)
4255 else:
4256 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004257
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004258 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004259
4260 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004261 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004262
4263 output_shape = a.shape.copy()
4264
4265 for i in range(len(output_shape)):
4266 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4267
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004268 if error_name == ErrorIf.PadOutputShapeMismatch:
4269 bad_dim = rng.choice(range(len(output_shape)))
4270 output_shape[bad_dim] -= rng.choice([1, 2])
4271
Matthew Haddone807aae2021-10-11 18:12:58 +01004272 # Fix negative output shape if error_if test causes it
4273 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
4274 output_shape = [i if i >= 1 else 1 for i in output_shape]
4275
4276 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004277 all_dtypes = [
4278 DType.INT8,
4279 DType.INT16,
4280 DType.INT32,
4281 DType.INT48,
4282 DType.FLOAT,
4283 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004284 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4285 outputDType = rng.choice(wrong_dtypes)
4286 else:
4287 outputDType = a.dtype
4288
4289 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004290
4291 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004292 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004293 output_shape = shape.copy()
4294
Matthew Haddone807aae2021-10-11 18:12:58 +01004295 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4296 for i in range(len(output_shape)):
4297 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4298
4299 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004300 all_dtypes = [
4301 DType.INT8,
4302 DType.INT16,
4303 DType.INT32,
4304 DType.INT48,
4305 DType.FLOAT,
4306 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004307 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4308 outputDType = rng.choice(wrong_dtypes)
4309 else:
4310 outputDType = a.dtype
4311
4312 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004313
4314 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004315 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004316
Matthew Haddone807aae2021-10-11 18:12:58 +01004317 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004318 all_dtypes = [
4319 DType.INT8,
4320 DType.INT16,
4321 DType.INT32,
4322 DType.INT48,
4323 DType.FLOAT,
4324 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004325 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4326 outputDType = rng.choice(wrong_dtypes)
4327 else:
4328 outputDType = a.dtype
4329
4330 if error_name == ErrorIf.SizeOutputShapeMismatch:
4331 output_shape = size.copy()
4332 for index in range(len(output_shape)):
4333 if output_shape[index] <= 2:
4334 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4335 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004336 output_shape[index] = output_shape[index] + rng.choice(
4337 [-2, -1, 1, 2]
4338 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004339 else:
4340 output_shape = size.copy()
4341
4342 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004343
4344 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004345 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004346
4347 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004348 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004349
4350 for i in range(len(output_shape)):
4351 output_shape[i] = a.shape[i] * multiples[i]
4352
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004353 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004354 all_dtypes = [
4355 DType.INT8,
4356 DType.INT16,
4357 DType.INT32,
4358 DType.INT48,
4359 DType.FLOAT,
4360 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004361 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4362 outputDType = rng.choice(wrong_dtypes)
4363 else:
4364 outputDType = a.dtype
4365
4366 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004367
4368 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004369 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004370 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004371
Kevin Cheng550ccc52021-03-03 11:21:43 -08004372 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004373
Matthew Haddone807aae2021-10-11 18:12:58 +01004374 if error_name == ErrorIf.IndexOutsideBounds:
4375 for i in range(len(output_shape)):
4376 output_shape[i] = a.shape[0]
4377 else:
4378 for i in range(len(output_shape)):
4379 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004380
Matthew Haddone807aae2021-10-11 18:12:58 +01004381 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004382 all_dtypes = [
4383 DType.INT8,
4384 DType.INT16,
4385 DType.INT32,
4386 DType.INT48,
4387 DType.FLOAT,
4388 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004389 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4390 outputDType = rng.choice(wrong_dtypes)
4391 else:
4392 outputDType = a.dtype
4393
4394 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004395
4396 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004397 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004398 if error_name != ErrorIf.WrongRank:
4399 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004400 assert len(indices.shape) == 2
4401 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004402
Kevin Cheng77d0f762020-11-24 10:26:32 -08004403 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4404
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004405 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004406 all_dtypes = [
4407 DType.INT8,
4408 DType.INT16,
4409 DType.INT32,
4410 DType.INT48,
4411 DType.FLOAT,
4412 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004413 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4414 outputDType = rng.choice(wrong_dtypes)
4415 else:
4416 outputDType = values.dtype
4417
4418 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004419
4420 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004421 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004422 if error_name != ErrorIf.WrongRank:
4423 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004424 assert len(indices.shape) == 2
4425 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004426 assert values_in.shape[0] == indices.shape[0] # N
4427 assert input.shape[1] == indices.shape[1] # W
4428 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004429
4430 output_shape = values_in.shape
4431
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004432 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004433 all_dtypes = [
4434 DType.INT8,
4435 DType.INT16,
4436 DType.INT32,
4437 DType.INT48,
4438 DType.FLOAT,
4439 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004440 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4441 outputDType = rng.choice(wrong_dtypes)
4442 else:
4443 outputDType = values_in.dtype
4444
4445 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004446
4447 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004448 def tableOp(ser, rng, input, error_name=None):
4449 # Same shape as the input, dtype dependent on input dtype
4450 if error_name != ErrorIf.WrongInputType:
4451 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004452 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004453 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004454 wrong_dtypes = [
4455 DType.INT8,
4456 DType.INT16,
4457 DType.INT32,
4458 DType.INT48,
4459 DType.FLOAT,
4460 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004461 wrong_dtypes.remove(output_dtype)
4462 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004463 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004464
4465 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004466 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004467 serializer,
4468 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004469 input,
4470 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004471 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004472 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004473 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004474 input_dtype,
4475 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004476 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004477 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004478 # Calculate OH, OW
4479 scale_y_n = scale[0]
4480 scale_y_d = scale[1]
4481 scale_x_n = scale[2]
4482 scale_x_d = scale[3]
4483 if error_name == ErrorIf.ScaleSmallerEqualZero:
4484 scale_y_n = max(scale_y_n, 1)
4485 scale_y_d = max(scale_y_d, 1)
4486 scale_x_n = max(scale_x_n, 1)
4487 scale_x_d = max(scale_x_d, 1)
4488
4489 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4490 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4491
4492 if error_name is not None:
4493 # Make sure the output tensor is valid, which can occur when
4494 # scale, offset or border have been changed for ERROR_IFs
4495 oh = max(oh, 1)
4496 ow = max(ow, 1)
4497 if error_name != ErrorIf.MaxDimExceeded:
4498 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4499 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4500
4501 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4502 choices = [1, 2, 3]
4503 change = rng.choice(choices)
4504 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4505 if change in [1, 3]:
4506 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4507 oh -= scale_y_d
4508 assert oh > 0 # Should have been caught in agResize
4509 else:
4510 oh += scale_y_d
4511 if change in [2, 3]:
4512 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4513 ow -= scale_x_d
4514 assert ow > 0 # Should have been caught in agResize
4515 else:
4516 ow += scale_x_d
4517
Matthew Haddon848efb42021-09-09 12:30:53 +01004518 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004519 output_dims = [
4520 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004521 oh,
4522 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004523 input.shape[0],
4524 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004525 elif error_name == ErrorIf.BatchMismatch:
4526 output_dims = [
4527 input.shape[0] + rng.integers(1, 10),
4528 oh,
4529 ow,
4530 input.shape[3],
4531 ]
4532 elif error_name == ErrorIf.ChannelMismatch:
4533 output_dims = [
4534 input.shape[0],
4535 oh,
4536 ow,
4537 input.shape[3] + rng.integers(1, 10),
4538 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004539 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004540 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004541
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004542 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004543
4544 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004545 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004546 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004547
4548 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00004549 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004550 if error_name == ErrorIf.ConvOutputShapeMismatch:
4551 choices = [1, 2, 3]
4552 change = rng.choice(choices)
4553 if change in [1, 3]:
4554 output_shape[1] = output_shape[1] + rng.choice(choices)
4555 if change in [2, 3]:
4556 output_shape[2] = output_shape[2] + rng.choice(choices)
4557
Kevin Cheng3a478572021-01-22 17:21:02 -08004558 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004559 out_dtype = DType.INT32
4560 elif ifm.dtype == DType.INT16:
4561 out_dtype = DType.INT48
4562 elif ifm.dtype == DType.FLOAT:
4563 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004564 elif error_name == ErrorIf.WrongInputType:
4565 # Pick some potentially correct output dtype if input type is incorrect
4566 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004567 else:
Les Bell0e027d42021-11-09 14:42:14 +00004568 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4569
4570 if error_name == ErrorIf.WrongOutputType:
4571 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4572 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004573
Kevin Cheng550ccc52021-03-03 11:21:43 -08004574 return ser.addOutput(output_shape, out_dtype)