blob: 7c2b9de2ecb7d8f2de22e26d165ea8fe7ea54029 [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
16from generator.tosa_utils import usableDTypes
Les Bell0e027d42021-11-09 14:42:14 +000017from tosa.DType import DType
18from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010019
20
Eric Kunzee5e26762020-10-13 16:11:07 -070021class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010022 # Maximum rank of tensor supported by test generator.
23 TOSA_TENSOR_MAX_RANK = 6
24
Eric Kunzee5e26762020-10-13 16:11:07 -070025 def __init__(self, args):
26 self.args = args
27 self.basePath = args.output_dir
28 self.random_seed = args.random_seed
29 self.ser = None
30 self.rng = np.random.default_rng(self.random_seed)
31 self.createDynamicOpLists()
32 self.initOpListDefaults()
33 self.quantGen = TosaQuantGen()
34 # Force makeShape to do a specific starting shape
35 self.targetted_shape = None
36
37 def createSerializer(self, opName, testPath):
38 self.testPath = os.path.join(opName, testPath)
39
40 fullPath = os.path.join(self.basePath, self.testPath)
41 os.makedirs(fullPath, exist_ok=True)
42 self.ser = ts.TosaSerializer(fullPath)
43
44 def getSerializer(self):
45 return self.ser
46
47 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080048 with open(
49 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
50 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070051 fd.write(self.ser.serialize())
52
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
54 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070055
Matthew Haddon74567092021-07-16 15:38:20 +010056 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000057 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010058 seed = self.random_seed + 1
59 self.rng = np.random.default_rng(seed)
60
Eric Kunzee5e26762020-10-13 16:11:07 -070061 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070062 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070063 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070064 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070065 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070066 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070067 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010068 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
69 elif dtype == DType.UINT8:
70 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070071 elif dtype == DType.INT16:
72 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
73 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080074 return np.int32(
75 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
76 )
Eric Kunzee5e26762020-10-13 16:11:07 -070077 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080078 return np.int64(
79 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
80 )
Eric Kunzee5e26762020-10-13 16:11:07 -070081 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +010082 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070083 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -080084 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070085
Kevin Cheng989cb052021-04-28 16:29:44 -070086 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -070087 placeholders = []
88
Kevin Cheng989cb052021-04-28 16:29:44 -070089 assert len(shape_list) == len(dtype_list)
90
91 for idx, shape in enumerate(shape_list):
92 arr = self.getRandTensor(shape, dtype_list[idx])
93 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -070094
95 return placeholders
96
Kevin Cheng989cb052021-04-28 16:29:44 -070097 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -070098 consts = []
99
Kevin Cheng989cb052021-04-28 16:29:44 -0700100 assert len(shape_list) == len(dtype_list)
101
102 for idx, shape in enumerate(shape_list):
103 arr = self.getRandTensor(shape, dtype_list[idx])
104 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700105
106 return consts
107
108 def makeShape(self, rank):
109 if self.targetted_shape:
110 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800111 return np.int32(
112 self.rng.integers(
113 low=self.args.tensor_shape_range[0],
114 high=self.args.tensor_shape_range[1],
115 size=rank,
116 )
117 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700118
119 def setTargetShape(self, shape):
120 self.targetted_shape = shape
121
122 def randInt(self, low=0, high=256):
123 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
124
125 def getRandNumberDType(self, dtype):
126 if dtype == DType.FLOAT:
127 return self.rng.random()
128 elif dtype == DType.BOOL:
129 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700130 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700131 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700132 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700133 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100134 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 elif dtype == DType.INT16:
136 low, high = (-32768, 32768)
137 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800138 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700139 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 # Special size
142 return np.int64(self.rng.integers(low, high, size=1))[0]
143 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800144 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
146 return np.int32(self.rng.integers(low, high, size=1))[0]
147
148 def shapeStr(self, shape):
149
150 sStr = []
151 # Convert to strings
152 for i in shape:
153 sStr.append(str(i))
154
Kevin Cheng550ccc52021-03-03 11:21:43 -0800155 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700156
157 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -0700158 if isinstance(t, list):
159 assert len(t) >= 2
160 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700161 else:
Kevin Cheng989cb052021-04-28 16:29:44 -0700162 if t == DType.BOOL:
163 return "b"
164 elif t == DType.INT4:
165 return "i4"
166 elif t == DType.INT8:
167 return "i8"
168 elif t == DType.UINT8:
169 return "u8"
170 elif t == DType.INT16:
171 return "i16"
172 elif t == DType.INT32:
173 return "i32"
174 elif t == DType.INT48:
175 return "i48"
176 elif t == DType.FLOAT:
177 return "float"
178 else:
179 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -0700180
181 def typeWidth(self, t):
Jeremy Johnson5d1a3472022-03-31 09:50:06 +0100182 """Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -0800183 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 return 4
185 elif t == DType.INT8:
186 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -0800187 elif t == DType.UINT8:
188 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -0700189 elif t == DType.INT16:
190 return 16
191 elif t == DType.INT32:
192 return 32
193 elif t == DType.INT48:
194 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +0100195 elif t == DType.FLOAT:
196 return 32
197 elif t == DType.BOOL:
198 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700199 else:
Les Bell729b0352021-11-24 10:28:21 +0000200 raise Exception(f"Unknown dtype, cannot determine width: {t}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700201
202 # Argument generators
203 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
204 # Where the string descriptor is used to generate the test name and
205 # The build_fcn_arg_list is expanded and passed to the operator test
206 # build function
207
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100208 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
209 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
210
Matthew Haddon848efb42021-09-09 12:30:53 +0100211 # build_placeholder returns an int, ABS/other ops does not
212 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100213 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
214 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000215 elif op["op"] == Op.IDENTITY:
216 self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100217 return result_tens
218
219 # Ensure new output type has correct qinfo
220 if error_name == ErrorIf.WrongOutputType:
221 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
222 qinfo = ts.TosaSerializerQuantInfo()
223 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000224 TosaQuantGen.getQinfo(self, a.dtype),
225 TosaQuantGen.getQinfo(self, result_tens.dtype),
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100226 )
227
228 # Invalidate Input/Output list for error if checks.
229 input_list = [a.name]
230 output_list = [result_tens.name]
231 pCount, cCount = op["operands"]
232 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000233 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
234 self, error_name, input_list, output_list
235 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100236
Les Bell729b0352021-11-24 10:28:21 +0000237 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100238 self.ser,
239 validator_fcns,
240 error_name,
241 op=op,
242 input_dtype=a.dtype,
243 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000244 qinfo=qinfo,
245 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100246 input_list=input_list,
247 output_list=output_list,
248 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000249 ):
250 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100251
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000252 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700253 return result_tens
254
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100255 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000256 result_tens = OutputShaper.binaryBroadcastOp(
257 self.ser, self.rng, a, b, error_name
258 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100259
260 # Invalidate Input/Output list for error if checks.
261 input_list = [a.name, b.name]
262 output_list = [result_tens.name]
263 pCount, cCount = op["operands"]
264 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000265 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
266 self, error_name, input_list, output_list
267 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100268
Les Bell729b0352021-11-24 10:28:21 +0000269 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100270 self.ser,
271 validator_fcns,
272 error_name,
273 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000274 input1=a,
275 input2=b,
276 input_dtype=a.dtype,
277 output_dtype=result_tens.dtype,
278 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100279 input_list=input_list,
280 output_list=output_list,
281 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000282 ):
283 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100284
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000285 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700286 return result_tens
287
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100288 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000290 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700291 return result_tens
292
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000293 def build_arithmetic_right_shift(
294 self, op, a, b, round, validator_fcns=None, error_name=None
295 ):
296 result_tens = OutputShaper.binaryBroadcastOp(
297 self.ser, self.rng, a, b, error_name
298 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100299
300 # Invalidate Input/Output list for error if checks.
301 input_list = [a.name, b.name]
302 output_list = [result_tens.name]
303 pCount, cCount = op["operands"]
304 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000305 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
306 self, error_name, input_list, output_list
307 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100308
Les Bell729b0352021-11-24 10:28:21 +0000309 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100310 self.ser,
311 validator_fcns,
312 error_name,
313 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000314 input1=a,
315 input2=b,
316 input_dtype=a.dtype,
317 output_dtype=result_tens.dtype,
318 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100319 input_list=input_list,
320 output_list=output_list,
321 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000322 ):
323 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800324
325 attr = ts.TosaSerializerAttribute()
326 attr.ArithmeticRightShiftAttribute(round)
327
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000328 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800329 return result_tens
330
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100331 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000332 result_tens = OutputShaper.binaryBroadcastOp(
333 self.ser, self.rng, a, b, error_name
334 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700335
336 # Special for multiply:
337 # Force the result to INT32 for INT types
338 if a.dtype != DType.FLOAT:
339 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100340 if error_name == ErrorIf.WrongOutputType:
341 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
342 outputDType = self.rng.choice(all_dtypes)
343 result_tens.setDtype(outputDType)
344
345 # Invalidate Input/Output list for error if checks.
346 input_list = [a.name, b.name]
347 output_list = [result_tens.name]
348 pCount, cCount = op["operands"]
349 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000350 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
351 self, error_name, input_list, output_list
352 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100353
Les Bell729b0352021-11-24 10:28:21 +0000354 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100355 self.ser,
356 validator_fcns,
357 error_name,
358 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000359 input1=a,
360 input2=b,
361 input_dtype=a.dtype,
362 output_dtype=result_tens.dtype,
363 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100364 input_list=input_list,
365 output_list=output_list,
366 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000367 ):
368 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
Kevin Chengaee1fac2020-11-11 13:54:06 -0800370 attr = ts.TosaSerializerAttribute()
371 attr.MulAttribute(shift)
372
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000373 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700374 return result_tens
375
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100376 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
377 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700378
Kevin Chengfe392ce2021-10-18 21:51:55 +0000379 attr = ts.TosaSerializerAttribute()
380 attr.TableAttribute(table)
381
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100382 # Invalidate Input/Output list for error if checks.
383 input_list = [a.name]
384 output_list = [result_tens.name]
385 pCount, cCount = op["operands"]
386 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000387 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
388 self, error_name, input_list, output_list
389 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100390
Les Bell729b0352021-11-24 10:28:21 +0000391 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100392 self.ser,
393 validator_fcns,
394 error_name,
395 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000396 input_shape=a.shape,
397 input_dtype=a.dtype,
398 output_dtype=result_tens.dtype,
399 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100400 input_list=input_list,
401 output_list=output_list,
402 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000403 ):
404 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100405
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000406 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700407
408 return result_tens
409
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
411 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
412
413 # Invalidate Input/Output list for error if checks.
414 input_list = [cond.name, a.name, b.name]
415 output_list = [result_tens.name]
416 pCount, cCount = op["operands"]
417 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000418 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
419 self, error_name, input_list, output_list
420 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100421
Les Bell729b0352021-11-24 10:28:21 +0000422 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100423 self.ser,
424 validator_fcns,
425 error_name,
426 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000427 input1=cond,
428 input2=a,
429 input3=b,
430 input_shape=a.shape,
431 input_dtype=a.dtype,
432 output_dtype=result_tens.dtype,
433 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100434 input_list=input_list,
435 output_list=output_list,
436 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000437 ):
438 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100439
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 self.ser.addOperator(
441 op["op"],
442 input_list,
443 output_list,
444 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700445 return result_tens
446
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100447 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000448 result_tens = OutputShaper.binaryComparisonOp(
449 self.ser, self.rng, a, b, error_name
450 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100451
452 # Invalidate Input/Output list for error if checks.
453 input_list = [a.name, b.name]
454 output_list = [result_tens.name]
455 pCount, cCount = op["operands"]
456 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000457 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
458 self, error_name, input_list, output_list
459 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100460
Les Bell729b0352021-11-24 10:28:21 +0000461 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100462 self.ser,
463 validator_fcns,
464 error_name,
465 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000466 input1=a,
467 input2=b,
468 input_shape=a.shape,
469 input_dtype=a.dtype,
470 output_shape=result_tens.shape,
471 output_dtype=result_tens.dtype,
472 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100473 input_list=input_list,
474 output_list=output_list,
475 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000476 ):
477 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000479 self.ser.addOperator(
480 op["op"],
481 input_list,
482 output_list,
483 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700484 return result_tens
485
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100486 def build_argmax(self, op, a, axis, validator_fcns, error_name):
487 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
488
489 # Invalidate Input/Output list for error if checks.
490 input_list = [a.name]
491 output_list = [result_tens.name]
492 pCount, cCount = op["operands"]
493 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000494 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
495 self, error_name, input_list, output_list
496 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100497
Les Bell729b0352021-11-24 10:28:21 +0000498 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100499 self.ser,
500 validator_fcns,
501 error_name,
502 op=op,
503 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000504 input_shape=a.shape,
505 input_dtype=a.dtype,
506 output_shape=result_tens.shape,
507 output_dtype=result_tens.dtype,
508 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100509 input_list=input_list,
510 output_list=output_list,
511 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000512 ):
513 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700514
515 attr = ts.TosaSerializerAttribute()
516 attr.AxisAttribute(axis)
517
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000518 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700519 return result_tens
520
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000521 def build_pool2d(
522 self,
523 op,
524 input,
525 stride,
526 pad,
527 kernel,
528 validator_fcns=None,
529 error_name=None,
530 qinfo=None,
531 ):
532 result_tens = OutputShaper.pool2dOp(
533 self.ser, self.rng, input, kernel, stride, pad, error_name
534 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100535
536 # Ensure new output type has correct qinfo
537 if error_name == ErrorIf.WrongInputType:
538 if input.dtype not in [DType.INT8, DType.UINT8]:
539 qinfo = ts.TosaSerializerQuantInfo()
540 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000541 TosaQuantGen.getQinfo(self, input.dtype),
542 TosaQuantGen.getQinfo(self, result_tens.dtype),
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100543 )
544
545 # Invalidate Input/Output list for error if checks.
546 input_list = [input.name]
547 output_list = [result_tens.name]
548 pCount, cCount = op["operands"]
549 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000550 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
551 self, error_name, input_list, output_list
552 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100553
Les Bell729b0352021-11-24 10:28:21 +0000554 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100555 self.ser,
556 validator_fcns,
557 error_name,
558 op=op,
559 input_shape=input.shape,
560 input_dtype=input.dtype,
561 output_shape=result_tens.shape,
562 output_dtype=result_tens.dtype,
563 kernel=kernel,
564 stride=stride,
565 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000566 qinfo=qinfo,
567 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100568 input_list=input_list,
569 output_list=output_list,
570 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000571 ):
572 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700573
574 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700575 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -0700576
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000577 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700578 return result_tens
579
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000580 def build_conv2d(
581 self,
582 op,
583 ifm,
584 filter,
585 bias,
586 strides,
587 padding,
588 dilations,
589 validator_fcns=None,
590 error_name=None,
591 qinfo=None,
592 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800593 assert len(padding) == 4
594 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000595 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
596 )
597
598 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000599 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
600 DType.INT8,
601 DType.UINT8,
602 ):
Les Bell0e027d42021-11-09 14:42:14 +0000603 qinfo = ts.TosaSerializerQuantInfo()
604 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000605 TosaQuantGen.getQinfo(self, ifm.dtype),
606 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000607 )
608
609 # Invalidate Input/Output list for error_if checks.
610 input_list = [ifm.name, filter.name, bias.name]
611 output_list = [result_tens.name]
612 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000613 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
614 self, error_name, input_list, output_list
615 )
Les Bell0e027d42021-11-09 14:42:14 +0000616
Les Bell729b0352021-11-24 10:28:21 +0000617 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000618 self.ser,
619 validator_fcns,
620 error_name,
621 op=op,
622 input_dtype=ifm.dtype,
623 weight_dtype=filter.dtype,
624 output_dtype=result_tens.dtype,
625 qinfo=qinfo,
626 input_list=input_list,
627 num_operands=num_operands,
628 output_list=output_list,
629 pad=padding,
630 stride=strides,
631 dilation=dilations,
632 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100633 weight_shape=filter.shape,
634 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000635 ):
636 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700637
638 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700639 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700640
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000641 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700642 return result_tens
643
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000644 def build_conv3d(
645 self,
646 op,
647 ifm,
648 filter,
649 bias,
650 strides,
651 padding,
652 dilations,
653 validator_fcns=None,
654 error_name=None,
655 qinfo=None,
656 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700657 assert len(padding) == 6
658 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000659 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
660 )
661
662 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000663 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
664 DType.INT8,
665 DType.UINT8,
666 ):
Les Bell0e027d42021-11-09 14:42:14 +0000667 qinfo = ts.TosaSerializerQuantInfo()
668 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000669 TosaQuantGen.getQinfo(self, ifm.dtype),
670 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000671 )
672
673 # Invalidate Input/Output list for error_if checks.
674 input_list = [ifm.name, filter.name, bias.name]
675 output_list = [result_tens.name]
676 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000677 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
678 self, error_name, input_list, output_list
679 )
Les Bell0e027d42021-11-09 14:42:14 +0000680
Les Bell729b0352021-11-24 10:28:21 +0000681 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000682 self.ser,
683 validator_fcns,
684 error_name,
685 op=op,
686 input_dtype=ifm.dtype,
687 weight_dtype=filter.dtype,
688 output_dtype=result_tens.dtype,
689 qinfo=qinfo,
690 input_list=input_list,
691 num_operands=num_operands,
692 output_list=output_list,
693 pad=padding,
694 stride=strides,
695 dilation=dilations,
696 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100697 weight_shape=filter.shape,
698 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000699 ):
700 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700701
702 attr = ts.TosaSerializerAttribute()
703 attr.ConvAttribute(padding, strides, dilations)
704
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Kevin Cheng1533b852021-09-01 12:51:58 -0700706 return result_tens
707
Kevin Cheng550ccc52021-03-03 11:21:43 -0800708 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000709 self,
710 op,
711 ifm,
712 filter,
713 bias,
714 stride,
715 outpad,
716 dilation,
717 output_shape,
718 validator_fcns=None,
719 error_name=None,
720 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800721 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100722 assert len(outpad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000723 result_tens = OutputShaper.transposeConv2DOp(
724 self.ser, self.rng, ifm, output_shape, error_name
725 )
Les Bell0e027d42021-11-09 14:42:14 +0000726
727 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000728 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
729 DType.INT8,
730 DType.UINT8,
731 ):
Les Bell0e027d42021-11-09 14:42:14 +0000732 qinfo = ts.TosaSerializerQuantInfo()
733 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000734 TosaQuantGen.getQinfo(self, ifm.dtype),
735 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000736 )
737
738 # Invalidate Input/Output list for error_if checks.
739 input_list = [ifm.name, filter.name, bias.name]
740 output_list = [result_tens.name]
741 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000742 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
743 self, error_name, input_list, output_list
744 )
Les Bell0e027d42021-11-09 14:42:14 +0000745
Les Bell729b0352021-11-24 10:28:21 +0000746 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000747 self.ser,
748 validator_fcns,
749 error_name,
750 op=op,
751 input_dtype=ifm.dtype,
752 weight_dtype=filter.dtype,
753 output_dtype=result_tens.dtype,
754 qinfo=qinfo,
755 input_list=input_list,
756 num_operands=num_operands,
757 output_list=output_list,
758 pad=outpad,
759 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000760 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100761 weight_shape=filter.shape,
762 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000763 ):
764 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700765
766 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700767 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -0700768
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000769 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700770 return result_tens
771
Kevin Cheng550ccc52021-03-03 11:21:43 -0800772 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000773 self,
774 op,
775 ifm,
776 filter,
777 bias,
778 strides,
779 padding,
780 dilations,
781 validator_fcns=None,
782 error_name=None,
783 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800784 ):
785 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000786 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
787 )
788
789 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000790 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
791 DType.INT8,
792 DType.UINT8,
793 ):
Les Bell0e027d42021-11-09 14:42:14 +0000794 qinfo = ts.TosaSerializerQuantInfo()
795 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 TosaQuantGen.getQinfo(self, ifm.dtype),
797 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000798 )
799
800 # Invalidate Input/Output list for error_if checks.
801 input_list = [ifm.name, filter.name, bias.name]
802 output_list = [result_tens.name]
803 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000804 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
805 self, error_name, input_list, output_list
806 )
Les Bell0e027d42021-11-09 14:42:14 +0000807
Les Bell729b0352021-11-24 10:28:21 +0000808 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000809 self.ser,
810 validator_fcns,
811 error_name,
812 op=op,
813 input_dtype=ifm.dtype,
814 weight_dtype=filter.dtype,
815 output_dtype=result_tens.dtype,
816 qinfo=qinfo,
817 input_list=input_list,
818 num_operands=num_operands,
819 output_list=output_list,
820 pad=padding,
821 stride=strides,
822 dilation=dilations,
823 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100824 weight_shape=filter.shape,
825 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000826 ):
827 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700828
829 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700830 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700831
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000832 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700833 return result_tens
834
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000835 def build_fully_connected(
836 self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
837 ):
838 result_tens = OutputShaper.fullyConnectedOp(
839 self.ser, self.rng, ifm, filter, error_name
840 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100841
842 # Invalidate Input/Output list for error if checks.
843 input_list = [ifm.name, filter.name, bias.name]
844 output_list = [result_tens.name]
845 pCount, cCount = op["operands"]
846 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000847 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
848 self, error_name, input_list, output_list
849 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100850
Les Bell729b0352021-11-24 10:28:21 +0000851 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100852 self.ser,
853 validator_fcns,
854 error_name,
855 op=op,
856 input_shape=ifm.shape,
857 input_dtype=ifm.dtype,
858 weight_dtype=filter.dtype,
859 output_shape=result_tens.shape,
860 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000861 qinfo=qinfo,
862 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100863 input_list=input_list,
864 output_list=output_list,
865 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000866 ):
867 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700868
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000869 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700870 return result_tens
871
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100872 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
873 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
874
875 # Invalidate Input/Output list for error if checks.
876 input_list = [a.name, b.name]
877 output_list = [result_tens.name]
878 pCount, cCount = op["operands"]
879 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000880 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
881 self, error_name, input_list, output_list
882 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100883
Les Bell729b0352021-11-24 10:28:21 +0000884 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100885 self.ser,
886 validator_fcns,
887 error_name,
888 op=op,
889 input_shape=a.shape,
890 input_dtype=a.dtype,
891 input2_shape=b.shape,
892 input2_dtype=b.dtype,
893 output_shape=result_tens.shape,
894 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000895 qinfo=qinfo,
896 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100897 input_list=input_list,
898 output_list=output_list,
899 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000900 ):
901 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100902
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000903 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700904 return result_tens
905
Matthew Haddond6ce7252021-09-29 15:35:44 +0100906 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
907 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
908
909 # Invalidate Input/Output list for error if checks.
910 input_list = [a.name]
911 output_list = [result_tens.name]
912 pCount, cCount = op["operands"]
913 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000914 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
915 self, error_name, input_list, output_list
916 )
Matthew Haddond6ce7252021-09-29 15:35:44 +0100917
Les Bell729b0352021-11-24 10:28:21 +0000918 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +0100919 self.ser,
920 validator_fcns,
921 error_name,
922 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000923 axis=axis,
924 input_shape=a.shape,
925 output_shape=result_tens.shape,
926 input_dtype=a.dtype,
927 output_dtype=result_tens.dtype,
928 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +0100929 input_list=input_list,
930 output_list=output_list,
931 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000932 ):
933 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700934
935 attr = ts.TosaSerializerAttribute()
936 attr.AxisAttribute(axis)
937
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000938 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700939 return result_tens
940
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100941 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
942 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700943
Jeremy Johnson18e26662021-07-22 16:15:29 +0100944 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700945
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100946 if error_name == ErrorIf.MaxSmallerMin:
947 # Make sure the numbers are different to invoke this error
948 while v[0] == v[1]:
949 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
950 max_val = min(v)
951 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700952 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100953 max_val = max(v)
954 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700955
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100956 # Invalidate Input/Output list for error if checks.
957 input_list = [a.name]
958 output_list = [result_tens.name]
959 pCount, cCount = op["operands"]
960 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000961 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
962 self, error_name, input_list, output_list
963 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100964
Les Bell729b0352021-11-24 10:28:21 +0000965 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100966 self.ser,
967 validator_fcns,
968 error_name,
969 op=op,
970 max_val=max_val,
971 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000972 input_shape=a.shape,
973 output_shape=result_tens.shape,
974 input_dtype=a.dtype,
975 output_dtype=result_tens.dtype,
976 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100977 input_list=input_list,
978 output_list=output_list,
979 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000980 ):
981 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100982
983 attr = ts.TosaSerializerAttribute()
984 if a.dtype == DType.FLOAT:
985 attr.ClampAttribute(0, 0, min_val, max_val)
986 else:
987 attr.ClampAttribute(min_val, max_val, 0, 0)
988
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000989 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700990 return result_tens
991
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100992 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
993 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700994 attr = ts.TosaSerializerAttribute()
995
996 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
997
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000998 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700999 return result_tens
1000
1001 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001002 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1003 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001004
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001005 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001006 return result_tens
1007
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001008 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1009 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1010
1011 # Invalidate Input/Output list for error if checks.
1012 input_list = [a.name]
1013 output_list = [result_tens.name]
1014 pCount, cCount = op["operands"]
1015 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001016 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1017 self, error_name, input_list, output_list
1018 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001019
Les Bell729b0352021-11-24 10:28:21 +00001020 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001021 self.ser,
1022 validator_fcns,
1023 error_name,
1024 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001025 input_shape=a.shape,
1026 output_shape=result_tens.shape,
1027 input_dtype=a.dtype,
1028 output_dtype=result_tens.dtype,
1029 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001030 input_list=input_list,
1031 output_list=output_list,
1032 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001033 ):
1034 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001035
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001036 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001037 return result_tens
1038
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001039 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1040 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1041
1042 # Invalidate Input/Output list for error if checks.
1043 input_list = [a.name]
1044 output_list = [result_tens.name]
1045 pCount, cCount = op["operands"]
1046 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001047 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1048 self, error_name, input_list, output_list
1049 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001050
Les Bell729b0352021-11-24 10:28:21 +00001051 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001052 self.ser,
1053 validator_fcns,
1054 error_name,
1055 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001056 input_shape=a.shape,
1057 output_shape=result_tens.shape,
1058 input_dtype=a.dtype,
1059 output_dtype=result_tens.dtype,
1060 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001061 input_list=input_list,
1062 output_list=output_list,
1063 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001064 ):
1065 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001066
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001067 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001068 return result_tens
1069
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001070 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1071 if error_name != ErrorIf.WrongInputType:
1072 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001073
1074 # To store variable length list of input tensors we need to store axis along with it
1075 axis = a[-1]
1076 a = a[:-1]
1077
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001078 result_tens = OutputShaper.concatOp(
1079 self.ser, self.rng, axis, *a, error_name=error_name
1080 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001081
Matthew Haddon818ab902021-07-27 09:12:49 +01001082 input_tensor_names = []
1083 for tensor in a:
1084 input_tensor_names.append(tensor.name)
1085
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001086 # Invalidate Input/Output list for error if checks.
1087 input_list = input_tensor_names
1088 output_list = [result_tens.name]
1089 pCount, cCount = op["operands"]
1090 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001091 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1092 self, error_name, input_list, output_list
1093 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001094
Les Bell729b0352021-11-24 10:28:21 +00001095 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001096 self.ser,
1097 validator_fcns,
1098 error_name,
1099 op=op,
1100 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001101 input_shape=a[0].shape,
1102 output_shape=result_tens.shape,
1103 input_dtype=a[0].dtype,
1104 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001105 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001106 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001107 input_list=input_list,
1108 output_list=output_list,
1109 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001110 ):
1111 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001112
1113 attr = ts.TosaSerializerAttribute()
1114 attr.AxisAttribute(axis)
1115
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001116 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001117 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001118
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001119 def build_pad(
1120 self,
1121 op,
1122 a,
1123 padding,
1124 pad_const_int,
1125 pad_const_float,
1126 validator_fcns=None,
1127 error_name=None,
1128 qinfo=None,
1129 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001130 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001131
Kevin Chengfe392ce2021-10-18 21:51:55 +00001132 attr = ts.TosaSerializerAttribute()
1133 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001134
Matthew Haddone807aae2021-10-11 18:12:58 +01001135 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001136 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001137 output_list = [result_tens.name]
1138 pCount, cCount = op["operands"]
1139 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001140 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1141 self, error_name, input_list, output_list
1142 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001143
Les Bell729b0352021-11-24 10:28:21 +00001144 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001145 self.ser,
1146 validator_fcns,
1147 error_name,
1148 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001149 input_shape=a.shape,
1150 output_shape=result_tens.shape,
1151 input_dtype=a.dtype,
1152 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001153 pad=padding,
1154 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001155 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001156 input_list=input_list,
1157 output_list=output_list,
1158 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001159 ):
1160 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001161
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001162 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Matthew Haddone86fd342021-09-07 16:12:21 +01001163 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001164
Matthew Haddone807aae2021-10-11 18:12:58 +01001165 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001166 result_tens = OutputShaper.reshapeOp(
1167 self.ser, self.rng, a, newShape, error_name
1168 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001169
1170 # Invalidate Input/Output list for error if checks.
1171 input_list = [a.name]
1172 output_list = [result_tens.name]
1173 pCount, cCount = op["operands"]
1174 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001175 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1176 self, error_name, input_list, output_list
1177 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001178
Les Bell729b0352021-11-24 10:28:21 +00001179 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001180 self.ser,
1181 validator_fcns,
1182 error_name,
1183 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001184 input_shape=a.shape,
1185 output_shape=result_tens.shape,
1186 input_dtype=a.dtype,
1187 output_dtype=result_tens.dtype,
1188 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001189 input_list=input_list,
1190 output_list=output_list,
1191 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001192 ):
1193 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001194
1195 attr = ts.TosaSerializerAttribute()
1196 attr.ReshapeAttribute(newShape)
1197
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001198 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001199 return result_tens
1200
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001201 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1202 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1203
1204 # Invalidate Input/Output list for error if checks.
1205 input_list = [a.name]
1206 output_list = [result_tens.name]
1207 pCount, cCount = op["operands"]
1208 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001209 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1210 self, error_name, input_list, output_list
1211 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001212
Les Bell729b0352021-11-24 10:28:21 +00001213 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001214 self.ser,
1215 validator_fcns,
1216 error_name,
1217 op=op,
1218 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001219 input_shape=a.shape,
1220 output_shape=result_tens.shape,
1221 input_dtype=a.dtype,
1222 output_dtype=result_tens.dtype,
1223 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001224 input_list=input_list,
1225 output_list=output_list,
1226 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001227 ):
1228 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001229
1230 attr = ts.TosaSerializerAttribute()
1231 attr.AxisAttribute(axis)
1232
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001233 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001234 return result_tens
1235
Matthew Haddone807aae2021-10-11 18:12:58 +01001236 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1237 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001238
Kevin Chengfe392ce2021-10-18 21:51:55 +00001239 attr = ts.TosaSerializerAttribute()
1240 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001241
Matthew Haddone807aae2021-10-11 18:12:58 +01001242 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001243 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001244 output_list = [result_tens.name]
1245 pCount, cCount = op["operands"]
1246 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001247 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1248 self, error_name, input_list, output_list
1249 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001250
Les Bell729b0352021-11-24 10:28:21 +00001251 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001252 self.ser,
1253 validator_fcns,
1254 error_name,
1255 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001256 input_shape=a.shape,
1257 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001258 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001259 input_dtype=a.dtype,
1260 output_dtype=result_tens.dtype,
1261 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001262 input_list=input_list,
1263 output_list=output_list,
1264 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001265 ):
1266 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001267
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001268 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001269 return result_tens
1270
Matthew Haddone807aae2021-10-11 18:12:58 +01001271 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001272 result_tens = OutputShaper.sliceOp(
1273 self.ser, self.rng, a, start, size, error_name
1274 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001275
1276 # Invalidate Input/Output list for error if checks.
1277 input_list = [a.name]
1278 output_list = [result_tens.name]
1279 pCount, cCount = op["operands"]
1280 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001281 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1282 self, error_name, input_list, output_list
1283 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001284
Les Bell729b0352021-11-24 10:28:21 +00001285 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001286 self.ser,
1287 validator_fcns,
1288 error_name,
1289 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001290 input_shape=a.shape,
1291 output_shape=result_tens.shape,
1292 input_dtype=a.dtype,
1293 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001294 start=start,
1295 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001296 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001297 input_list=input_list,
1298 output_list=output_list,
1299 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001300 ):
1301 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001302
1303 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001304 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001305
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001306 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001307 return result_tens
1308
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001309 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1310 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1311
1312 # Invalidate Input/Output list for error if checks.
1313 input_list = [a.name]
1314 output_list = [result_tens.name]
1315 pCount, cCount = op["operands"]
1316 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001317 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1318 self, error_name, input_list, output_list
1319 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001320
Les Bell729b0352021-11-24 10:28:21 +00001321 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322 self.ser,
1323 validator_fcns,
1324 error_name,
1325 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001326 input_shape=a.shape,
1327 output_shape=result_tens.shape,
1328 input_dtype=a.dtype,
1329 output_dtype=result_tens.dtype,
1330 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001331 input_list=input_list,
1332 output_list=output_list,
1333 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001334 ):
1335 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001336
1337 attr = ts.TosaSerializerAttribute()
1338 attr.TileAttribute(multiples)
1339
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001340 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001341 return result_tens
1342
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001343 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001344
1345 # Create a new indicies tensor
1346 # here with data that doesn't exceed the dimensions of the values tensor
1347
Kevin Cheng550ccc52021-03-03 11:21:43 -08001348 K = values.shape[1] # K
1349 W = self.randInt(
1350 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1351 ) # W
1352 indicies_arr = np.int32(
1353 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1354 ) # (N, W)
1355 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001356
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001357 result_tens = OutputShaper.gatherOp(
1358 self.ser, self.rng, values, indicies, error_name
1359 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001360
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001361 # Invalidate Input/Output list for error if checks.
1362 input_list = [values.name, indicies.name]
1363 output_list = [result_tens.name]
1364 pCount, cCount = op["operands"]
1365 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001366 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1367 self, error_name, input_list, output_list
1368 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001369
Les Bell729b0352021-11-24 10:28:21 +00001370 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001371 self.ser,
1372 validator_fcns,
1373 error_name,
1374 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001375 input_shape=values.shape,
1376 output_shape=result_tens.shape,
1377 input_dtype=values.dtype,
1378 output_dtype=result_tens.dtype,
1379 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001380 input_list=input_list,
1381 output_list=output_list,
1382 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001383 ):
1384 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001385
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001386 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001387
1388 return result_tens
1389
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001391
1392 # Create a new indicies tensor
1393 # here with data that doesn't exceed the dimensions of the values_in tensor
1394
Kevin Cheng550ccc52021-03-03 11:21:43 -08001395 K = values_in.shape[1] # K
1396 W = input.shape[1] # W
1397 indicies_arr = np.int32(
1398 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1399 ) # (N, W)
1400 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001401
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001402 result_tens = OutputShaper.scatterOp(
1403 self.ser, self.rng, values_in, indicies, input, error_name
1404 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001405
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001406 # Invalidate Input/Output list for error if checks.
1407 input_list = [values_in.name, indicies.name, input.name]
1408 output_list = [result_tens.name]
1409 pCount, cCount = op["operands"]
1410 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001411 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1412 self, error_name, input_list, output_list
1413 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001414
Les Bell729b0352021-11-24 10:28:21 +00001415 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001416 self.ser,
1417 validator_fcns,
1418 error_name,
1419 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001420 input_shape=values_in.shape,
1421 output_shape=result_tens.shape,
1422 input_dtype=values_in.dtype,
1423 output_dtype=result_tens.dtype,
1424 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001425 input_list=input_list,
1426 output_list=output_list,
1427 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001428 ):
1429 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001430
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001431 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432
Kevin Cheng77d0f762020-11-24 10:26:32 -08001433 return result_tens
1434
Kevin Cheng550ccc52021-03-03 11:21:43 -08001435 def build_resize(
1436 self,
1437 op,
1438 input,
1439 mode,
1440 stride,
1441 offset,
1442 shift,
1443 stride_fp,
1444 offset_fp,
1445 output_dims,
1446 input_dtype,
1447 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001448 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001449 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001450 ):
1451 result_tens = OutputShaper.resizeOp(
1452 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001453 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001454 input,
1455 mode,
1456 stride,
1457 offset,
1458 shift,
1459 stride_fp,
1460 offset_fp,
1461 output_dims,
1462 input_dtype,
1463 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001464 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001465 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001466
Matthew Haddon848efb42021-09-09 12:30:53 +01001467 # Invalidate Input/Output list for error if checks.
1468 input_list = [input.name]
1469 output_list = [result_tens.name]
1470 pCount, cCount = op["operands"]
1471 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001472 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1473 self, error_name, input_list, output_list
1474 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001475
Les Bell729b0352021-11-24 10:28:21 +00001476 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001477 self.ser,
1478 validator_fcns,
1479 error_name,
1480 op=op,
1481 mode=mode,
1482 shift=shift,
1483 input_dtype=input_dtype,
1484 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001485 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001486 output_shape=output_dims,
1487 offset=offset,
1488 offset_fp=offset_fp,
1489 stride=stride,
1490 stride_fp=stride_fp,
1491 input_list=input_list,
1492 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001493 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001494 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001495 ):
1496 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001497
Eric Kunzee5e26762020-10-13 16:11:07 -07001498 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001499
Kevin Cheng550ccc52021-03-03 11:21:43 -08001500 attr.ResizeAttribute(
1501 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1502 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001503
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001504 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001505 return result_tens
1506
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001507 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1508 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1509 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001510 self.ser.addOperator(
1511 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1512 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001513 return result_tens
1514
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001515 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001516 self.ser.addOutputTensor(val)
1517 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001518
1519 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001520 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001521 result_tens = OutputShaper.typeConversionOp(
1522 self.ser, self.rng, val, out_dtype, error_name
1523 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001524
1525 # Invalidate Input/Output list for error if checks.
1526 input_list = [val.name]
1527 output_list = [result_tens.name]
1528 pCount, cCount = op["operands"]
1529 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001530 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1531 self, error_name, input_list, output_list
1532 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001533
Les Bell729b0352021-11-24 10:28:21 +00001534 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001535 self.ser,
1536 validator_fcns,
1537 error_name,
1538 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001539 input_shape=val.shape,
1540 output_shape=result_tens.shape,
1541 input_dtype=val.dtype,
1542 output_dtype=result_tens.dtype,
1543 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001544 input_list=input_list,
1545 output_list=output_list,
1546 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001547 ):
1548 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001549
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001550 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001551 return result_tens
1552
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001553 def build_rescale(
1554 self,
1555 op,
1556 val,
1557 out_dtype,
1558 scale32,
1559 double_round,
1560 per_channel,
1561 validator_fcns,
1562 error_name,
1563 ):
1564 result_tens = OutputShaper.typeConversionOp(
1565 self.ser, self.rng, val, out_dtype, error_name
1566 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001567
1568 if per_channel:
1569 nc = val.shape[-1]
1570 else:
1571 nc = 1
1572
1573 in_type_width = self.typeWidth(val.dtype)
1574 out_type_width = self.typeWidth(out_dtype)
1575
Kevin Cheng3a478572021-01-22 17:21:02 -08001576 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001577 input_zp = self.randInt(-128, 128)
1578 in_type_width = in_type_width + 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001579 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001580 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001581 in_type_width = in_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01001582 elif error_name == ErrorIf.InputZeroPointNotZero:
1583 input_zp = self.randInt(-128, 128)
1584 if input_zp == 0:
1585 input_zp = input_zp + self.rng.integers(1, 10)
1586 in_type_width = in_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001587 else:
1588 input_zp = 0
1589
Kevin Cheng3a478572021-01-22 17:21:02 -08001590 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001591 output_zp = self.randInt(-128, 128)
1592 out_type_width = out_type_width + 1
1593 elif out_dtype == DType.UINT8:
1594 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001595 out_type_width = out_type_width + 1
Matthew Haddonc2025212021-10-08 21:21:05 +01001596 elif error_name == ErrorIf.OutputZeroPointNotZero:
1597 output_zp = self.randInt(-128, 128)
1598 if output_zp == 0:
1599 output_zp = output_zp + self.rng.integers(1, 10)
1600 out_type_width = out_type_width + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001601 else:
1602 output_zp = 0
1603
1604 # Calculate scale based on:
1605 # scale = a *(2^output_width)/(2^input_width))
1606
1607 a = np.float32(self.rng.random(size=[nc]))
1608 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1609
1610 if scale32:
1611 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001612 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001613 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1614 else:
1615 # Cap the scaling at 2^15 - 1 for scale16
1616 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1617
Kevin Cheng550ccc52021-03-03 11:21:43 -08001618 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001619
1620 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1621 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001622 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1623 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001624
1625 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001626 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1627 scale_arr[i], scale32
1628 )
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001629 min_shift_value_arr[i] = -1 << (shift_arr[i] - 2)
1630 max_shift_value_arr[i] = (1 << (shift_arr[i] - 2)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001631
Kevin Cheng550ccc52021-03-03 11:21:43 -08001632 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001633 if scale32 and error_name is None:
1634 # Make sure random values are within apply_scale_32 speicification
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001635 # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2))
1636 assert val.placeholderFilename
1637 values = np.load(
1638 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1639 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001640 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1641 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1642 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1643 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001644 if not np.all(np.array_equal(values, val_adj)):
1645 # Values changed so overwrite file with new values
1646 np.save(
1647 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1648 val_adj,
1649 False,
1650 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001651
Matthew Haddonc2025212021-10-08 21:21:05 +01001652 # Invalidate Input/Output list for error if checks.
1653 input_list = [val.name]
1654 output_list = [result_tens.name]
1655 pCount, cCount = op["operands"]
1656 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001657 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1658 self, error_name, input_list, output_list
1659 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001660
1661 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001662 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001663 self.ser,
1664 validator_fcns,
1665 error_name,
1666 op=op,
1667 input_dtype=val.dtype,
1668 output_dtype=out_dtype,
1669 input_shape=val.shape,
1670 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001671 scale32=scale32,
1672 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001673 input_list=input_list,
1674 output_list=output_list,
1675 result_tensor=result_tens,
1676 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001677 ):
1678 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001679
Eric Kunzee5e26762020-10-13 16:11:07 -07001680 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001681 attr.RescaleAttribute(
1682 input_zp,
1683 output_zp,
1684 multiplier_arr,
1685 shift_arr,
1686 scale32,
1687 double_round,
1688 per_channel,
1689 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001690
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001691 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001692 return result_tens
1693
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001694 def build_cond_if_const(
1695 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1696 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001697 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1698 # (except for the generated shap) and the condition. Build Then/Else blocks
1699 # and fill them with const nodes for the body.
1700
1701 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001702 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001703
1704 # Make then/else tensors
1705 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001706
1707 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001708 if error_name in [
1709 ErrorIf.CondIfOutputListThenGraphMismatch,
1710 ErrorIf.CondIfOutputListElseGraphMismatch,
1711 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001712 incorrect_shape = deepcopy(then_tens.shape)
1713 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001714 incorrect_shape[i] += (
1715 self.rng.choice([-3, -2, 2, 3])
1716 if incorrect_shape[i] > 3
1717 else self.rng.choice([1, 2, 4])
1718 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001719 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1720
Jeremy Johnson18e26662021-07-22 16:15:29 +01001721 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1722 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001723
1724 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001725 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
1727 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001728 then_block = "THEN_BLOCK"
1729 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001730 attr = ts.TosaSerializerAttribute()
1731 attr.CondIfAttribute(then_block, else_block)
1732
1733 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001734 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001735
1736 self.ser.startBasicBlock(then_block)
1737 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001738 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1739 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1740 else:
1741 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001742 self.ser.addOutputTensor(then_tens)
1743
1744 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001745 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1746 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1747 else:
1748 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001749 self.ser.addOutputTensor(else_tens)
1750
Les Bell729b0352021-11-24 10:28:21 +00001751 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001752 self.ser,
1753 validator_fcns,
1754 error_name,
1755 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001756 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001757 ):
1758 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001759
Eric Kunzee5e26762020-10-13 16:11:07 -07001760 return result_tens
1761
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001762 def build_cond_if_binary(
1763 self, op, a, b, cond, validator_fcns=None, error_name=None
1764 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001765 # For cond_if with a binary op in the then/else blocks, take a and b and
1766 # alternately add or subtract them based on the condition
1767
1768 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001769 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001770
Kevin Cheng550ccc52021-03-03 11:21:43 -08001771 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001772
1773 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001774 then_block = "THEN_BLOCK"
1775 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001776 attr = ts.TosaSerializerAttribute()
1777 attr.CondIfAttribute(then_block, else_block)
1778
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001779 if error_name in [
1780 ErrorIf.CondIfInputListThenGraphMismatch,
1781 ErrorIf.CondIfInputListElseGraphMismatch,
1782 ErrorIf.CondIfOutputListElseGraphMismatch,
1783 ErrorIf.CondIfOutputListThenGraphMismatch,
1784 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001785 incorrect_shape = a.shape.copy()
1786 for i in range(len(incorrect_shape)):
1787 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1788 incorrect_block_input = deepcopy(a)
1789 incorrect_block_input.shape = incorrect_shape
1790
Eric Kunzee5e26762020-10-13 16:11:07 -07001791 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001792 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001793 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001794 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001795
Les Bell6040b4d2021-10-11 12:50:31 +01001796 if a.dtype in (DType.FLOAT, DType.INT32):
1797 then_op, else_op = Op.ADD, Op.SUB
1798 elif a.dtype in (DType.INT8, DType.INT16):
1799 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1800 else:
1801 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001802
Les Bell6040b4d2021-10-11 12:50:31 +01001803 for block, op in ((then_block, then_op), (else_block, else_op)):
1804 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001805 if (
1806 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1807 and block == then_block
1808 ) or (
1809 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1810 and block == else_block
1811 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001812 self.ser.addInputTensor(incorrect_block_input)
1813 self.ser.addInputTensor(b)
1814 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001815 elif (
1816 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1817 and block == then_block
1818 ) or (
1819 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1820 and block == else_block
1821 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001822 self.ser.addInputTensor(a)
1823 self.ser.addInputTensor(b)
1824 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1825 else:
1826 self.ser.addInputTensor(a)
1827 self.ser.addInputTensor(b)
1828 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001829 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001830
Les Bell729b0352021-11-24 10:28:21 +00001831 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001832 self.ser,
1833 validator_fcns,
1834 error_name,
1835 op=op,
1836 a=a,
1837 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001838 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001839 ):
1840 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001841
Eric Kunzee5e26762020-10-13 16:11:07 -07001842 return result_tens
1843
Matthew Haddon630c17c2021-10-14 15:05:41 +01001844 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001845 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001846
Kevin Cheng550ccc52021-03-03 11:21:43 -08001847 cond_block = "COND_BLOCK"
1848 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001849
1850 attr = ts.TosaSerializerAttribute()
1851 attr.WhileLoopAttribute(cond_block, body_block)
1852
1853 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001854 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001855 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001856 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001857
1858 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001859 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1860 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001861 if error_name == ErrorIf.InputListOutputListMismatch:
1862 incorrect_acc = deepcopy(acc)
1863 for i in range(len(incorrect_acc.shape)):
1864 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1865 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1866 else:
1867 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001868
1869 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001870 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001871 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001872 [iter.name, a.name, acc.name],
1873 [iter_out.name, a_out.name, acc_out.name],
1874 attr,
1875 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001876 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001877
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001878 if error_name in [
1879 ErrorIf.InputListCondGraphMismatch,
1880 ErrorIf.InputListBodyGraphInputMismatch,
1881 ErrorIf.InputListBodyGraphOutputMismatch,
1882 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001883 incorrect_iter = deepcopy(iter)
1884 for i in range(len(incorrect_iter.shape)):
1885 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
1886 if len(incorrect_iter.shape) == 0:
1887 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
1888
1889 incorrect_acc = deepcopy(acc)
1890 for i in range(len(incorrect_acc.shape)):
1891 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1892
Eric Kunzee5e26762020-10-13 16:11:07 -07001893 # COND block (input: iter, output: cond_tens )
1894 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001895 if error_name == ErrorIf.InputListCondGraphMismatch:
1896 self.ser.addInputTensor(incorrect_iter)
1897 self.ser.addInputTensor(a)
1898 self.ser.addInputTensor(incorrect_acc)
1899 else:
1900 self.ser.addInputTensor(iter)
1901 self.ser.addInputTensor(a)
1902 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001903 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001904
1905 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001906 cond_tens = self.ser.addOutput(
1907 [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
1908 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001909 else:
1910 cond_tens = self.ser.addOutput([], DType.BOOL)
1911
Kevin Cheng550ccc52021-03-03 11:21:43 -08001912 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001913
1914 # BODY block (input: a, acc, iter, output: a, acc, iter)
1915 # Note that local intermediate tensors need to be declared here for the outputs
1916 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001917 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
1918 self.ser.addInputTensor(incorrect_iter)
1919 self.ser.addInputTensor(a)
1920 self.ser.addInputTensor(incorrect_acc)
1921 else:
1922 self.ser.addInputTensor(iter)
1923 self.ser.addInputTensor(a)
1924 self.ser.addInputTensor(acc)
1925
Kevin Cheng550ccc52021-03-03 11:21:43 -08001926 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001927
1928 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001929 iter_body_out = self.ser.addIntermediate(
1930 incorrect_iter.shape, incorrect_iter.dtype
1931 )
1932 acc_body_out = self.ser.addIntermediate(
1933 incorrect_acc.shape, incorrect_acc.dtype
1934 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001935 else:
1936 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1937 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
1938
Eric Kunzee5e26762020-10-13 16:11:07 -07001939 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1940 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1941 self.ser.addOutputTensor(iter_body_out)
1942 self.ser.addOutputTensor(a)
1943 self.ser.addOutputTensor(acc_body_out)
1944
Les Bell729b0352021-11-24 10:28:21 +00001945 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001946 self.ser,
1947 validator_fcns,
1948 error_name,
1949 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001950 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001951 ):
1952 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001953
Eric Kunzee5e26762020-10-13 16:11:07 -07001954 return acc_out
1955
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001956 def create_filter_lists(
1957 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
1958 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01001959 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1960 default_test_rank_range = range(1, 5)
1961 if not shapeFilter:
1962 shapeFilter = [None]
1963
1964 # Calculate the filters based on what is requested and what the operator allows
1965 rmin, rmax = op["rank"]
1966 if rankFilter is not None:
1967 cleanRankFilter = []
1968 # Ensure rankFilter values are allowed by operator
1969 for rank in rankFilter:
1970 if rank >= rmin and rank <= rmax:
1971 cleanRankFilter.append(rank)
1972 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01001973 # Ensure default behaviour is bounded by default range or by operator,
1974 # whichever is the smaller range of ranks.
1975 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001976 cleanRankFilter = (
1977 opRankRange
1978 if len(opRankRange) <= len(default_test_rank_range)
1979 else default_test_rank_range
1980 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01001981 else:
1982 cleanRankFilter = range(rmin, rmax + 1)
1983
1984 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001985
Matthew Haddon1c00b712021-10-01 15:51:03 +01001986 if dtypeFilter is not None:
1987 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01001988 # Create list of operator dtypes filtered by requested dtypes
1989 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001990 if dtype in dtypeFilter or (
1991 isinstance(dtype, list) and dtype[0] in dtypeFilter
1992 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01001993 cleanDtypeFilter.append(dtype)
1994 else:
1995 cleanDtypeFilter = dtypes
1996
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001997 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01001998 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001999 "shapeFilter": shapeFilter,
2000 "rankFilter": cleanRankFilter,
2001 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002002 }
2003 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002004 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002005 if validator is not None:
2006 validator_info = validator(check=False, op=op)
2007 else:
2008 return None
2009
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002010 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002011
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002012 # Set parameters as required
2013 if error_arguments["rank"] is not None:
2014 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002015 else:
2016 rankFilter = cleanRankFilter
2017
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002018 if error_arguments["dtype"] is not None:
2019 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002020 else:
2021 dtypeFilter = cleanDtypeFilter
2022
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002023 if error_arguments["shape"] is not None:
2024 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002025 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002026 shapeFilter = shapeFilter[
2027 :2
2028 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002029
2030 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002031 "shapeFilter": shapeFilter,
2032 "rankFilter": rankFilter,
2033 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002034 }
2035 return filterDict
2036
Kevin Cheng550ccc52021-03-03 11:21:43 -08002037 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002038 self,
2039 opName,
2040 shapeFilter=[None],
2041 rankFilter=None,
2042 dtypeFilter=None,
2043 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002044 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002045
2046 try:
2047 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002048 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002049 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002050
2051 # Initialize a new random number generator
2052 self.rng = np.random.default_rng(self.random_seed)
2053
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002054 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002055
Eric Kunzee5e26762020-10-13 16:11:07 -07002056 # Test list consists of a tuple of:
2057 # (opName, testNameStr, dtype, shapeList, argumentsList)
2058 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002059 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002060 error_if_validators = op["error_if_validators"]
2061 else:
2062 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002063
Matthew Haddon1c00b712021-10-01 15:51:03 +01002064 for validator in error_if_validators:
2065 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002066 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002067 else:
2068 error_name = None
2069
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002070 filterDict = self.create_filter_lists(
2071 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2072 )
2073 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002074 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002075 cleanRankFilter = filterDict["rankFilter"]
2076 cleanDtypeFilter = filterDict["dtypeFilter"]
2077 cleanShapeFilter = filterDict["shapeFilter"]
2078 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002079
2080 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002081 for t in cleanDtypeFilter:
2082 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002083 # Filter out by rank
2084 if shape is not None and len(shape) != r:
2085 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002086 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002087 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002088
Matthew Haddon74567092021-07-16 15:38:20 +01002089 shapeStr = self.shapeStr(shapeList[0])
2090 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002091
Matthew Haddon74567092021-07-16 15:38:20 +01002092 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2093 argList = []
2094 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002095 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002096 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002097 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002098
Matthew Haddon74567092021-07-16 15:38:20 +01002099 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002100 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002101 if argStr:
2102 testStr = "{}_{}_{}_{}".format(
2103 opName, shapeStr, typeStr, argStr
2104 )
2105 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002106 testStr = "{}_{}_{}".format(
2107 opName, shapeStr, typeStr
2108 )
2109 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002110 if argStr:
2111 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2112 opName, error_name, shapeStr, typeStr, argStr
2113 )
2114 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002115 testStr = "{}_ERRORIF_{}_{}_{}".format(
2116 opName, error_name, shapeStr, typeStr
2117 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002118
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002119 testList.append(
2120 (opName, testStr, t, error_name, shapeList, args)
2121 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002122
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002123 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002124 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2125 if "invalid_test_validators" in op:
2126 invalid_test_validators = op["invalid_test_validators"]
2127 clean_testList = []
2128 for test in testList:
2129 for validator_fcn in invalid_test_validators:
2130 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002131 if validator_fcn(
2132 opName=test[0],
2133 input_dtype=test[2],
2134 shapeList=test[4],
2135 args=test[5],
2136 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002137 remove_test = True
2138 if not remove_test:
2139 clean_testList.append(test)
2140 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002141
2142 return testList
2143
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002144 def serializeTest(
2145 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2146 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002147 try:
2148 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002149 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002150 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002151
2152 # Create a serializer
2153 self.createSerializer(opName, testStr)
2154
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002155 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002156 if "error_if_validators" in op:
2157 error_if_validators = op["error_if_validators"]
2158 else:
2159 error_if_validators = None
2160
Kevin Cheng550ccc52021-03-03 11:21:43 -08002161 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002162 num_operands = pCount + cCount
2163
2164 if isinstance(dtype_or_dtypeList, list):
2165 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002166 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002167 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002168 else:
2169 dtypeList = [dtype_or_dtypeList] * (num_operands)
2170
Kevin Cheng93a16282021-08-31 16:14:03 -07002171 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002172 assert (
2173 len(shapeList) == num_operands
2174 ), "shapeList length {} must match number of operands {}".format(
2175 len(shapeList), num_operands
2176 )
2177 assert (
2178 len(dtypeList) == num_operands
2179 ), "dtypeList length {} must match number of operands {}".format(
2180 len(dtypeList), num_operands
2181 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002182
2183 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002184 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002185 except KeyError:
2186 qgen = None
2187
2188 # Build the random tensor operands and the test
2189 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002190
Matthew Haddon1c00b712021-10-01 15:51:03 +01002191 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002192 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002193 else:
2194 qinfo = None
2195
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002196 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002197
Matthew Haddon1c00b712021-10-01 15:51:03 +01002198 try:
2199 if error_if_validators is None:
2200 if qinfo is not None:
2201 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2202 else:
2203 resultName = build_fcn(self, op, *tens, *testArgs)
2204 else:
2205 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002206 resultName = build_fcn(
2207 self,
2208 op,
2209 *tens,
2210 *testArgs,
2211 validator_fcns=error_if_validators,
2212 error_name=error_name,
2213 qinfo=qinfo,
2214 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002215 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002216 resultName = build_fcn(
2217 self,
2218 op,
2219 *tens,
2220 *testArgs,
2221 validator_fcns=error_if_validators,
2222 error_name=error_name,
2223 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002224 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002225 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002226 raise e
2227
Les Bell729b0352021-11-24 10:28:21 +00002228 if resultName:
2229 # The test is valid, serialize it
2230 self.serialize("test")
2231 else:
2232 # The test is not valid
2233 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002234
Eric Kunzee5e26762020-10-13 16:11:07 -07002235 def createDynamicOpLists(self):
2236
2237 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002238 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002239
Kevin Cheng1533b852021-09-01 12:51:58 -07002240 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002241 testName = "conv2d_{}x{}".format(k[0], k[1])
2242 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2243 self.TOSA_OP_LIST[testName]["filter"] = k
2244 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002245
Kevin Cheng550ccc52021-03-03 11:21:43 -08002246 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2247 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2248 "depthwise_conv2d_TEMPLATE"
2249 ].copy()
2250 self.TOSA_OP_LIST[testName]["filter"] = k
2251 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002252
Kevin Cheng550ccc52021-03-03 11:21:43 -08002253 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2254 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2255 "transpose_conv2d_TEMPLATE"
2256 ].copy()
2257 self.TOSA_OP_LIST[testName]["filter"] = k
2258 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002259
Kevin Cheng1533b852021-09-01 12:51:58 -07002260 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2261 for k in KERNELS_3D:
2262 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2263 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2264 self.TOSA_OP_LIST[testName]["filter"] = k
2265 self.TOSA_OP_LIST[testName]["template"] = False
2266
Eric Kunzee5e26762020-10-13 16:11:07 -07002267 # Delete any templates after having created any dynamic ops
2268 # This is a two-pass operation because it's bad practice to delete
2269 # keys from dictionaries while iterating
2270 keyList = []
2271 for k in self.TOSA_OP_LIST:
2272 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002273 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002274 keyList.append(k)
2275 continue
2276 except KeyError:
2277 pass
2278
2279 for k in keyList:
2280 del self.TOSA_OP_LIST[k]
2281
2282 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002283 """Fill in default fields for ops if they aren't already specified.
2284 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002285 for op in self.TOSA_OP_LIST:
2286
2287 # Required fields
2288 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002289 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002290 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002291 raise Exception(
2292 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2293 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002294
2295 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002296 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002297 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002298 raise Exception(
2299 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2300 op
2301 )
2302 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002303
2304 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002305 _ = self.TOSA_OP_LIST[op]["types"]
2306 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002307 raise Exception(
2308 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2309 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002310
2311 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002312 _ = self.TOSA_OP_LIST[op]["op"]
2313 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002314 raise Exception(
2315 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2316 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002317
2318 # Put in default rank range, if missing
2319 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002320 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002321 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002322 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002323
2324 # Tensor operator list
2325 # 'op': op name
2326 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002327 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2328 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002329 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2330 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002331 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002332
Kevin Cheng550ccc52021-03-03 11:21:43 -08002333 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2334 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002335
Kevin Cheng550ccc52021-03-03 11:21:43 -08002336 TYPE_BOOL = [DType.BOOL]
2337 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2338 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2339 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002340
Kevin Cheng550ccc52021-03-03 11:21:43 -08002341 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002342
Kevin Cheng1533b852021-09-01 12:51:58 -07002343 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002344 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002345 [DType.INT8, DType.INT8, DType.INT32],
2346 [DType.INT16, DType.INT8, DType.INT48],
2347 DType.FLOAT,
2348 ]
2349
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002350 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002351
2352 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002353 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002354 "argmax": {
2355 "op": Op.ARGMAX,
2356 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002357 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002358 "build_fcn": (
2359 build_argmax,
2360 TosaTensorGen.tgBasic,
2361 TosaTensorValuesGen.tvgDefault,
2362 TosaArgGen.agAxis,
2363 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002364 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002365 "error_if_validators": (
2366 TosaErrorValidator.evAxisSmallerZero,
2367 TosaErrorValidator.evAxisLargerRank,
2368 TosaErrorValidator.evArgmaxOutputRankMismatch,
2369 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2370 TosaErrorValidator.evWrongRank,
2371 TosaErrorValidator.evWrongInputType,
2372 TosaErrorValidator.evWrongOutputType,
2373 TosaErrorValidator.evWrongInputList,
2374 TosaErrorValidator.evWrongOutputList,
2375 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002376 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002377 "avg_pool2d": {
2378 "op": Op.AVG_POOL2D,
2379 "operands": (1, 0),
2380 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002381 "build_fcn": (
2382 build_pool2d,
2383 TosaTensorGen.tgNHWC,
2384 TosaTensorValuesGen.tvgDefault,
2385 TosaArgGen.agPooling,
2386 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002387 "qgen": TosaQuantGen.qgUnary,
2388 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002389 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002390 "error_if_validators": (
2391 TosaErrorValidator.evKernelSmallerOne,
2392 TosaErrorValidator.evStrideSmallerOne,
2393 TosaErrorValidator.evPadSmallerZero,
2394 TosaErrorValidator.evWrongRank,
2395 TosaErrorValidator.evWrongInputType,
2396 TosaErrorValidator.evWrongOutputType,
2397 TosaErrorValidator.evWrongInputList,
2398 TosaErrorValidator.evWrongOutputList,
2399 TosaErrorValidator.evInputZeroPointNotZero,
2400 TosaErrorValidator.evOutputZeroPointNotZero,
2401 TosaErrorValidator.evPadLargerEqualKernel,
2402 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002403 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002404 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002405 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002406 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002407 "conv2d_TEMPLATE": {
2408 "op": Op.CONV2D,
2409 "operands": (1, 2),
2410 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002411 "build_fcn": (
2412 build_conv2d,
2413 TosaTensorGen.tgConv2D,
2414 TosaTensorValuesGen.tvgDefault,
2415 TosaArgGen.agConv,
2416 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002417 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002418 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002419 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2420 "error_if_validators": (
2421 TosaErrorValidator.evWrongInputType,
2422 TosaErrorValidator.evWrongOutputType,
2423 TosaErrorValidator.evWrongInputList,
2424 TosaErrorValidator.evWrongOutputList,
2425 TosaErrorValidator.evInputZeroPointNotZero,
2426 TosaErrorValidator.evWeightZeroPointNotZero,
2427 TosaErrorValidator.evPadSmallerZero,
2428 TosaErrorValidator.evStrideSmallerOne,
2429 TosaErrorValidator.evDilationSmallerOne,
2430 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002431 TosaErrorValidator.evConvOutputShapeMismatch,
2432 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002433 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002434 "template": True,
2435 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002436 # Templated operator. Filled in by createDynamicOpLists
2437 "conv3d_TEMPLATE": {
2438 "op": Op.CONV3D,
2439 "operands": (1, 2),
2440 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002441 "build_fcn": (
2442 build_conv3d,
2443 TosaTensorGen.tgConv3D,
2444 TosaTensorValuesGen.tvgDefault,
2445 TosaArgGen.agConv,
2446 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002447 "qgen": TosaQuantGen.qgConv,
2448 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002449 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2450 "error_if_validators": (
2451 TosaErrorValidator.evWrongInputType,
2452 TosaErrorValidator.evWrongOutputType,
2453 TosaErrorValidator.evWrongInputList,
2454 TosaErrorValidator.evWrongOutputList,
2455 TosaErrorValidator.evInputZeroPointNotZero,
2456 TosaErrorValidator.evWeightZeroPointNotZero,
2457 TosaErrorValidator.evPadSmallerZero,
2458 TosaErrorValidator.evStrideSmallerOne,
2459 TosaErrorValidator.evDilationSmallerOne,
2460 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002461 TosaErrorValidator.evConvOutputShapeMismatch,
2462 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002463 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002464 "template": True,
2465 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002466 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002467 "depthwise_conv2d_TEMPLATE": {
2468 "op": Op.DEPTHWISE_CONV2D,
2469 "operands": (1, 2),
2470 "filter": [1, 1],
2471 "rank": (4, 4),
2472 "build_fcn": (
2473 build_depthwise_conv2d,
2474 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002475 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002476 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002477 ),
2478 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002479 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002480 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2481 "error_if_validators": (
2482 TosaErrorValidator.evWrongInputType,
2483 TosaErrorValidator.evWrongOutputType,
2484 TosaErrorValidator.evWrongInputList,
2485 TosaErrorValidator.evWrongOutputList,
2486 TosaErrorValidator.evInputZeroPointNotZero,
2487 TosaErrorValidator.evWeightZeroPointNotZero,
2488 TosaErrorValidator.evPadSmallerZero,
2489 TosaErrorValidator.evStrideSmallerOne,
2490 TosaErrorValidator.evDilationSmallerOne,
2491 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002492 TosaErrorValidator.evConvOutputShapeMismatch,
2493 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002494 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002495 "template": True,
2496 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002497 "fully_connected": {
2498 "op": Op.FULLY_CONNECTED,
2499 "operands": (1, 2),
2500 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002501 "build_fcn": (
2502 build_fully_connected,
2503 TosaTensorGen.tgFullyConnected,
2504 TosaTensorValuesGen.tvgDefault,
2505 None,
2506 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002507 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002508 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002509 "error_if_validators": (
2510 TosaErrorValidator.evInputZeroPointNotZero,
2511 TosaErrorValidator.evWeightZeroPointNotZero,
2512 TosaErrorValidator.evWrongRank,
2513 TosaErrorValidator.evWrongInputType,
2514 TosaErrorValidator.evWrongOutputType,
2515 TosaErrorValidator.evWrongInputList,
2516 TosaErrorValidator.evWrongOutputList,
2517 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002518 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002519 "matmul": {
2520 "op": Op.MATMUL,
2521 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002522 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002523 "build_fcn": (
2524 build_matmul,
2525 TosaTensorGen.tgMatmul,
2526 TosaTensorValuesGen.tvgDefault,
2527 None,
2528 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002529 "qgen": TosaQuantGen.qgMatmul,
2530 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002531 "error_if_validators": (
2532 TosaErrorValidator.evInputZeroPointNotZero,
2533 TosaErrorValidator.evWrongRank,
2534 TosaErrorValidator.evWrongInputType,
2535 TosaErrorValidator.evWrongOutputType,
2536 TosaErrorValidator.evWrongInputList,
2537 TosaErrorValidator.evWrongOutputList,
2538 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002539 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002540 "max_pool2d": {
2541 "op": Op.MAX_POOL2D,
2542 "operands": (1, 0),
2543 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002544 "build_fcn": (
2545 build_pool2d,
2546 TosaTensorGen.tgNHWC,
2547 TosaTensorValuesGen.tvgDefault,
2548 TosaArgGen.agPooling,
2549 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002550 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002551 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002552 "error_if_validators": (
2553 TosaErrorValidator.evKernelSmallerOne,
2554 TosaErrorValidator.evStrideSmallerOne,
2555 TosaErrorValidator.evPadSmallerZero,
2556 TosaErrorValidator.evWrongRank,
2557 TosaErrorValidator.evWrongInputType,
2558 TosaErrorValidator.evWrongOutputType,
2559 TosaErrorValidator.evWrongInputList,
2560 TosaErrorValidator.evWrongOutputList,
2561 TosaErrorValidator.evPadLargerEqualKernel,
2562 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002563 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002564 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002565 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002566 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002567 "transpose_conv2d_TEMPLATE": {
2568 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002569 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002570 "rank": (4, 4),
2571 "build_fcn": (
2572 build_transpose_conv2d,
2573 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002574 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002575 TosaArgGen.agTransposeConv2D,
2576 ),
2577 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002578 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002579 "invalid_test_validators": (
2580 TosaInvalidValidator.ivHeightWidthInvalid,
2581 TosaInvalidValidator.ivNonPositiveOutputShape,
2582 ),
2583 "error_if_validators": (
2584 TosaErrorValidator.evWrongInputType,
2585 TosaErrorValidator.evWrongOutputType,
2586 TosaErrorValidator.evWrongInputList,
2587 TosaErrorValidator.evWrongOutputList,
2588 TosaErrorValidator.evInputZeroPointNotZero,
2589 TosaErrorValidator.evWeightZeroPointNotZero,
2590 TosaErrorValidator.evPadSmallerZero,
2591 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002592 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002593 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002594 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002595 "template": True,
2596 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002597 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002598 "clamp": {
2599 "op": Op.CLAMP,
2600 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002601 "build_fcn": (
2602 build_clamp,
2603 TosaTensorGen.tgBasic,
2604 TosaTensorValuesGen.tvgDefault,
2605 None,
2606 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002607 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002608 "error_if_validators": (
2609 TosaErrorValidator.evMaxSmallerMin,
2610 TosaErrorValidator.evWrongInputType,
2611 TosaErrorValidator.evWrongOutputType,
2612 TosaErrorValidator.evWrongInputList,
2613 TosaErrorValidator.evWrongOutputList,
2614 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002615 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002616 "sigmoid": {
2617 "op": Op.SIGMOID,
2618 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002619 "build_fcn": (
2620 build_sigmoid,
2621 TosaTensorGen.tgBasic,
2622 TosaTensorValuesGen.tvgDefault,
2623 None,
2624 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002625 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002626 "error_if_validators": (
2627 TosaErrorValidator.evWrongInputType,
2628 TosaErrorValidator.evWrongOutputType,
2629 TosaErrorValidator.evWrongInputList,
2630 TosaErrorValidator.evWrongOutputList,
2631 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002632 },
2633 "tanh": {
2634 "op": Op.TANH,
2635 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002636 "build_fcn": (
2637 build_tanh,
2638 TosaTensorGen.tgBasic,
2639 TosaTensorValuesGen.tvgDefault,
2640 None,
2641 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002642 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002643 "error_if_validators": (
2644 TosaErrorValidator.evWrongInputType,
2645 TosaErrorValidator.evWrongOutputType,
2646 TosaErrorValidator.evWrongInputList,
2647 TosaErrorValidator.evWrongOutputList,
2648 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002649 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002650 # Elementwise Binary Operators
2651 "add": {
2652 "op": Op.ADD,
2653 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002654 "build_fcn": (
2655 build_binary_broadcast,
2656 TosaTensorGen.tgBroadcastFuzz,
2657 TosaTensorValuesGen.tvgAddSub,
2658 None,
2659 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002660 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002661 "error_if_validators": (
2662 TosaErrorValidator.evRankMismatch,
2663 TosaErrorValidator.evWrongInputType,
2664 TosaErrorValidator.evWrongOutputType,
2665 TosaErrorValidator.evWrongInputList,
2666 TosaErrorValidator.evWrongOutputList,
2667 TosaErrorValidator.evDimensionMismatch,
2668 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002669 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002670 "arithmetic_right_shift": {
2671 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2672 "operands": (2, 0),
2673 "build_fcn": (
2674 build_arithmetic_right_shift,
2675 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002676 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002677 TosaArgGen.agArithmeticRightShift,
2678 ),
2679 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002680 "error_if_validators": (
2681 TosaErrorValidator.evRankMismatch,
2682 TosaErrorValidator.evWrongInputType,
2683 TosaErrorValidator.evWrongOutputType,
2684 TosaErrorValidator.evWrongInputList,
2685 TosaErrorValidator.evWrongOutputList,
2686 TosaErrorValidator.evDimensionMismatch,
2687 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002688 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002689 "bitwise_and": {
2690 "op": Op.BITWISE_AND,
2691 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002692 "build_fcn": (
2693 build_binary_broadcast,
2694 TosaTensorGen.tgBroadcastFuzz,
2695 TosaTensorValuesGen.tvgDefault,
2696 None,
2697 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002698 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002699 "error_if_validators": (
2700 TosaErrorValidator.evRankMismatch,
2701 TosaErrorValidator.evWrongInputType,
2702 TosaErrorValidator.evWrongOutputType,
2703 TosaErrorValidator.evWrongInputList,
2704 TosaErrorValidator.evWrongOutputList,
2705 TosaErrorValidator.evDimensionMismatch,
2706 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002707 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002708 "bitwise_or": {
2709 "op": Op.BITWISE_OR,
2710 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002711 "build_fcn": (
2712 build_binary_broadcast,
2713 TosaTensorGen.tgBroadcastFuzz,
2714 TosaTensorValuesGen.tvgDefault,
2715 None,
2716 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002717 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002718 "error_if_validators": (
2719 TosaErrorValidator.evRankMismatch,
2720 TosaErrorValidator.evWrongInputType,
2721 TosaErrorValidator.evWrongOutputType,
2722 TosaErrorValidator.evWrongInputList,
2723 TosaErrorValidator.evWrongOutputList,
2724 TosaErrorValidator.evDimensionMismatch,
2725 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002726 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002727 "bitwise_xor": {
2728 "op": Op.BITWISE_XOR,
2729 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002730 "build_fcn": (
2731 build_binary_broadcast,
2732 TosaTensorGen.tgBroadcastFuzz,
2733 TosaTensorValuesGen.tvgDefault,
2734 None,
2735 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002736 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002737 "error_if_validators": (
2738 TosaErrorValidator.evRankMismatch,
2739 TosaErrorValidator.evWrongInputType,
2740 TosaErrorValidator.evWrongOutputType,
2741 TosaErrorValidator.evWrongInputList,
2742 TosaErrorValidator.evWrongOutputList,
2743 TosaErrorValidator.evDimensionMismatch,
2744 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002745 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002746 "intdiv": {
2747 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002748 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002749 "build_fcn": (
2750 build_binary_broadcast,
2751 TosaTensorGen.tgBroadcastFuzz,
2752 TosaTensorValuesGen.tvgIntDiv,
2753 None,
2754 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002755 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002756 "error_if_validators": (
2757 TosaErrorValidator.evRankMismatch,
2758 TosaErrorValidator.evWrongInputType,
2759 TosaErrorValidator.evWrongOutputType,
2760 TosaErrorValidator.evWrongInputList,
2761 TosaErrorValidator.evWrongOutputList,
2762 TosaErrorValidator.evDimensionMismatch,
2763 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002764 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002765 "logical_and": {
2766 "op": Op.LOGICAL_AND,
2767 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002768 "build_fcn": (
2769 build_binary_broadcast,
2770 TosaTensorGen.tgBroadcastFuzz,
2771 TosaTensorValuesGen.tvgDefault,
2772 None,
2773 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002774 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002775 "error_if_validators": (
2776 TosaErrorValidator.evRankMismatch,
2777 TosaErrorValidator.evWrongInputType,
2778 TosaErrorValidator.evWrongOutputType,
2779 TosaErrorValidator.evWrongInputList,
2780 TosaErrorValidator.evWrongOutputList,
2781 TosaErrorValidator.evDimensionMismatch,
2782 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002783 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002784 "logical_left_shift": {
2785 "op": Op.LOGICAL_LEFT_SHIFT,
2786 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002787 "build_fcn": (
2788 build_binary_broadcast,
2789 TosaTensorGen.tgBroadcastFuzz,
2790 TosaTensorValuesGen.tvgLogicalShift,
2791 None,
2792 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002793 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002794 "error_if_validators": (
2795 TosaErrorValidator.evRankMismatch,
2796 TosaErrorValidator.evWrongInputType,
2797 TosaErrorValidator.evWrongOutputType,
2798 TosaErrorValidator.evWrongInputList,
2799 TosaErrorValidator.evWrongOutputList,
2800 TosaErrorValidator.evDimensionMismatch,
2801 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002802 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002803 "logical_right_shift": {
2804 "op": Op.LOGICAL_RIGHT_SHIFT,
2805 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002806 "build_fcn": (
2807 build_binary_broadcast,
2808 TosaTensorGen.tgBroadcastFuzz,
2809 TosaTensorValuesGen.tvgLogicalShift,
2810 None,
2811 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002812 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002813 "error_if_validators": (
2814 TosaErrorValidator.evRankMismatch,
2815 TosaErrorValidator.evWrongInputType,
2816 TosaErrorValidator.evWrongOutputType,
2817 TosaErrorValidator.evWrongInputList,
2818 TosaErrorValidator.evWrongOutputList,
2819 TosaErrorValidator.evDimensionMismatch,
2820 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002821 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002822 "logical_or": {
2823 "op": Op.LOGICAL_OR,
2824 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002825 "build_fcn": (
2826 build_binary_broadcast,
2827 TosaTensorGen.tgBroadcastFuzz,
2828 TosaTensorValuesGen.tvgDefault,
2829 None,
2830 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002831 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002832 "error_if_validators": (
2833 TosaErrorValidator.evRankMismatch,
2834 TosaErrorValidator.evWrongInputType,
2835 TosaErrorValidator.evWrongOutputType,
2836 TosaErrorValidator.evWrongInputList,
2837 TosaErrorValidator.evWrongOutputList,
2838 TosaErrorValidator.evDimensionMismatch,
2839 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002840 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002841 "logical_xor": {
2842 "op": Op.LOGICAL_XOR,
2843 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002844 "build_fcn": (
2845 build_binary_broadcast,
2846 TosaTensorGen.tgBroadcastFuzz,
2847 TosaTensorValuesGen.tvgDefault,
2848 None,
2849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002850 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002851 "error_if_validators": (
2852 TosaErrorValidator.evRankMismatch,
2853 TosaErrorValidator.evWrongInputType,
2854 TosaErrorValidator.evWrongOutputType,
2855 TosaErrorValidator.evWrongInputList,
2856 TosaErrorValidator.evWrongOutputList,
2857 TosaErrorValidator.evDimensionMismatch,
2858 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002859 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002860 "maximum": {
2861 "op": Op.MAXIMUM,
2862 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002863 "build_fcn": (
2864 build_binary_broadcast,
2865 TosaTensorGen.tgBroadcastFuzz,
2866 TosaTensorValuesGen.tvgDefault,
2867 None,
2868 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002869 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 "error_if_validators": (
2871 TosaErrorValidator.evRankMismatch,
2872 TosaErrorValidator.evWrongInputType,
2873 TosaErrorValidator.evWrongOutputType,
2874 TosaErrorValidator.evWrongInputList,
2875 TosaErrorValidator.evWrongOutputList,
2876 TosaErrorValidator.evDimensionMismatch,
2877 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002878 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002879 "minimum": {
2880 "op": Op.MINIMUM,
2881 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002882 "build_fcn": (
2883 build_binary_broadcast,
2884 TosaTensorGen.tgBroadcastFuzz,
2885 TosaTensorValuesGen.tvgDefault,
2886 None,
2887 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002888 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002889 "error_if_validators": (
2890 TosaErrorValidator.evRankMismatch,
2891 TosaErrorValidator.evWrongInputType,
2892 TosaErrorValidator.evWrongOutputType,
2893 TosaErrorValidator.evWrongInputList,
2894 TosaErrorValidator.evWrongOutputList,
2895 TosaErrorValidator.evDimensionMismatch,
2896 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002897 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002898 "mul": {
2899 "op": Op.MUL,
2900 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002901 "build_fcn": (
2902 build_mul,
2903 TosaTensorGen.tgBroadcastFuzz,
2904 TosaTensorValuesGen.tvgMul,
2905 TosaArgGen.agMul,
2906 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002907 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002908 "error_if_validators": (
2909 TosaErrorValidator.evWrongInputType,
2910 TosaErrorValidator.evWrongOutputType,
2911 TosaErrorValidator.evWrongInputList,
2912 TosaErrorValidator.evWrongOutputList,
2913 TosaErrorValidator.evRankMismatch,
2914 TosaErrorValidator.evDimensionMismatch,
2915 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002916 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002917 "pow": {
2918 "op": Op.POW,
2919 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002920 "build_fcn": (
2921 build_binary_broadcast,
2922 TosaTensorGen.tgBroadcastFuzz,
2923 TosaTensorValuesGen.tvgDefault,
2924 None,
2925 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002926 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002927 "error_if_validators": (
2928 TosaErrorValidator.evRankMismatch,
2929 TosaErrorValidator.evWrongInputType,
2930 TosaErrorValidator.evWrongOutputType,
2931 TosaErrorValidator.evWrongInputList,
2932 TosaErrorValidator.evWrongOutputList,
2933 TosaErrorValidator.evDimensionMismatch,
2934 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002935 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002936 "sub": {
2937 "op": Op.SUB,
2938 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002939 "build_fcn": (
2940 build_binary_broadcast,
2941 TosaTensorGen.tgBroadcastFuzz,
2942 TosaTensorValuesGen.tvgAddSub,
2943 None,
2944 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002945 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002946 "error_if_validators": (
2947 TosaErrorValidator.evRankMismatch,
2948 TosaErrorValidator.evWrongInputType,
2949 TosaErrorValidator.evWrongOutputType,
2950 TosaErrorValidator.evWrongInputList,
2951 TosaErrorValidator.evWrongOutputList,
2952 TosaErrorValidator.evDimensionMismatch,
2953 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002954 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002955 "table": {
2956 "op": Op.TABLE,
2957 # Use the automatic generation functions to create the input array
2958 # but create the table tensor in the build function, as it may be
2959 # a different type from the input
2960 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002961 "build_fcn": (
2962 build_table,
2963 TosaTensorGen.tgBasic,
2964 TosaTensorValuesGen.tvgDefault,
2965 TosaArgGen.agTable,
2966 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002967 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002968 "error_if_validators": (
2969 TosaErrorValidator.evWrongInputType,
2970 TosaErrorValidator.evWrongOutputType,
2971 TosaErrorValidator.evWrongInputList,
2972 TosaErrorValidator.evWrongOutputList,
2973 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002974 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002975 # Elementwise Unary operators
2976 "abs": {
2977 "op": Op.ABS,
2978 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002979 "build_fcn": (
2980 build_unary,
2981 TosaTensorGen.tgBasic,
2982 TosaTensorValuesGen.tvgDefault,
2983 None,
2984 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002985 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002986 "error_if_validators": (
2987 TosaErrorValidator.evWrongInputType,
2988 TosaErrorValidator.evWrongOutputType,
2989 TosaErrorValidator.evWrongInputList,
2990 TosaErrorValidator.evWrongOutputList,
2991 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002992 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002993 "bitwise_not": {
2994 "op": Op.BITWISE_NOT,
2995 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002996 "build_fcn": (
2997 build_unary,
2998 TosaTensorGen.tgBasic,
2999 TosaTensorValuesGen.tvgDefault,
3000 None,
3001 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003002 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003003 "error_if_validators": (
3004 TosaErrorValidator.evWrongInputType,
3005 TosaErrorValidator.evWrongOutputType,
3006 TosaErrorValidator.evWrongInputList,
3007 TosaErrorValidator.evWrongOutputList,
3008 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003009 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003010 "ceil": {
3011 "op": Op.CEIL,
3012 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003013 "build_fcn": (
3014 build_unary,
3015 TosaTensorGen.tgBasic,
3016 TosaTensorValuesGen.tvgDefault,
3017 None,
3018 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003019 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003020 "error_if_validators": (
3021 TosaErrorValidator.evWrongInputType,
3022 TosaErrorValidator.evWrongOutputType,
3023 TosaErrorValidator.evWrongInputList,
3024 TosaErrorValidator.evWrongOutputList,
3025 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003026 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003027 "clz": {
3028 "op": Op.CLZ,
3029 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003030 "build_fcn": (
3031 build_unary,
3032 TosaTensorGen.tgBasic,
3033 TosaTensorValuesGen.tvgDefault,
3034 None,
3035 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003036 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003037 "error_if_validators": (
3038 TosaErrorValidator.evWrongInputType,
3039 TosaErrorValidator.evWrongOutputType,
3040 TosaErrorValidator.evWrongInputList,
3041 TosaErrorValidator.evWrongOutputList,
3042 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003043 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003044 "exp": {
3045 "op": Op.EXP,
3046 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003047 "build_fcn": (
3048 build_unary,
3049 TosaTensorGen.tgBasic,
3050 TosaTensorValuesGen.tvgDefault,
3051 None,
3052 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003053 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003054 "error_if_validators": (
3055 TosaErrorValidator.evWrongInputType,
3056 TosaErrorValidator.evWrongOutputType,
3057 TosaErrorValidator.evWrongInputList,
3058 TosaErrorValidator.evWrongOutputList,
3059 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003060 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003061 "floor": {
3062 "op": Op.FLOOR,
3063 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003064 "build_fcn": (
3065 build_unary,
3066 TosaTensorGen.tgBasic,
3067 TosaTensorValuesGen.tvgDefault,
3068 None,
3069 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003070 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003071 "error_if_validators": (
3072 TosaErrorValidator.evWrongInputType,
3073 TosaErrorValidator.evWrongOutputType,
3074 TosaErrorValidator.evWrongInputList,
3075 TosaErrorValidator.evWrongOutputList,
3076 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003077 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003078 "log": {
3079 "op": Op.LOG,
3080 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003081 "build_fcn": (
3082 build_unary,
3083 TosaTensorGen.tgBasic,
3084 TosaTensorValuesGen.tvgDefault,
3085 None,
3086 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003087 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003088 "error_if_validators": (
3089 TosaErrorValidator.evWrongInputType,
3090 TosaErrorValidator.evWrongOutputType,
3091 TosaErrorValidator.evWrongInputList,
3092 TosaErrorValidator.evWrongOutputList,
3093 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003094 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003095 "logical_not": {
3096 "op": Op.LOGICAL_NOT,
3097 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003098 "build_fcn": (
3099 build_unary,
3100 TosaTensorGen.tgBasic,
3101 TosaTensorValuesGen.tvgDefault,
3102 None,
3103 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003104 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003105 "error_if_validators": (
3106 TosaErrorValidator.evWrongInputType,
3107 TosaErrorValidator.evWrongOutputType,
3108 TosaErrorValidator.evWrongInputList,
3109 TosaErrorValidator.evWrongOutputList,
3110 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003111 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003112 "negate": {
3113 "op": Op.NEGATE,
3114 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003115 "build_fcn": (
3116 build_unary,
3117 TosaTensorGen.tgBasic,
3118 TosaTensorValuesGen.tvgNegate,
3119 None,
3120 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003121 "qgen": TosaQuantGen.qgUnary,
3122 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003123 "error_if_validators": (
3124 TosaErrorValidator.evInputZeroPointNotZero,
3125 TosaErrorValidator.evOutputZeroPointNotZero,
3126 TosaErrorValidator.evWrongInputType,
3127 TosaErrorValidator.evWrongOutputType,
3128 TosaErrorValidator.evWrongInputList,
3129 TosaErrorValidator.evWrongOutputList,
3130 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003131 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003132 "reciprocal": {
3133 "op": Op.RECIPROCAL,
3134 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003135 "build_fcn": (
3136 build_unary,
3137 TosaTensorGen.tgBasic,
3138 TosaTensorValuesGen.tvgDefault,
3139 None,
3140 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003141 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003142 "error_if_validators": (
3143 TosaErrorValidator.evWrongInputType,
3144 TosaErrorValidator.evWrongOutputType,
3145 TosaErrorValidator.evWrongInputList,
3146 TosaErrorValidator.evWrongOutputList,
3147 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003148 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003149 "rsqrt": {
3150 "op": Op.RSQRT,
3151 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003152 "build_fcn": (
3153 build_unary,
3154 TosaTensorGen.tgBasic,
3155 TosaTensorValuesGen.tvgDefault,
3156 None,
3157 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003158 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003159 "error_if_validators": (
3160 TosaErrorValidator.evWrongInputType,
3161 TosaErrorValidator.evWrongOutputType,
3162 TosaErrorValidator.evWrongInputList,
3163 TosaErrorValidator.evWrongOutputList,
3164 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003165 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003166 # Elementwise Ternary operators
3167 "select": {
3168 "op": Op.SELECT,
3169 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003170 "build_fcn": (
3171 build_select,
3172 TosaTensorGen.tgBroadcastFuzz,
3173 TosaTensorValuesGen.tvgSelect,
3174 None,
3175 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003176 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003177 "error_if_validators": (
3178 TosaErrorValidator.evRankMismatch,
3179 TosaErrorValidator.evWrongInputType,
3180 TosaErrorValidator.evWrongOutputType,
3181 TosaErrorValidator.evWrongInputList,
3182 TosaErrorValidator.evWrongOutputList,
3183 TosaErrorValidator.evDimensionMismatch,
3184 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003185 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003186 # Comparison operators
3187 "equal": {
3188 "op": Op.EQUAL,
3189 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003190 "build_fcn": (
3191 build_comparison,
3192 TosaTensorGen.tgBroadcastFuzz,
3193 TosaTensorValuesGen.tvgEqual,
3194 None,
3195 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003196 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003197 "error_if_validators": (
3198 TosaErrorValidator.evRankMismatch,
3199 TosaErrorValidator.evWrongInputType,
3200 TosaErrorValidator.evWrongOutputType,
3201 TosaErrorValidator.evWrongInputList,
3202 TosaErrorValidator.evWrongOutputList,
3203 TosaErrorValidator.evDimensionMismatch,
3204 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003205 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003206 "greater_equal": {
3207 "op": Op.GREATER_EQUAL,
3208 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003209 "build_fcn": (
3210 build_comparison,
3211 TosaTensorGen.tgBroadcastFuzz,
3212 TosaTensorValuesGen.tvgDefault,
3213 None,
3214 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003215 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003216 "error_if_validators": (
3217 TosaErrorValidator.evRankMismatch,
3218 TosaErrorValidator.evWrongInputType,
3219 TosaErrorValidator.evWrongOutputType,
3220 TosaErrorValidator.evWrongInputList,
3221 TosaErrorValidator.evWrongOutputList,
3222 TosaErrorValidator.evDimensionMismatch,
3223 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003224 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003225 "greater": {
3226 "op": Op.GREATER,
3227 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003228 "build_fcn": (
3229 build_comparison,
3230 TosaTensorGen.tgBroadcastFuzz,
3231 TosaTensorValuesGen.tvgDefault,
3232 None,
3233 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003234 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003235 "error_if_validators": (
3236 TosaErrorValidator.evRankMismatch,
3237 TosaErrorValidator.evWrongInputType,
3238 TosaErrorValidator.evWrongOutputType,
3239 TosaErrorValidator.evWrongInputList,
3240 TosaErrorValidator.evWrongOutputList,
3241 TosaErrorValidator.evDimensionMismatch,
3242 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003243 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003244 # Reduction operators
3245 "reduce_all": {
3246 "op": Op.REDUCE_ALL,
3247 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003248 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003249 "build_fcn": (
3250 build_reduce,
3251 TosaTensorGen.tgBasic,
3252 TosaTensorValuesGen.tvgDefault,
3253 TosaArgGen.agAxis,
3254 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003255 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003256 "error_if_validators": (
3257 TosaErrorValidator.evAxisLargerRank,
3258 TosaErrorValidator.evAxisSmallerZero,
3259 TosaErrorValidator.evShapeOfAxisNotOne,
3260 TosaErrorValidator.evWrongInputType,
3261 TosaErrorValidator.evWrongOutputType,
3262 TosaErrorValidator.evWrongRank,
3263 TosaErrorValidator.evWrongInputList,
3264 TosaErrorValidator.evWrongOutputList,
3265 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003266 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003267 "reduce_any": {
3268 "op": Op.REDUCE_ANY,
3269 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003270 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003271 "build_fcn": (
3272 build_reduce,
3273 TosaTensorGen.tgBasic,
3274 TosaTensorValuesGen.tvgDefault,
3275 TosaArgGen.agAxis,
3276 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003277 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003278 "error_if_validators": (
3279 TosaErrorValidator.evAxisLargerRank,
3280 TosaErrorValidator.evAxisSmallerZero,
3281 TosaErrorValidator.evShapeOfAxisNotOne,
3282 TosaErrorValidator.evWrongInputType,
3283 TosaErrorValidator.evWrongOutputType,
3284 TosaErrorValidator.evWrongRank,
3285 TosaErrorValidator.evWrongInputList,
3286 TosaErrorValidator.evWrongOutputList,
3287 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003288 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003289 "reduce_max": {
3290 "op": Op.REDUCE_MAX,
3291 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003292 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003293 "build_fcn": (
3294 build_reduce,
3295 TosaTensorGen.tgBasic,
3296 TosaTensorValuesGen.tvgDefault,
3297 TosaArgGen.agAxis,
3298 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003299 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003300 "error_if_validators": (
3301 TosaErrorValidator.evAxisLargerRank,
3302 TosaErrorValidator.evAxisSmallerZero,
3303 TosaErrorValidator.evShapeOfAxisNotOne,
3304 TosaErrorValidator.evWrongInputType,
3305 TosaErrorValidator.evWrongOutputType,
3306 TosaErrorValidator.evWrongRank,
3307 TosaErrorValidator.evWrongInputList,
3308 TosaErrorValidator.evWrongOutputList,
3309 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003310 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003311 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003312 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003313 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003314 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003315 "build_fcn": (
3316 build_reduce,
3317 TosaTensorGen.tgBasic,
3318 TosaTensorValuesGen.tvgDefault,
3319 TosaArgGen.agAxis,
3320 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003321 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003322 "error_if_validators": (
3323 TosaErrorValidator.evAxisLargerRank,
3324 TosaErrorValidator.evAxisSmallerZero,
3325 TosaErrorValidator.evShapeOfAxisNotOne,
3326 TosaErrorValidator.evWrongInputType,
3327 TosaErrorValidator.evWrongOutputType,
3328 TosaErrorValidator.evWrongRank,
3329 TosaErrorValidator.evWrongInputList,
3330 TosaErrorValidator.evWrongOutputList,
3331 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003332 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003333 "reduce_product": {
3334 "op": Op.REDUCE_PRODUCT,
3335 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003336 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003337 "build_fcn": (
3338 build_reduce,
3339 TosaTensorGen.tgBasic,
3340 TosaTensorValuesGen.tvgDefault,
3341 TosaArgGen.agAxis,
3342 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003344 "error_if_validators": (
3345 TosaErrorValidator.evAxisLargerRank,
3346 TosaErrorValidator.evAxisSmallerZero,
3347 TosaErrorValidator.evShapeOfAxisNotOne,
3348 TosaErrorValidator.evWrongInputType,
3349 TosaErrorValidator.evWrongOutputType,
3350 TosaErrorValidator.evWrongRank,
3351 TosaErrorValidator.evWrongInputList,
3352 TosaErrorValidator.evWrongOutputList,
3353 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003354 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003355 "reduce_sum": {
3356 "op": Op.REDUCE_SUM,
3357 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003358 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003359 "build_fcn": (
3360 build_reduce,
3361 TosaTensorGen.tgBasic,
3362 TosaTensorValuesGen.tvgReduceSum,
3363 TosaArgGen.agAxis,
3364 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003365 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003366 "error_if_validators": (
3367 TosaErrorValidator.evAxisLargerRank,
3368 TosaErrorValidator.evAxisSmallerZero,
3369 TosaErrorValidator.evShapeOfAxisNotOne,
3370 TosaErrorValidator.evWrongInputType,
3371 TosaErrorValidator.evWrongOutputType,
3372 TosaErrorValidator.evWrongRank,
3373 TosaErrorValidator.evWrongInputList,
3374 TosaErrorValidator.evWrongOutputList,
3375 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003376 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003377 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003378 "concat": {
3379 "op": Op.CONCAT,
3380 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003381 "build_fcn": (
3382 build_concat,
3383 TosaTensorGen.tgConcat,
3384 TosaTensorValuesGen.tvgConcat,
3385 TosaArgGen.agAxis,
3386 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003387 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003388 "error_if_validators": (
3389 TosaErrorValidator.evAxisLargerRank,
3390 TosaErrorValidator.evAxisSmallerZero,
3391 TosaErrorValidator.evConcatInputRankMismatch,
3392 TosaErrorValidator.evConcatShapeSumMismatch,
3393 TosaErrorValidator.evConcatInputDimMismatch,
3394 TosaErrorValidator.evWrongInputType,
3395 TosaErrorValidator.evWrongOutputType,
3396 TosaErrorValidator.evWrongOutputList,
3397 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003398 },
3399 "pad": {
3400 "op": Op.PAD,
3401 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003402 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003403 "build_fcn": (
3404 build_pad,
3405 TosaTensorGen.tgBasic,
3406 TosaTensorValuesGen.tvgDefault,
3407 TosaArgGen.agPad,
3408 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003409 "qgen": TosaQuantGen.qgPad,
3410 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003411 "error_if_validators": (
3412 TosaErrorValidator.evWrongInputType,
3413 TosaErrorValidator.evPadSmallerZero,
3414 TosaErrorValidator.evWrongOutputType,
3415 TosaErrorValidator.evWrongInputList,
3416 TosaErrorValidator.evWrongOutputList,
3417 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003418 },
3419 "reshape": {
3420 "op": Op.RESHAPE,
3421 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003422 "build_fcn": (
3423 build_reshape,
3424 TosaTensorGen.tgBasic,
3425 TosaTensorValuesGen.tvgDefault,
3426 TosaArgGen.agReshape,
3427 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003428 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003429 "error_if_validators": (
3430 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3431 TosaErrorValidator.evWrongInputType,
3432 TosaErrorValidator.evWrongOutputType,
3433 TosaErrorValidator.evWrongInputList,
3434 TosaErrorValidator.evWrongOutputList,
3435 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003436 },
3437 "reverse": {
3438 "op": Op.REVERSE,
3439 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003440 "build_fcn": (
3441 build_reverse,
3442 TosaTensorGen.tgBasic,
3443 TosaTensorValuesGen.tvgDefault,
3444 TosaArgGen.agAxis,
3445 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003446 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003447 "error_if_validators": (
3448 TosaErrorValidator.evAxisSmallerZero,
3449 TosaErrorValidator.evAxisLargerRank,
3450 TosaErrorValidator.evWrongInputType,
3451 TosaErrorValidator.evWrongOutputType,
3452 TosaErrorValidator.evWrongInputList,
3453 TosaErrorValidator.evWrongOutputList,
3454 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003455 },
3456 "slice": {
3457 "op": Op.SLICE,
3458 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003459 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003460 "build_fcn": (
3461 build_slice,
3462 TosaTensorGen.tgBasic,
3463 TosaTensorValuesGen.tvgDefault,
3464 TosaArgGen.agSlice,
3465 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003466 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003467 "error_if_validators": (
3468 TosaErrorValidator.evStartSmallerZero,
3469 TosaErrorValidator.evSizeSmallerEqualZero,
3470 TosaErrorValidator.evStartSizeOutsideBounds,
3471 TosaErrorValidator.evSizeOutputShapeMismatch,
3472 TosaErrorValidator.evInputSizeStartLengthMismatch,
3473 TosaErrorValidator.evWrongRank,
3474 TosaErrorValidator.evWrongInputType,
3475 TosaErrorValidator.evWrongOutputType,
3476 TosaErrorValidator.evWrongInputList,
3477 TosaErrorValidator.evWrongOutputList,
3478 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003479 },
3480 "tile": {
3481 "op": Op.TILE,
3482 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003483 "build_fcn": (
3484 build_tile,
3485 TosaTensorGen.tgBasic,
3486 TosaTensorValuesGen.tvgDefault,
3487 TosaArgGen.agTile,
3488 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003489 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003490 "error_if_validators": (
3491 TosaErrorValidator.evWrongInputType,
3492 TosaErrorValidator.evWrongOutputType,
3493 TosaErrorValidator.evWrongInputList,
3494 TosaErrorValidator.evWrongOutputList,
3495 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003496 },
3497 "transpose": {
3498 "op": Op.TRANSPOSE,
3499 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003500 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003501 "build_fcn": (
3502 build_transpose,
3503 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003504 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003505 TosaArgGen.agTranspose,
3506 ),
3507 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003508 "error_if_validators": (
3509 TosaErrorValidator.evIndexOutsideBounds,
3510 TosaErrorValidator.evIndexUsedTwice,
3511 TosaErrorValidator.evWrongInputType,
3512 TosaErrorValidator.evWrongOutputType,
3513 TosaErrorValidator.evWrongInputList,
3514 TosaErrorValidator.evWrongOutputList,
3515 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003516 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003517 # Data nodes
3518 "const": {
3519 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003520 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003521 "build_fcn": (
3522 build_const,
3523 TosaTensorGen.tgBasic,
3524 TosaTensorValuesGen.tvgDefault,
3525 None,
3526 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003527 "types": TYPE_FIB,
3528 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003529 "identity": {
3530 "op": Op.IDENTITY,
3531 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003532 "build_fcn": (
3533 build_unary,
3534 TosaTensorGen.tgBasic,
3535 TosaTensorValuesGen.tvgDefault,
3536 None,
3537 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003538 "types": TYPE_FIB,
3539 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003540 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003541 "gather": {
3542 "op": Op.GATHER,
3543 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3544 "operands": (1, 0),
3545 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003546 "build_fcn": (
3547 build_gather,
3548 TosaTensorGen.tgBasic,
3549 TosaTensorValuesGen.tvgDefault,
3550 None,
3551 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003552 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003553 "error_if_validators": (
3554 TosaErrorValidator.evWrongInputType,
3555 TosaErrorValidator.evWrongOutputType,
3556 TosaErrorValidator.evWrongInputList,
3557 TosaErrorValidator.evWrongOutputList,
3558 TosaErrorValidator.evWrongRank,
3559 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003560 },
3561 "scatter": {
3562 "op": Op.SCATTER,
3563 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003564 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003565 "operands": (2, 0),
3566 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003567 "build_fcn": (
3568 build_scatter,
3569 TosaTensorGen.tgScatter,
3570 TosaTensorValuesGen.tvgDefault,
3571 None,
3572 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003573 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003574 "error_if_validators": (
3575 TosaErrorValidator.evWrongInputType,
3576 TosaErrorValidator.evWrongOutputType,
3577 TosaErrorValidator.evWrongInputList,
3578 TosaErrorValidator.evWrongOutputList,
3579 TosaErrorValidator.evWrongRank,
3580 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003581 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003582 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003583 "resize": {
3584 "op": Op.RESIZE,
3585 "operands": (1, 0),
3586 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003587 "build_fcn": (
3588 build_resize,
3589 TosaTensorGen.tgNHWC,
3590 TosaTensorValuesGen.tvgDefault,
3591 TosaArgGen.agResize,
3592 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003593 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003594 "invalid_test_validators": (
3595 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
3596 TosaInvalidValidator.ivBadStride,
3597 ),
3598 "error_if_validators": (
3599 TosaErrorValidator.evMaxDimExceeded,
3600 TosaErrorValidator.evStrideSmallerEqualZero,
3601 TosaErrorValidator.evStrideLargerDimension,
3602 TosaErrorValidator.evStrideLargerEqualMax,
3603 TosaErrorValidator.evOffsetSmallerEqualMin,
3604 TosaErrorValidator.evOffsetLargerEqualMax,
3605 TosaErrorValidator.evShiftNotZero,
3606 TosaErrorValidator.evShiftSmallerOne,
3607 TosaErrorValidator.evShiftLargerEleven,
3608 TosaErrorValidator.evWrongInputType,
3609 TosaErrorValidator.evWrongOutputType,
3610 TosaErrorValidator.evWrongRank,
3611 TosaErrorValidator.evWrongInputList,
3612 TosaErrorValidator.evWrongOutputList,
3613 TosaErrorValidator.evBatchMismatch,
3614 TosaErrorValidator.evChannelMismatch,
3615 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003616 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003617 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003618 "cast": {
3619 "op": Op.CAST,
3620 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003621 "build_fcn": (
3622 build_cast,
3623 TosaTensorGen.tgBasic,
3624 TosaTensorValuesGen.tvgDefault,
3625 TosaArgGen.agCast,
3626 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003627 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003628 "error_if_validators": (
3629 TosaErrorValidator.evWrongInputType,
3630 TosaErrorValidator.evWrongOutputType,
3631 TosaErrorValidator.evWrongInputList,
3632 TosaErrorValidator.evWrongOutputList,
3633 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003634 },
3635 "rescale": {
3636 "op": Op.RESCALE,
3637 "operands": (1, 0),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003638 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 "build_fcn": (
3640 build_rescale,
3641 TosaTensorGen.tgBasic,
3642 TosaTensorValuesGen.tvgDefault,
3643 TosaArgGen.agRescale,
3644 ),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01003645 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003646 "error_if_validators": (
3647 TosaErrorValidator.evInputZeroPointNotZero,
3648 TosaErrorValidator.evOutputZeroPointNotZero,
3649 TosaErrorValidator.evScaleTrue,
3650 TosaErrorValidator.evScaleNotTrue,
3651 TosaErrorValidator.evWrongInputType,
3652 TosaErrorValidator.evWrongOutputType,
3653 TosaErrorValidator.evWrongRank,
3654 TosaErrorValidator.evWrongInputList,
3655 TosaErrorValidator.evWrongOutputList,
3656 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003657 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003658 # Custom
3659 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003660 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003661 # Two varients of cond_if, one that generates one of two constant tensors (no
3662 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3663 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003664 "cond_if_const": {
3665 "op": Op.COND_IF,
3666 "operands": (0, 2),
3667 "build_fcn": (
3668 build_cond_if_const,
3669 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003670 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003671 TosaArgGen.agCondIf,
3672 ),
3673 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003674 "error_if_validators": (
3675 TosaErrorValidator.evOutputListThenGraphMismatch,
3676 TosaErrorValidator.evOutputListElseGraphMismatch,
3677 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003678 },
3679 "cond_if_binary": {
3680 "op": Op.COND_IF,
3681 "operands": (2, 0),
3682 "build_fcn": (
3683 build_cond_if_binary,
3684 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003685 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003686 TosaArgGen.agCondIf,
3687 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003688 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003689 "error_if_validators": (
3690 TosaErrorValidator.evInputListThenGraphMismatch,
3691 TosaErrorValidator.evInputListElseGraphMismatch,
3692 TosaErrorValidator.evOutputListThenGraphMismatch,
3693 TosaErrorValidator.evOutputListElseGraphMismatch,
3694 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003695 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003696 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003697 "while_loop": {
3698 "op": Op.WHILE_LOOP,
3699 "operands": (0, 1),
3700 "build_fcn": (
3701 build_while_loop,
3702 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003703 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003704 TosaArgGen.agWhileLoop,
3705 ),
3706 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003707 "error_if_validators": (
3708 TosaErrorValidator.evInputListOutputListMismatch,
3709 TosaErrorValidator.evInputListCondGraphMismatch,
3710 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3711 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3712 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
3713 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003714 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003715 }
3716
Kevin Cheng550ccc52021-03-03 11:21:43 -08003717
Eric Kunzee5e26762020-10-13 16:11:07 -07003718class OutputShaper:
3719 # Methods in this class compute the expected output shape and datatype
3720 # for common classes of operations
3721 def __init__(self):
3722 pass
3723
3724 # These methods return arguments that can be used for
3725 # creating a new output tensor
3726 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003727 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3728 if error_name != ErrorIf.RankMismatch:
3729 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003730 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003731
3732 shape = []
3733 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003735 shape.append(b.shape[i])
3736 else:
3737 shape.append(a.shape[i])
3738
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003739 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003740 all_dtypes = [
3741 DType.INT8,
3742 DType.INT16,
3743 DType.INT32,
3744 DType.INT48,
3745 DType.FLOAT,
3746 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003747 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3748 outputDType = rng.choice(wrong_dtypes)
3749 else:
3750 outputDType = a.dtype
3751
3752 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003753
3754 @staticmethod
3755 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003756 assert len(a.shape) == len(b.shape)
3757 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003758
3759 shape = []
3760 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003761 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003762 shape.append(a.shape[i])
3763
Kevin Cheng550ccc52021-03-03 11:21:43 -08003764 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003765
3766 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003767 def unaryOp(ser, rng, a, error_name=None):
3768 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003769 all_dtypes = [
3770 DType.INT8,
3771 DType.INT16,
3772 DType.INT32,
3773 DType.INT48,
3774 DType.FLOAT,
3775 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003776 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3777 outputDType = rng.choice(wrong_dtypes)
3778 else:
3779 outputDType = a.dtype
3780
3781 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003782
3783 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003784 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003785 if error_name != ErrorIf.RankMismatch:
3786 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003787 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003788
3789 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003790 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003791 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003792 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3793 else:
3794 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003795
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003796 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003797 all_dtypes = [
3798 DType.INT8,
3799 DType.INT16,
3800 DType.INT32,
3801 DType.INT48,
3802 DType.FLOAT,
3803 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003804 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3805 outputDType = rng.choice(wrong_dtypes)
3806 else:
3807 outputDType = a.dtype
3808
3809 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003810
3811 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003812 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003813 if error_name != ErrorIf.RankMismatch:
3814 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003815 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003816
3817 # Do broadcast
3818 shape = []
3819 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08003820 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07003821 shape.append(b.shape[i])
3822 else:
3823 shape.append(a.shape[i])
3824
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003825 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003826 wrong_dtypes = [
3827 DType.INT8,
3828 DType.INT16,
3829 DType.INT32,
3830 DType.INT48,
3831 DType.FLOAT,
3832 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003833 outputDType = rng.choice(wrong_dtypes)
3834 else:
3835 outputDType = DType.BOOL
3836
3837 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003838
3839 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01003840 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003841 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003842 if error_name not in [
3843 ErrorIf.AxisSmallerZero,
3844 ErrorIf.AxisLargerRank,
3845 ErrorIf.ShapeOfAxisNotOne,
3846 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01003847 shape[axis] = 1
3848 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
3849 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07003850
Matthew Haddond6ce7252021-09-29 15:35:44 +01003851 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003852 all_dtypes = [
3853 DType.INT8,
3854 DType.INT16,
3855 DType.INT32,
3856 DType.INT48,
3857 DType.FLOAT,
3858 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01003859 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3860 outputDType = rng.choice(wrong_dtypes)
3861 else:
3862 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003863
Matthew Haddond6ce7252021-09-29 15:35:44 +01003864 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003865
3866 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003867 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003868 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003869
3870 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
3871 del shape[axis]
3872
3873 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
3874 remove = rng.choice([True, False])
3875 if remove and len(shape) > 1:
3876 del shape[0]
3877 else:
3878 shape.append(1)
3879 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
3880 for i in range(len(shape)):
3881 shape[i] = shape[i] + rng.integers(1, 10)
3882
3883 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003884 all_dtypes = [
3885 DType.INT8,
3886 DType.INT16,
3887 DType.INT32,
3888 DType.INT48,
3889 DType.FLOAT,
3890 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003891 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
3892 outputDType = rng.choice(wrong_dtypes)
3893 else:
3894 outputDType = DType.INT32
3895
3896 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003897
3898 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003899 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003900
3901 # IFM: NHWC
3902 # Filter: OHWI
3903 # OFM: NHWC
3904
Kevin Cheng550ccc52021-03-03 11:21:43 -08003905 h = (
3906 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003907 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003908 + padding[0]
3909 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003910 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003911 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003912
Kevin Cheng550ccc52021-03-03 11:21:43 -08003913 w = (
3914 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003915 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003916 + padding[2]
3917 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003918 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003919 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003920
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003921 if error_name == ErrorIf.ConvOutputShapeMismatch:
3922 choices = [1, 2, 3]
3923 change = rng.choice(choices)
3924 # increment in multiples of stride to not hit non-integer error case
3925 if change in [1, 3]:
3926 h = h + (rng.choice(choices) * strides[0])
3927 if change in [2, 3]:
3928 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00003929
Eric Kunzee5e26762020-10-13 16:11:07 -07003930 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3931
Kevin Cheng3a478572021-01-22 17:21:02 -08003932 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003933 out_dtype = DType.INT32
3934 elif ifm.dtype == DType.INT16:
3935 out_dtype = DType.INT48
3936 elif ifm.dtype == DType.FLOAT:
3937 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00003938 elif error_name == ErrorIf.WrongInputType:
3939 # Pick some potentially correct output dtype if input type is incorrect
3940 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07003941 else:
Les Bell0e027d42021-11-09 14:42:14 +00003942 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
3943
3944 if error_name == ErrorIf.WrongOutputType:
3945 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
3946 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07003947
Kevin Cheng550ccc52021-03-03 11:21:43 -08003948 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003949
3950 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003951 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07003952
3953 # IFM: NDHWC
3954 # Filter: ODHWI
3955 # OFM: NDHWC
3956
3957 d = (
3958 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003959 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003960 + padding[0]
3961 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003962 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07003963 ) // strides[0] + 1
3964
3965 h = (
3966 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003967 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003968 + padding[2]
3969 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003970 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07003971 ) // strides[1] + 1
3972
3973 w = (
3974 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003975 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003976 + padding[4]
3977 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003978 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07003979 ) // strides[2] + 1
3980
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003981 if error_name == ErrorIf.ConvOutputShapeMismatch:
3982 choices = [1, 2, 3, 4]
3983 change = rng.choice(choices)
3984 # increment in multiples of stride to not hit non-integer error case
3985 if change in [1, 4]:
3986 d = d + (rng.choice(choices) * strides[0])
3987 if change in [2, 4]:
3988 h = h + (rng.choice(choices) * strides[1])
3989 if change in [3, 4]:
3990 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00003991
Kevin Cheng1533b852021-09-01 12:51:58 -07003992 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
3993
3994 if ifm.dtype == DType.INT8:
3995 out_dtype = DType.INT32
3996 elif ifm.dtype == DType.INT16:
3997 out_dtype = DType.INT48
3998 elif ifm.dtype == DType.FLOAT:
3999 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004000 elif error_name == ErrorIf.WrongInputType:
4001 # Pick some potentially correct output dtype if input type is incorrect
4002 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004003 else:
Les Bell0e027d42021-11-09 14:42:14 +00004004 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4005
4006 if error_name == ErrorIf.WrongOutputType:
4007 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4008 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004009
4010 return ser.addOutput(ofm_shape, out_dtype)
4011
4012 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004013 def depthwiseConv2dOp(
4014 ser, rng, ifm, filter, strides, padding, dilations, error_name=None
4015 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004016 # IFM: NHWC
4017 # Filter: HWCM
4018 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004019
Kevin Cheng550ccc52021-03-03 11:21:43 -08004020 h = (
4021 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004022 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004023 + padding[0]
4024 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004025 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004026 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004027
Kevin Cheng550ccc52021-03-03 11:21:43 -08004028 w = (
4029 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004030 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004031 + padding[2]
4032 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004033 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004034 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004035
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004036 if error_name == ErrorIf.ConvOutputShapeMismatch:
4037 choices = [1, 2, 3]
4038 change = rng.choice(choices)
4039 # increment in multiples of stride to not hit non-integer error case
4040 if change in [1, 3]:
4041 h = h + (rng.choice(choices) * strides[0])
4042 if change in [2, 3]:
4043 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004044
Eric Kunzee5e26762020-10-13 16:11:07 -07004045 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4046
Kevin Cheng3a478572021-01-22 17:21:02 -08004047 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004048 out_dtype = DType.INT32
4049 elif ifm.dtype == DType.INT16:
4050 out_dtype = DType.INT48
4051 elif ifm.dtype == DType.FLOAT:
4052 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004053 elif error_name == ErrorIf.WrongInputType:
4054 # Pick some potentially correct output dtype if input type is incorrect
4055 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004056 else:
Les Bell0e027d42021-11-09 14:42:14 +00004057 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4058
4059 if error_name == ErrorIf.WrongOutputType:
4060 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4061 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004062
Kevin Cheng550ccc52021-03-03 11:21:43 -08004063 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004064
4065 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004066 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004067 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004068 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004069 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004070 h = 1
4071 w = 1
4072 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004073 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4074 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004075
4076 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004077 choices = [1, 2, 3]
4078 change = rng.choice(choices)
4079 # increment in multiples of stride to not hit non-integer error case
4080 if change in [1, 3]:
4081 h = h + (rng.choice(choices) * stride[0])
4082 if change in [2, 3]:
4083 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004084 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004085
4086 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004087 all_dtypes = [
4088 DType.INT8,
4089 DType.INT16,
4090 DType.INT32,
4091 DType.INT48,
4092 DType.FLOAT,
4093 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004094 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4095 outputDType = rng.choice(wrong_dtypes)
4096 else:
4097 outputDType = ifm.dtype
4098
4099 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004100
4101 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004102 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004103 # input: N, IC
4104 # filter: OC, IC
4105 # output: N, OC
4106
4107 output_shape = [input.shape[0], filter.shape[0]]
4108
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004109 if error_name == ErrorIf.WrongOutputType:
4110 if input.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004111 incorrect_types = (
4112 DType.INT4,
4113 DType.INT8,
4114 DType.INT16,
4115 DType.INT48,
4116 DType.FLOAT,
4117 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004118 elif input.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004119 incorrect_types = (
4120 DType.INT4,
4121 DType.INT8,
4122 DType.INT16,
4123 DType.INT32,
4124 DType.FLOAT,
4125 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004126 elif input.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004127 incorrect_types = (
4128 DType.INT4,
4129 DType.INT8,
4130 DType.INT16,
4131 DType.INT32,
4132 DType.INT48,
4133 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004134 out_dtype = rng.choice(a=incorrect_types)
4135 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004136 out_dtype = DType.INT32
4137 elif input.dtype == DType.INT16:
4138 out_dtype = DType.INT48
4139 elif input.dtype == DType.FLOAT:
4140 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004141 elif error_name == ErrorIf.WrongInputType:
4142 # Pick some potentially correct output dtype if input type is incorrect
4143 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004144 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004145 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004146
Kevin Cheng550ccc52021-03-03 11:21:43 -08004147 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004148
4149 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004150 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004151 # a: N, H, C
4152 # b: N, C, W
4153 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004154
Kevin Cheng2d60f002021-06-09 14:18:32 -07004155 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004156
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004157 if error_name == ErrorIf.WrongOutputType:
4158 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004159 incorrect_types = (
4160 DType.INT4,
4161 DType.INT8,
4162 DType.INT16,
4163 DType.INT48,
4164 DType.FLOAT,
4165 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004166 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004167 incorrect_types = (
4168 DType.INT4,
4169 DType.INT8,
4170 DType.INT16,
4171 DType.INT32,
4172 DType.FLOAT,
4173 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004174 elif a.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004175 incorrect_types = (
4176 DType.INT4,
4177 DType.INT8,
4178 DType.INT16,
4179 DType.INT32,
4180 DType.INT48,
4181 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004182 out_dtype = rng.choice(a=incorrect_types)
4183 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004184 out_dtype = DType.INT32
4185 elif a.dtype == DType.INT16:
4186 out_dtype = DType.INT48
4187 elif a.dtype == DType.FLOAT:
4188 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004189 elif error_name == ErrorIf.WrongInputType:
4190 # Pick some potentially correct output dtype if input type is incorrect
4191 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004192 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004193 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004194
Kevin Cheng550ccc52021-03-03 11:21:43 -08004195 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004196
4197 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004198 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004199 input1 = a[0]
4200 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004201
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004202 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004203 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004204 if not (
4205 # unable to concat tensors of different ranks
4206 error_name == ErrorIf.ConcatInputRankMismatch
4207 # unable to concat tensors along an invalid axis
4208 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004209 ):
4210 for tensor in remaining_inputs:
4211 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004212
Matthew Haddon01c359d2021-10-15 16:30:48 +01004213 if error_name == ErrorIf.ConcatShapeSumMismatch:
4214 output_shape[axis] += rng.integers(5, 10)
4215
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004216 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004217 all_dtypes = {
4218 DType.INT8,
4219 DType.INT16,
4220 DType.INT32,
4221 DType.INT48,
4222 DType.FLOAT,
4223 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004224 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4225 outputDType = rng.choice(wrong_dtypes)
4226 else:
4227 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004228
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004229 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004230
4231 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004232 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004233
4234 output_shape = a.shape.copy()
4235
4236 for i in range(len(output_shape)):
4237 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4238
Matthew Haddone807aae2021-10-11 18:12:58 +01004239 # Fix negative output shape if error_if test causes it
4240 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
4241 output_shape = [i if i >= 1 else 1 for i in output_shape]
4242
4243 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004244 all_dtypes = [
4245 DType.INT8,
4246 DType.INT16,
4247 DType.INT32,
4248 DType.INT48,
4249 DType.FLOAT,
4250 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004251 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4252 outputDType = rng.choice(wrong_dtypes)
4253 else:
4254 outputDType = a.dtype
4255
4256 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004257
4258 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004259 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004260 output_shape = shape.copy()
4261
4262 totalElements = 1
4263 for i in a.shape:
4264 totalElements *= i
4265
4266 # If there are any -1 elements, figure out what that dimension must be
4267 totalOutputElements = 1
4268 for i in output_shape:
4269 if i != -1:
4270 totalOutputElements *= i
4271
4272 # And fill it in
4273 for i in range(len(output_shape)):
4274 if output_shape[i] == -1:
4275 output_shape[i] = totalElements // totalOutputElements
4276
Matthew Haddone807aae2021-10-11 18:12:58 +01004277 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4278 for i in range(len(output_shape)):
4279 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4280
4281 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004282 all_dtypes = [
4283 DType.INT8,
4284 DType.INT16,
4285 DType.INT32,
4286 DType.INT48,
4287 DType.FLOAT,
4288 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004289 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4290 outputDType = rng.choice(wrong_dtypes)
4291 else:
4292 outputDType = a.dtype
4293
4294 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004295
4296 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004297 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004298
Matthew Haddone807aae2021-10-11 18:12:58 +01004299 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004300 all_dtypes = [
4301 DType.INT8,
4302 DType.INT16,
4303 DType.INT32,
4304 DType.INT48,
4305 DType.FLOAT,
4306 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004307 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4308 outputDType = rng.choice(wrong_dtypes)
4309 else:
4310 outputDType = a.dtype
4311
4312 if error_name == ErrorIf.SizeOutputShapeMismatch:
4313 output_shape = size.copy()
4314 for index in range(len(output_shape)):
4315 if output_shape[index] <= 2:
4316 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4317 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004318 output_shape[index] = output_shape[index] + rng.choice(
4319 [-2, -1, 1, 2]
4320 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004321 else:
4322 output_shape = size.copy()
4323
4324 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004325
4326 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004327 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004328
4329 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004330 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004331
4332 for i in range(len(output_shape)):
4333 output_shape[i] = a.shape[i] * multiples[i]
4334
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004335 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004336 all_dtypes = [
4337 DType.INT8,
4338 DType.INT16,
4339 DType.INT32,
4340 DType.INT48,
4341 DType.FLOAT,
4342 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004343 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4344 outputDType = rng.choice(wrong_dtypes)
4345 else:
4346 outputDType = a.dtype
4347
4348 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004349
4350 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004351 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004352 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004353
Kevin Cheng550ccc52021-03-03 11:21:43 -08004354 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004355
Matthew Haddone807aae2021-10-11 18:12:58 +01004356 if error_name == ErrorIf.IndexOutsideBounds:
4357 for i in range(len(output_shape)):
4358 output_shape[i] = a.shape[0]
4359 else:
4360 for i in range(len(output_shape)):
4361 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004362
Matthew Haddone807aae2021-10-11 18:12:58 +01004363 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004364 all_dtypes = [
4365 DType.INT8,
4366 DType.INT16,
4367 DType.INT32,
4368 DType.INT48,
4369 DType.FLOAT,
4370 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004371 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4372 outputDType = rng.choice(wrong_dtypes)
4373 else:
4374 outputDType = a.dtype
4375
4376 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004377
4378 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004379 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004380 if error_name != ErrorIf.WrongRank:
4381 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004382 assert len(indices.shape) == 2
4383 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004384
Kevin Cheng77d0f762020-11-24 10:26:32 -08004385 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4386
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004387 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004388 all_dtypes = [
4389 DType.INT8,
4390 DType.INT16,
4391 DType.INT32,
4392 DType.INT48,
4393 DType.FLOAT,
4394 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004395 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4396 outputDType = rng.choice(wrong_dtypes)
4397 else:
4398 outputDType = values.dtype
4399
4400 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004401
4402 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004403 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004404 if error_name != ErrorIf.WrongRank:
4405 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004406 assert len(indices.shape) == 2
4407 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004408 assert values_in.shape[0] == indices.shape[0] # N
4409 assert input.shape[1] == indices.shape[1] # W
4410 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004411
4412 output_shape = values_in.shape
4413
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004414 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004415 all_dtypes = [
4416 DType.INT8,
4417 DType.INT16,
4418 DType.INT32,
4419 DType.INT48,
4420 DType.FLOAT,
4421 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004422 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4423 outputDType = rng.choice(wrong_dtypes)
4424 else:
4425 outputDType = values_in.dtype
4426
4427 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004428
4429 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004430 def tableOp(ser, rng, input, error_name=None):
4431 # Same shape as the input, dtype dependent on input dtype
4432 if error_name != ErrorIf.WrongInputType:
4433 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004434 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004435 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004436 wrong_dtypes = [
4437 DType.INT8,
4438 DType.INT16,
4439 DType.INT32,
4440 DType.INT48,
4441 DType.FLOAT,
4442 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004443 wrong_dtypes.remove(output_dtype)
4444 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004445 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004446
4447 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004448 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004449 serializer,
4450 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004451 input,
4452 mode,
4453 stride,
4454 offset,
4455 shift,
4456 stride_fp,
4457 offset_fp,
4458 output_dims,
4459 input_dtype,
4460 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004461 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004462 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01004463 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004464 output_dims = [
4465 input.shape[0],
4466 output_dims[0],
4467 output_dims[0],
4468 input.shape[0],
4469 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004470 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004471 if error_name == ErrorIf.BatchMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004472 output_dims = [
4473 input.shape[0] + rng.integers(1, 10),
4474 output_dims[0],
4475 output_dims[1],
4476 input.shape[3],
4477 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004478 elif error_name == ErrorIf.ChannelMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004479 output_dims = [
4480 input.shape[0],
4481 output_dims[0],
4482 output_dims[1],
4483 input.shape[3] + rng.integers(1, 10),
4484 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004485 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004486 output_dims = [
4487 input.shape[0],
4488 output_dims[0],
4489 output_dims[1],
4490 input.shape[3],
4491 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07004492
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004493 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004494
4495 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004496 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004497 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004498
4499 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00004500 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004501 if error_name == ErrorIf.ConvOutputShapeMismatch:
4502 choices = [1, 2, 3]
4503 change = rng.choice(choices)
4504 if change in [1, 3]:
4505 output_shape[1] = output_shape[1] + rng.choice(choices)
4506 if change in [2, 3]:
4507 output_shape[2] = output_shape[2] + rng.choice(choices)
4508
Kevin Cheng3a478572021-01-22 17:21:02 -08004509 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004510 out_dtype = DType.INT32
4511 elif ifm.dtype == DType.INT16:
4512 out_dtype = DType.INT48
4513 elif ifm.dtype == DType.FLOAT:
4514 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004515 elif error_name == ErrorIf.WrongInputType:
4516 # Pick some potentially correct output dtype if input type is incorrect
4517 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004518 else:
Les Bell0e027d42021-11-09 14:42:14 +00004519 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4520
4521 if error_name == ErrorIf.WrongOutputType:
4522 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4523 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004524
Kevin Cheng550ccc52021-03-03 11:21:43 -08004525 return ser.addOutput(output_shape, out_dtype)