blob: 583e1ed2827f7dfb92c81f54a533f0c16ebfa936 [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
16from generator.tosa_utils import usableDTypes
Les Bell0e027d42021-11-09 14:42:14 +000017from tosa.DType import DType
18from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010019
20
Eric Kunzee5e26762020-10-13 16:11:07 -070021class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010022 # Maximum rank of tensor supported by test generator.
23 TOSA_TENSOR_MAX_RANK = 6
24
Eric Kunzee5e26762020-10-13 16:11:07 -070025 def __init__(self, args):
26 self.args = args
27 self.basePath = args.output_dir
28 self.random_seed = args.random_seed
29 self.ser = None
30 self.rng = np.random.default_rng(self.random_seed)
31 self.createDynamicOpLists()
32 self.initOpListDefaults()
33 self.quantGen = TosaQuantGen()
34 # Force makeShape to do a specific starting shape
35 self.targetted_shape = None
36
37 def createSerializer(self, opName, testPath):
38 self.testPath = os.path.join(opName, testPath)
39
40 fullPath = os.path.join(self.basePath, self.testPath)
41 os.makedirs(fullPath, exist_ok=True)
42 self.ser = ts.TosaSerializer(fullPath)
43
44 def getSerializer(self):
45 return self.ser
46
47 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080048 with open(
49 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
50 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070051 fd.write(self.ser.serialize())
52
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
54 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070055
Matthew Haddon74567092021-07-16 15:38:20 +010056 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000057 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010058 seed = self.random_seed + 1
59 self.rng = np.random.default_rng(seed)
60
Eric Kunzee5e26762020-10-13 16:11:07 -070061 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070062 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070063 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070064 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070065 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070066 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070067 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010068 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
69 elif dtype == DType.UINT8:
70 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070071 elif dtype == DType.INT16:
72 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010073 elif dtype == DType.UINT16:
74 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070075 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080076 return np.int32(
77 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
78 )
Eric Kunzee5e26762020-10-13 16:11:07 -070079 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080080 return np.int64(
81 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
82 )
Eric Kunzee5e26762020-10-13 16:11:07 -070083 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +010084 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070085 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -080086 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070087
Kevin Cheng989cb052021-04-28 16:29:44 -070088 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -070089 placeholders = []
90
Kevin Cheng989cb052021-04-28 16:29:44 -070091 assert len(shape_list) == len(dtype_list)
92
93 for idx, shape in enumerate(shape_list):
94 arr = self.getRandTensor(shape, dtype_list[idx])
95 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -070096
97 return placeholders
98
Kevin Cheng989cb052021-04-28 16:29:44 -070099 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 consts = []
101
Kevin Cheng989cb052021-04-28 16:29:44 -0700102 assert len(shape_list) == len(dtype_list)
103
104 for idx, shape in enumerate(shape_list):
105 arr = self.getRandTensor(shape, dtype_list[idx])
106 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700107
108 return consts
109
110 def makeShape(self, rank):
111 if self.targetted_shape:
112 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800113 return np.int32(
114 self.rng.integers(
115 low=self.args.tensor_shape_range[0],
116 high=self.args.tensor_shape_range[1],
117 size=rank,
118 )
119 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 def setTargetShape(self, shape):
122 self.targetted_shape = shape
123
124 def randInt(self, low=0, high=256):
125 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
126
127 def getRandNumberDType(self, dtype):
128 if dtype == DType.FLOAT:
129 return self.rng.random()
130 elif dtype == DType.BOOL:
131 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700132 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700133 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700134 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100136 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700137 elif dtype == DType.INT16:
138 low, high = (-32768, 32768)
139 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800142 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 # Special size
144 return np.int64(self.rng.integers(low, high, size=1))[0]
145 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800146 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700147
148 return np.int32(self.rng.integers(low, high, size=1))[0]
149
150 def shapeStr(self, shape):
151
152 sStr = []
153 # Convert to strings
154 for i in shape:
155 sStr.append(str(i))
156
Kevin Cheng550ccc52021-03-03 11:21:43 -0800157 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700158
159 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -0700160 if isinstance(t, list):
161 assert len(t) >= 2
162 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 else:
Kevin Cheng989cb052021-04-28 16:29:44 -0700164 if t == DType.BOOL:
165 return "b"
166 elif t == DType.INT4:
167 return "i4"
168 elif t == DType.INT8:
169 return "i8"
170 elif t == DType.UINT8:
171 return "u8"
172 elif t == DType.INT16:
173 return "i16"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100174 elif t == DType.UINT16:
175 return "u16"
Kevin Cheng989cb052021-04-28 16:29:44 -0700176 elif t == DType.INT32:
177 return "i32"
178 elif t == DType.INT48:
179 return "i48"
180 elif t == DType.FLOAT:
181 return "float"
182 else:
183 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184
185 def typeWidth(self, t):
Jeremy Johnson5d1a3472022-03-31 09:50:06 +0100186 """Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -0800187 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -0700188 return 4
189 elif t == DType.INT8:
190 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -0800191 elif t == DType.UINT8:
192 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 elif t == DType.INT16:
194 return 16
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100195 elif t == DType.UINT16:
196 return 16
Eric Kunzee5e26762020-10-13 16:11:07 -0700197 elif t == DType.INT32:
198 return 32
199 elif t == DType.INT48:
200 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +0100201 elif t == DType.FLOAT:
202 return 32
203 elif t == DType.BOOL:
204 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700205 else:
Les Bell729b0352021-11-24 10:28:21 +0000206 raise Exception(f"Unknown dtype, cannot determine width: {t}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
208 # Argument generators
209 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
210 # Where the string descriptor is used to generate the test name and
211 # The build_fcn_arg_list is expanded and passed to the operator test
212 # build function
213
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100214 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
215 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
216
Matthew Haddon848efb42021-09-09 12:30:53 +0100217 # build_placeholder returns an int, ABS/other ops does not
218 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000219 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100220 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000221 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000222 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100223 return result_tens
224
225 # Ensure new output type has correct qinfo
226 if error_name == ErrorIf.WrongOutputType:
227 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000228 qinfo = [
229 TosaQuantGen.getZeroPoint(self, a.dtype),
230 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
231 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100232
233 # Invalidate Input/Output list for error if checks.
234 input_list = [a.name]
235 output_list = [result_tens.name]
236 pCount, cCount = op["operands"]
237 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000238 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
239 self, error_name, input_list, output_list
240 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100241
Les Bell729b0352021-11-24 10:28:21 +0000242 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100243 self.ser,
244 validator_fcns,
245 error_name,
246 op=op,
247 input_dtype=a.dtype,
248 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000249 qinfo=qinfo,
250 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100251 input_list=input_list,
252 output_list=output_list,
253 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000254 ):
255 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100256
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000257 attr = None
258 if op["op"] == Op.NEGATE:
259 attr = ts.TosaSerializerAttribute()
260 attr.NegateAttribute(qinfo[0], qinfo[1])
261
262 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700263 return result_tens
264
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100265 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000266 result_tens = OutputShaper.binaryBroadcastOp(
267 self.ser, self.rng, a, b, error_name
268 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100269
270 # Invalidate Input/Output list for error if checks.
271 input_list = [a.name, b.name]
272 output_list = [result_tens.name]
273 pCount, cCount = op["operands"]
274 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000275 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
276 self, error_name, input_list, output_list
277 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100278
Les Bell729b0352021-11-24 10:28:21 +0000279 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100280 self.ser,
281 validator_fcns,
282 error_name,
283 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000284 input1=a,
285 input2=b,
286 input_dtype=a.dtype,
287 output_dtype=result_tens.dtype,
288 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100289 input_list=input_list,
290 output_list=output_list,
291 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000292 ):
293 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100294
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000295 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700296 return result_tens
297
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100298 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700299 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000300 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700301 return result_tens
302
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000303 def build_arithmetic_right_shift(
304 self, op, a, b, round, validator_fcns=None, error_name=None
305 ):
306 result_tens = OutputShaper.binaryBroadcastOp(
307 self.ser, self.rng, a, b, error_name
308 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100309
310 # Invalidate Input/Output list for error if checks.
311 input_list = [a.name, b.name]
312 output_list = [result_tens.name]
313 pCount, cCount = op["operands"]
314 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000315 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
316 self, error_name, input_list, output_list
317 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100318
Les Bell729b0352021-11-24 10:28:21 +0000319 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100320 self.ser,
321 validator_fcns,
322 error_name,
323 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000324 input1=a,
325 input2=b,
326 input_dtype=a.dtype,
327 output_dtype=result_tens.dtype,
328 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100329 input_list=input_list,
330 output_list=output_list,
331 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000332 ):
333 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800334
335 attr = ts.TosaSerializerAttribute()
336 attr.ArithmeticRightShiftAttribute(round)
337
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000338 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800339 return result_tens
340
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100341 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000342 result_tens = OutputShaper.binaryBroadcastOp(
343 self.ser, self.rng, a, b, error_name
344 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700345
346 # Special for multiply:
347 # Force the result to INT32 for INT types
348 if a.dtype != DType.FLOAT:
349 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100350 if error_name == ErrorIf.WrongOutputType:
351 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
352 outputDType = self.rng.choice(all_dtypes)
353 result_tens.setDtype(outputDType)
354
355 # Invalidate Input/Output list for error if checks.
356 input_list = [a.name, b.name]
357 output_list = [result_tens.name]
358 pCount, cCount = op["operands"]
359 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000360 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
361 self, error_name, input_list, output_list
362 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100363
Les Bell729b0352021-11-24 10:28:21 +0000364 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100365 self.ser,
366 validator_fcns,
367 error_name,
368 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000369 input1=a,
370 input2=b,
371 input_dtype=a.dtype,
372 output_dtype=result_tens.dtype,
373 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100374 input_list=input_list,
375 output_list=output_list,
376 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000377 ):
378 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700379
Kevin Chengaee1fac2020-11-11 13:54:06 -0800380 attr = ts.TosaSerializerAttribute()
381 attr.MulAttribute(shift)
382
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000383 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700384 return result_tens
385
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100386 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
387 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700388
Kevin Chengfe392ce2021-10-18 21:51:55 +0000389 attr = ts.TosaSerializerAttribute()
390 attr.TableAttribute(table)
391
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100392 # Invalidate Input/Output list for error if checks.
393 input_list = [a.name]
394 output_list = [result_tens.name]
395 pCount, cCount = op["operands"]
396 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000397 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
398 self, error_name, input_list, output_list
399 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100400
Les Bell729b0352021-11-24 10:28:21 +0000401 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100402 self.ser,
403 validator_fcns,
404 error_name,
405 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000406 input_shape=a.shape,
407 input_dtype=a.dtype,
408 output_dtype=result_tens.dtype,
409 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410 input_list=input_list,
411 output_list=output_list,
412 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000413 ):
414 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100415
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000416 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700417
418 return result_tens
419
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100420 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
421 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
422
423 # Invalidate Input/Output list for error if checks.
424 input_list = [cond.name, a.name, b.name]
425 output_list = [result_tens.name]
426 pCount, cCount = op["operands"]
427 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000428 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
429 self, error_name, input_list, output_list
430 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100431
Les Bell729b0352021-11-24 10:28:21 +0000432 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100433 self.ser,
434 validator_fcns,
435 error_name,
436 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000437 input1=cond,
438 input2=a,
439 input3=b,
440 input_shape=a.shape,
441 input_dtype=a.dtype,
442 output_dtype=result_tens.dtype,
443 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100444 input_list=input_list,
445 output_list=output_list,
446 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000447 ):
448 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100449
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000450 self.ser.addOperator(
451 op["op"],
452 input_list,
453 output_list,
454 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700455 return result_tens
456
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000458 result_tens = OutputShaper.binaryComparisonOp(
459 self.ser, self.rng, a, b, error_name
460 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100461
462 # Invalidate Input/Output list for error if checks.
463 input_list = [a.name, b.name]
464 output_list = [result_tens.name]
465 pCount, cCount = op["operands"]
466 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000467 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
468 self, error_name, input_list, output_list
469 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100470
Les Bell729b0352021-11-24 10:28:21 +0000471 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100472 self.ser,
473 validator_fcns,
474 error_name,
475 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000476 input1=a,
477 input2=b,
478 input_shape=a.shape,
479 input_dtype=a.dtype,
480 output_shape=result_tens.shape,
481 output_dtype=result_tens.dtype,
482 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100483 input_list=input_list,
484 output_list=output_list,
485 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000486 ):
487 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100488
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000489 self.ser.addOperator(
490 op["op"],
491 input_list,
492 output_list,
493 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700494 return result_tens
495
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100496 def build_argmax(self, op, a, axis, validator_fcns, error_name):
497 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
498
499 # Invalidate Input/Output list for error if checks.
500 input_list = [a.name]
501 output_list = [result_tens.name]
502 pCount, cCount = op["operands"]
503 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000504 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
505 self, error_name, input_list, output_list
506 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100507
Les Bell729b0352021-11-24 10:28:21 +0000508 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100509 self.ser,
510 validator_fcns,
511 error_name,
512 op=op,
513 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000514 input_shape=a.shape,
515 input_dtype=a.dtype,
516 output_shape=result_tens.shape,
517 output_dtype=result_tens.dtype,
518 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100519 input_list=input_list,
520 output_list=output_list,
521 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000522 ):
523 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700524
525 attr = ts.TosaSerializerAttribute()
526 attr.AxisAttribute(axis)
527
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700529 return result_tens
530
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000531 def build_pool2d(
532 self,
533 op,
534 input,
535 stride,
536 pad,
537 kernel,
538 validator_fcns=None,
539 error_name=None,
540 qinfo=None,
541 ):
542 result_tens = OutputShaper.pool2dOp(
543 self.ser, self.rng, input, kernel, stride, pad, error_name
544 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100545
546 # Ensure new output type has correct qinfo
547 if error_name == ErrorIf.WrongInputType:
548 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000549 qinfo = [
550 TosaQuantGen.getZeroPoint(self, input.dtype),
551 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
552 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100553
554 # Invalidate Input/Output list for error if checks.
555 input_list = [input.name]
556 output_list = [result_tens.name]
557 pCount, cCount = op["operands"]
558 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000559 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
560 self, error_name, input_list, output_list
561 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100562
Les Bell729b0352021-11-24 10:28:21 +0000563 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100564 self.ser,
565 validator_fcns,
566 error_name,
567 op=op,
568 input_shape=input.shape,
569 input_dtype=input.dtype,
570 output_shape=result_tens.shape,
571 output_dtype=result_tens.dtype,
572 kernel=kernel,
573 stride=stride,
574 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000575 qinfo=qinfo,
576 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100577 input_list=input_list,
578 output_list=output_list,
579 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000580 ):
581 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700582
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000583 if qinfo is None:
584 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700585
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000586 attr = ts.TosaSerializerAttribute()
587 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1])
588
589 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700590 return result_tens
591
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000592 def build_conv2d(
593 self,
594 op,
595 ifm,
596 filter,
597 bias,
598 strides,
599 padding,
600 dilations,
601 validator_fcns=None,
602 error_name=None,
603 qinfo=None,
604 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800605 assert len(padding) == 4
606 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000607 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
608 )
609
610 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000611 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
612 DType.INT8,
613 DType.UINT8,
614 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000615 qinfo = [
616 TosaQuantGen.getZeroPoint(self, ifm.dtype),
617 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
618 ]
Les Bell0e027d42021-11-09 14:42:14 +0000619
620 # Invalidate Input/Output list for error_if checks.
621 input_list = [ifm.name, filter.name, bias.name]
622 output_list = [result_tens.name]
623 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000624 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
625 self, error_name, input_list, output_list
626 )
Les Bell0e027d42021-11-09 14:42:14 +0000627
Les Bell729b0352021-11-24 10:28:21 +0000628 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000629 self.ser,
630 validator_fcns,
631 error_name,
632 op=op,
633 input_dtype=ifm.dtype,
634 weight_dtype=filter.dtype,
635 output_dtype=result_tens.dtype,
636 qinfo=qinfo,
637 input_list=input_list,
638 num_operands=num_operands,
639 output_list=output_list,
640 pad=padding,
641 stride=strides,
642 dilation=dilations,
643 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100644 weight_shape=filter.shape,
645 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000646 ):
647 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700648
649 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000650 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700651
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000652 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700653 return result_tens
654
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000655 def build_conv3d(
656 self,
657 op,
658 ifm,
659 filter,
660 bias,
661 strides,
662 padding,
663 dilations,
664 validator_fcns=None,
665 error_name=None,
666 qinfo=None,
667 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700668 assert len(padding) == 6
669 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000670 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
671 )
672
673 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000674 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
675 DType.INT8,
676 DType.UINT8,
677 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000678 qinfo = [
679 TosaQuantGen.getZeroPoint(self, ifm.dtype),
680 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
681 ]
Les Bell0e027d42021-11-09 14:42:14 +0000682
683 # Invalidate Input/Output list for error_if checks.
684 input_list = [ifm.name, filter.name, bias.name]
685 output_list = [result_tens.name]
686 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000687 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
688 self, error_name, input_list, output_list
689 )
Les Bell0e027d42021-11-09 14:42:14 +0000690
Les Bell729b0352021-11-24 10:28:21 +0000691 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000692 self.ser,
693 validator_fcns,
694 error_name,
695 op=op,
696 input_dtype=ifm.dtype,
697 weight_dtype=filter.dtype,
698 output_dtype=result_tens.dtype,
699 qinfo=qinfo,
700 input_list=input_list,
701 num_operands=num_operands,
702 output_list=output_list,
703 pad=padding,
704 stride=strides,
705 dilation=dilations,
706 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100707 weight_shape=filter.shape,
708 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000709 ):
710 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700711
712 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000713 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700714
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000715 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700716 return result_tens
717
Kevin Cheng550ccc52021-03-03 11:21:43 -0800718 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000719 self,
720 op,
721 ifm,
722 filter,
723 bias,
724 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700725 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000726 output_shape,
727 validator_fcns=None,
728 error_name=None,
729 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800730 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700731 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000732 result_tens = OutputShaper.transposeConv2DOp(
733 self.ser, self.rng, ifm, output_shape, error_name
734 )
Les Bell0e027d42021-11-09 14:42:14 +0000735
736 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000737 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
738 DType.INT8,
739 DType.UINT8,
740 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000741 qinfo = [
742 TosaQuantGen.getZeroPoint(self, ifm.dtype),
743 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
744 ]
Les Bell0e027d42021-11-09 14:42:14 +0000745
746 # Invalidate Input/Output list for error_if checks.
747 input_list = [ifm.name, filter.name, bias.name]
748 output_list = [result_tens.name]
749 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000750 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
751 self, error_name, input_list, output_list
752 )
Les Bell0e027d42021-11-09 14:42:14 +0000753
Les Bell729b0352021-11-24 10:28:21 +0000754 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000755 self.ser,
756 validator_fcns,
757 error_name,
758 op=op,
759 input_dtype=ifm.dtype,
760 weight_dtype=filter.dtype,
761 output_dtype=result_tens.dtype,
762 qinfo=qinfo,
763 input_list=input_list,
764 num_operands=num_operands,
765 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700766 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000767 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000768 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100769 weight_shape=filter.shape,
770 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000771 ):
772 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700773
774 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000775 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700776
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000777 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700778 return result_tens
779
Kevin Cheng550ccc52021-03-03 11:21:43 -0800780 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000781 self,
782 op,
783 ifm,
784 filter,
785 bias,
786 strides,
787 padding,
788 dilations,
789 validator_fcns=None,
790 error_name=None,
791 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800792 ):
793 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000794 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
795 )
796
797 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000798 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
799 DType.INT8,
800 DType.UINT8,
801 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000802 qinfo = [
803 TosaQuantGen.getZeroPoint(self, ifm.dtype),
804 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
805 ]
Les Bell0e027d42021-11-09 14:42:14 +0000806
807 # Invalidate Input/Output list for error_if checks.
808 input_list = [ifm.name, filter.name, bias.name]
809 output_list = [result_tens.name]
810 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000811 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
812 self, error_name, input_list, output_list
813 )
Les Bell0e027d42021-11-09 14:42:14 +0000814
Les Bell729b0352021-11-24 10:28:21 +0000815 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000816 self.ser,
817 validator_fcns,
818 error_name,
819 op=op,
820 input_dtype=ifm.dtype,
821 weight_dtype=filter.dtype,
822 output_dtype=result_tens.dtype,
823 qinfo=qinfo,
824 input_list=input_list,
825 num_operands=num_operands,
826 output_list=output_list,
827 pad=padding,
828 stride=strides,
829 dilation=dilations,
830 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100831 weight_shape=filter.shape,
832 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000833 ):
834 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700835
836 attr = ts.TosaSerializerAttribute()
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000837 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700838
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000839 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700840 return result_tens
841
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000842 def build_fully_connected(
843 self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
844 ):
845 result_tens = OutputShaper.fullyConnectedOp(
846 self.ser, self.rng, ifm, filter, error_name
847 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100848
849 # Invalidate Input/Output list for error if checks.
850 input_list = [ifm.name, filter.name, bias.name]
851 output_list = [result_tens.name]
852 pCount, cCount = op["operands"]
853 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000854 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
855 self, error_name, input_list, output_list
856 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100857
Les Bell729b0352021-11-24 10:28:21 +0000858 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100859 self.ser,
860 validator_fcns,
861 error_name,
862 op=op,
863 input_shape=ifm.shape,
864 input_dtype=ifm.dtype,
865 weight_dtype=filter.dtype,
866 output_shape=result_tens.shape,
867 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000868 qinfo=qinfo,
869 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100870 input_list=input_list,
871 output_list=output_list,
872 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000873 ):
874 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700875
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000876 attr = ts.TosaSerializerAttribute()
877 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
878
879 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700880 return result_tens
881
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100882 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
883 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
884
885 # Invalidate Input/Output list for error if checks.
886 input_list = [a.name, b.name]
887 output_list = [result_tens.name]
888 pCount, cCount = op["operands"]
889 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000890 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
891 self, error_name, input_list, output_list
892 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100893
Les Bell729b0352021-11-24 10:28:21 +0000894 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100895 self.ser,
896 validator_fcns,
897 error_name,
898 op=op,
899 input_shape=a.shape,
900 input_dtype=a.dtype,
901 input2_shape=b.shape,
902 input2_dtype=b.dtype,
903 output_shape=result_tens.shape,
904 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000905 qinfo=qinfo,
906 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100907 input_list=input_list,
908 output_list=output_list,
909 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000910 ):
911 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100912
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000913 attr = ts.TosaSerializerAttribute()
914 attr.MatMulAttribute(qinfo[0], qinfo[1])
915
916 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700917 return result_tens
918
Matthew Haddond6ce7252021-09-29 15:35:44 +0100919 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
920 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
921
922 # Invalidate Input/Output list for error if checks.
923 input_list = [a.name]
924 output_list = [result_tens.name]
925 pCount, cCount = op["operands"]
926 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000927 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
928 self, error_name, input_list, output_list
929 )
Matthew Haddond6ce7252021-09-29 15:35:44 +0100930
Les Bell729b0352021-11-24 10:28:21 +0000931 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +0100932 self.ser,
933 validator_fcns,
934 error_name,
935 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000936 axis=axis,
937 input_shape=a.shape,
938 output_shape=result_tens.shape,
939 input_dtype=a.dtype,
940 output_dtype=result_tens.dtype,
941 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +0100942 input_list=input_list,
943 output_list=output_list,
944 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000945 ):
946 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700947
948 attr = ts.TosaSerializerAttribute()
949 attr.AxisAttribute(axis)
950
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000951 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700952 return result_tens
953
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100954 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
955 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700956
Jeremy Johnson18e26662021-07-22 16:15:29 +0100957 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700958
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100959 if error_name == ErrorIf.MaxSmallerMin:
960 # Make sure the numbers are different to invoke this error
961 while v[0] == v[1]:
962 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
963 max_val = min(v)
964 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700965 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100966 max_val = max(v)
967 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700968
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100969 # Invalidate Input/Output list for error if checks.
970 input_list = [a.name]
971 output_list = [result_tens.name]
972 pCount, cCount = op["operands"]
973 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000974 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
975 self, error_name, input_list, output_list
976 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100977
Les Bell729b0352021-11-24 10:28:21 +0000978 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100979 self.ser,
980 validator_fcns,
981 error_name,
982 op=op,
983 max_val=max_val,
984 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000985 input_shape=a.shape,
986 output_shape=result_tens.shape,
987 input_dtype=a.dtype,
988 output_dtype=result_tens.dtype,
989 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100990 input_list=input_list,
991 output_list=output_list,
992 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000993 ):
994 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100995
996 attr = ts.TosaSerializerAttribute()
997 if a.dtype == DType.FLOAT:
998 attr.ClampAttribute(0, 0, min_val, max_val)
999 else:
1000 attr.ClampAttribute(min_val, max_val, 0, 0)
1001
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001002 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001003 return result_tens
1004
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001005 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1006 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001007 attr = ts.TosaSerializerAttribute()
1008
1009 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1010
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001011 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001012 return result_tens
1013
1014 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001015 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1016 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001017
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001018 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001019 return result_tens
1020
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001021 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1022 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1023
1024 # Invalidate Input/Output list for error if checks.
1025 input_list = [a.name]
1026 output_list = [result_tens.name]
1027 pCount, cCount = op["operands"]
1028 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001029 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1030 self, error_name, input_list, output_list
1031 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001032
Les Bell729b0352021-11-24 10:28:21 +00001033 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001034 self.ser,
1035 validator_fcns,
1036 error_name,
1037 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001038 input_shape=a.shape,
1039 output_shape=result_tens.shape,
1040 input_dtype=a.dtype,
1041 output_dtype=result_tens.dtype,
1042 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001043 input_list=input_list,
1044 output_list=output_list,
1045 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001046 ):
1047 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001048
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001049 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001050 return result_tens
1051
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001052 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1053 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1054
1055 # Invalidate Input/Output list for error if checks.
1056 input_list = [a.name]
1057 output_list = [result_tens.name]
1058 pCount, cCount = op["operands"]
1059 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001060 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1061 self, error_name, input_list, output_list
1062 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001063
Les Bell729b0352021-11-24 10:28:21 +00001064 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001065 self.ser,
1066 validator_fcns,
1067 error_name,
1068 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001069 input_shape=a.shape,
1070 output_shape=result_tens.shape,
1071 input_dtype=a.dtype,
1072 output_dtype=result_tens.dtype,
1073 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001074 input_list=input_list,
1075 output_list=output_list,
1076 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001077 ):
1078 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001079
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001080 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001081 return result_tens
1082
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001083 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1084 if error_name != ErrorIf.WrongInputType:
1085 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001086
1087 # To store variable length list of input tensors we need to store axis along with it
1088 axis = a[-1]
1089 a = a[:-1]
1090
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001091 result_tens = OutputShaper.concatOp(
1092 self.ser, self.rng, axis, *a, error_name=error_name
1093 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001094
Matthew Haddon818ab902021-07-27 09:12:49 +01001095 input_tensor_names = []
1096 for tensor in a:
1097 input_tensor_names.append(tensor.name)
1098
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001099 # Invalidate Input/Output list for error if checks.
1100 input_list = input_tensor_names
1101 output_list = [result_tens.name]
1102 pCount, cCount = op["operands"]
1103 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001104 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1105 self, error_name, input_list, output_list
1106 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001107
Les Bell729b0352021-11-24 10:28:21 +00001108 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001109 self.ser,
1110 validator_fcns,
1111 error_name,
1112 op=op,
1113 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001114 input_shape=a[0].shape,
1115 output_shape=result_tens.shape,
1116 input_dtype=a[0].dtype,
1117 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001118 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001119 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001120 input_list=input_list,
1121 output_list=output_list,
1122 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001123 ):
1124 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001125
1126 attr = ts.TosaSerializerAttribute()
1127 attr.AxisAttribute(axis)
1128
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001129 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001130 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001131
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001132 def build_pad(
1133 self,
1134 op,
1135 a,
1136 padding,
1137 pad_const_int,
1138 pad_const_float,
1139 validator_fcns=None,
1140 error_name=None,
1141 qinfo=None,
1142 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001143 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001144
Kevin Chengfe392ce2021-10-18 21:51:55 +00001145 attr = ts.TosaSerializerAttribute()
1146 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001147
Matthew Haddone807aae2021-10-11 18:12:58 +01001148 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001149 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001150 output_list = [result_tens.name]
1151 pCount, cCount = op["operands"]
1152 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001153 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1154 self, error_name, input_list, output_list
1155 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001156
Les Bell729b0352021-11-24 10:28:21 +00001157 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001158 self.ser,
1159 validator_fcns,
1160 error_name,
1161 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001162 input_shape=a.shape,
1163 output_shape=result_tens.shape,
1164 input_dtype=a.dtype,
1165 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001166 pad=padding,
1167 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001168 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001169 input_list=input_list,
1170 output_list=output_list,
1171 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001172 ):
1173 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001174
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001175 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001176 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001177
Matthew Haddone807aae2021-10-11 18:12:58 +01001178 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001179 result_tens = OutputShaper.reshapeOp(
1180 self.ser, self.rng, a, newShape, error_name
1181 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001182
1183 # Invalidate Input/Output list for error if checks.
1184 input_list = [a.name]
1185 output_list = [result_tens.name]
1186 pCount, cCount = op["operands"]
1187 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001188 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1189 self, error_name, input_list, output_list
1190 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001191
Les Bell729b0352021-11-24 10:28:21 +00001192 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001193 self.ser,
1194 validator_fcns,
1195 error_name,
1196 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 input_shape=a.shape,
1198 output_shape=result_tens.shape,
1199 input_dtype=a.dtype,
1200 output_dtype=result_tens.dtype,
1201 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001202 input_list=input_list,
1203 output_list=output_list,
1204 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001205 ):
1206 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001207
1208 attr = ts.TosaSerializerAttribute()
1209 attr.ReshapeAttribute(newShape)
1210
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001211 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001212 return result_tens
1213
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001214 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1215 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1216
1217 # Invalidate Input/Output list for error if checks.
1218 input_list = [a.name]
1219 output_list = [result_tens.name]
1220 pCount, cCount = op["operands"]
1221 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001222 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1223 self, error_name, input_list, output_list
1224 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001225
Les Bell729b0352021-11-24 10:28:21 +00001226 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001227 self.ser,
1228 validator_fcns,
1229 error_name,
1230 op=op,
1231 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001232 input_shape=a.shape,
1233 output_shape=result_tens.shape,
1234 input_dtype=a.dtype,
1235 output_dtype=result_tens.dtype,
1236 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001237 input_list=input_list,
1238 output_list=output_list,
1239 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001240 ):
1241 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001242
1243 attr = ts.TosaSerializerAttribute()
1244 attr.AxisAttribute(axis)
1245
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001246 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001247 return result_tens
1248
Matthew Haddone807aae2021-10-11 18:12:58 +01001249 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1250 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001251
Kevin Chengfe392ce2021-10-18 21:51:55 +00001252 attr = ts.TosaSerializerAttribute()
1253 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001254
Matthew Haddone807aae2021-10-11 18:12:58 +01001255 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001256 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001257 output_list = [result_tens.name]
1258 pCount, cCount = op["operands"]
1259 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001260 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1261 self, error_name, input_list, output_list
1262 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001263
Les Bell729b0352021-11-24 10:28:21 +00001264 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001265 self.ser,
1266 validator_fcns,
1267 error_name,
1268 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001269 input_shape=a.shape,
1270 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001271 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001272 input_dtype=a.dtype,
1273 output_dtype=result_tens.dtype,
1274 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001275 input_list=input_list,
1276 output_list=output_list,
1277 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001278 ):
1279 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001280
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001281 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001282 return result_tens
1283
Matthew Haddone807aae2021-10-11 18:12:58 +01001284 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001285 result_tens = OutputShaper.sliceOp(
1286 self.ser, self.rng, a, start, size, error_name
1287 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001288
1289 # Invalidate Input/Output list for error if checks.
1290 input_list = [a.name]
1291 output_list = [result_tens.name]
1292 pCount, cCount = op["operands"]
1293 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001294 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1295 self, error_name, input_list, output_list
1296 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001297
Les Bell729b0352021-11-24 10:28:21 +00001298 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001299 self.ser,
1300 validator_fcns,
1301 error_name,
1302 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001303 input_shape=a.shape,
1304 output_shape=result_tens.shape,
1305 input_dtype=a.dtype,
1306 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001307 start=start,
1308 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001309 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001310 input_list=input_list,
1311 output_list=output_list,
1312 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001313 ):
1314 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001315
1316 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001317 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001318
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001319 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001320 return result_tens
1321
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1323 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1324
1325 # Invalidate Input/Output list for error if checks.
1326 input_list = [a.name]
1327 output_list = [result_tens.name]
1328 pCount, cCount = op["operands"]
1329 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001330 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1331 self, error_name, input_list, output_list
1332 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001333
Les Bell729b0352021-11-24 10:28:21 +00001334 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001335 self.ser,
1336 validator_fcns,
1337 error_name,
1338 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001339 input_shape=a.shape,
1340 output_shape=result_tens.shape,
1341 input_dtype=a.dtype,
1342 output_dtype=result_tens.dtype,
1343 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001344 input_list=input_list,
1345 output_list=output_list,
1346 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001347 ):
1348 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001349
1350 attr = ts.TosaSerializerAttribute()
1351 attr.TileAttribute(multiples)
1352
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001353 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001354 return result_tens
1355
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001356 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001357
1358 # Create a new indicies tensor
1359 # here with data that doesn't exceed the dimensions of the values tensor
1360
Kevin Cheng550ccc52021-03-03 11:21:43 -08001361 K = values.shape[1] # K
1362 W = self.randInt(
1363 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1364 ) # W
1365 indicies_arr = np.int32(
1366 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1367 ) # (N, W)
1368 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001369
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001370 result_tens = OutputShaper.gatherOp(
1371 self.ser, self.rng, values, indicies, error_name
1372 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001373
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374 # Invalidate Input/Output list for error if checks.
1375 input_list = [values.name, indicies.name]
1376 output_list = [result_tens.name]
1377 pCount, cCount = op["operands"]
1378 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001379 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1380 self, error_name, input_list, output_list
1381 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001382
Les Bell729b0352021-11-24 10:28:21 +00001383 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001384 self.ser,
1385 validator_fcns,
1386 error_name,
1387 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001388 input_shape=values.shape,
1389 output_shape=result_tens.shape,
1390 input_dtype=values.dtype,
1391 output_dtype=result_tens.dtype,
1392 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001393 input_list=input_list,
1394 output_list=output_list,
1395 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001396 ):
1397 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001398
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001399 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001400
1401 return result_tens
1402
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001403 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001404
1405 # Create a new indicies tensor
1406 # here with data that doesn't exceed the dimensions of the values_in tensor
1407
Kevin Cheng550ccc52021-03-03 11:21:43 -08001408 K = values_in.shape[1] # K
1409 W = input.shape[1] # W
1410 indicies_arr = np.int32(
1411 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1412 ) # (N, W)
1413 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001414
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001415 result_tens = OutputShaper.scatterOp(
1416 self.ser, self.rng, values_in, indicies, input, error_name
1417 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001418
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001419 # Invalidate Input/Output list for error if checks.
1420 input_list = [values_in.name, indicies.name, input.name]
1421 output_list = [result_tens.name]
1422 pCount, cCount = op["operands"]
1423 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001424 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1425 self, error_name, input_list, output_list
1426 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001427
Les Bell729b0352021-11-24 10:28:21 +00001428 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001429 self.ser,
1430 validator_fcns,
1431 error_name,
1432 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001433 input_shape=values_in.shape,
1434 output_shape=result_tens.shape,
1435 input_dtype=values_in.dtype,
1436 output_dtype=result_tens.dtype,
1437 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001438 input_list=input_list,
1439 output_list=output_list,
1440 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001441 ):
1442 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001443
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001444 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001445
Kevin Cheng77d0f762020-11-24 10:26:32 -08001446 return result_tens
1447
Kevin Cheng550ccc52021-03-03 11:21:43 -08001448 def build_resize(
1449 self,
1450 op,
1451 input,
1452 mode,
1453 stride,
1454 offset,
1455 shift,
1456 stride_fp,
1457 offset_fp,
1458 output_dims,
1459 input_dtype,
1460 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001461 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001462 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001463 ):
1464 result_tens = OutputShaper.resizeOp(
1465 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001466 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001467 input,
1468 mode,
1469 stride,
1470 offset,
1471 shift,
1472 stride_fp,
1473 offset_fp,
1474 output_dims,
1475 input_dtype,
1476 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001477 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001478 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001479
Matthew Haddon848efb42021-09-09 12:30:53 +01001480 # Invalidate Input/Output list for error if checks.
1481 input_list = [input.name]
1482 output_list = [result_tens.name]
1483 pCount, cCount = op["operands"]
1484 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001485 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1486 self, error_name, input_list, output_list
1487 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001488
Les Bell729b0352021-11-24 10:28:21 +00001489 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001490 self.ser,
1491 validator_fcns,
1492 error_name,
1493 op=op,
1494 mode=mode,
1495 shift=shift,
1496 input_dtype=input_dtype,
1497 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001498 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001499 output_shape=output_dims,
1500 offset=offset,
1501 offset_fp=offset_fp,
1502 stride=stride,
1503 stride_fp=stride_fp,
1504 input_list=input_list,
1505 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001506 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001507 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001508 ):
1509 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001510
Eric Kunzee5e26762020-10-13 16:11:07 -07001511 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001512
Kevin Cheng550ccc52021-03-03 11:21:43 -08001513 attr.ResizeAttribute(
1514 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1515 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001516
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001517 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001518 return result_tens
1519
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001520 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1521 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1522 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001523 self.ser.addOperator(
1524 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1525 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001526 return result_tens
1527
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001528 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001529 self.ser.addOutputTensor(val)
1530 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001531
1532 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001533 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001534 result_tens = OutputShaper.typeConversionOp(
1535 self.ser, self.rng, val, out_dtype, error_name
1536 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001537
1538 # Invalidate Input/Output list for error if checks.
1539 input_list = [val.name]
1540 output_list = [result_tens.name]
1541 pCount, cCount = op["operands"]
1542 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001543 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1544 self, error_name, input_list, output_list
1545 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001546
Les Bell729b0352021-11-24 10:28:21 +00001547 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001548 self.ser,
1549 validator_fcns,
1550 error_name,
1551 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001552 input_shape=val.shape,
1553 output_shape=result_tens.shape,
1554 input_dtype=val.dtype,
1555 output_dtype=result_tens.dtype,
1556 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001557 input_list=input_list,
1558 output_list=output_list,
1559 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001560 ):
1561 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001562
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001563 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001564 return result_tens
1565
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001566 def build_rescale(
1567 self,
1568 op,
1569 val,
1570 out_dtype,
1571 scale32,
1572 double_round,
1573 per_channel,
1574 validator_fcns,
1575 error_name,
1576 ):
1577 result_tens = OutputShaper.typeConversionOp(
1578 self.ser, self.rng, val, out_dtype, error_name
1579 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001580
1581 if per_channel:
1582 nc = val.shape[-1]
1583 else:
1584 nc = 1
1585
1586 in_type_width = self.typeWidth(val.dtype)
1587 out_type_width = self.typeWidth(out_dtype)
1588
Kevin Cheng3a478572021-01-22 17:21:02 -08001589 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001590 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001591 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001592 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001593 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001594 in_type_width += 1
1595 elif error_name in [
1596 ErrorIf.InputZeroPointNotZero,
1597 ErrorIf.U16InputZeroPointNotValid,
1598 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001599 input_zp = self.randInt(-128, 128)
1600 if input_zp == 0:
1601 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001602 in_type_width += 1
1603 elif val.dtype == DType.UINT16:
1604 # Must come after ErrorIf.U16InputZeroPointNotValid check
1605 input_zp = self.rng.choice([0, 32768])
1606 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001607 else:
1608 input_zp = 0
1609
Kevin Cheng3a478572021-01-22 17:21:02 -08001610 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001611 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001612 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001613 elif out_dtype == DType.UINT8:
1614 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001615 out_type_width += 1
1616 elif error_name in [
1617 ErrorIf.OutputZeroPointNotZero,
1618 ErrorIf.U16OutputZeroPointNotValid,
1619 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001620 output_zp = self.randInt(-128, 128)
1621 if output_zp == 0:
1622 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001623 out_type_width += 1
1624 elif out_dtype == DType.UINT16:
1625 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1626 output_zp = self.rng.choice([0, 32768])
1627 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001628 else:
1629 output_zp = 0
1630
1631 # Calculate scale based on:
1632 # scale = a *(2^output_width)/(2^input_width))
1633
1634 a = np.float32(self.rng.random(size=[nc]))
1635 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1636
1637 if scale32:
1638 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001639 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001640 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1641 else:
1642 # Cap the scaling at 2^15 - 1 for scale16
1643 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1644
Kevin Cheng550ccc52021-03-03 11:21:43 -08001645 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001646
1647 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1648 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001649 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1650 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001651
1652 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001653 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1654 scale_arr[i], scale32
1655 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001656 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1657 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001658
Kevin Cheng550ccc52021-03-03 11:21:43 -08001659 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001660 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001661 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001662 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001663 assert val.placeholderFilename
1664 values = np.load(
1665 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1666 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001667 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1668 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1669 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1670 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001671 if not np.all(np.array_equal(values, val_adj)):
1672 # Values changed so overwrite file with new values
1673 np.save(
1674 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1675 val_adj,
1676 False,
1677 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001678
Matthew Haddonc2025212021-10-08 21:21:05 +01001679 # Invalidate Input/Output list for error if checks.
1680 input_list = [val.name]
1681 output_list = [result_tens.name]
1682 pCount, cCount = op["operands"]
1683 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001684 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1685 self, error_name, input_list, output_list
1686 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001687
1688 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001689 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001690 self.ser,
1691 validator_fcns,
1692 error_name,
1693 op=op,
1694 input_dtype=val.dtype,
1695 output_dtype=out_dtype,
1696 input_shape=val.shape,
1697 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001698 scale32=scale32,
1699 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001700 input_list=input_list,
1701 output_list=output_list,
1702 result_tensor=result_tens,
1703 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001704 ):
1705 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001706
Eric Kunzee5e26762020-10-13 16:11:07 -07001707 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001708 attr.RescaleAttribute(
1709 input_zp,
1710 output_zp,
1711 multiplier_arr,
1712 shift_arr,
1713 scale32,
1714 double_round,
1715 per_channel,
1716 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001717
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001718 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001719 return result_tens
1720
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001721 def build_cond_if_const(
1722 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1723 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001724 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1725 # (except for the generated shap) and the condition. Build Then/Else blocks
1726 # and fill them with const nodes for the body.
1727
1728 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001729 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001730
1731 # Make then/else tensors
1732 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001733
1734 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001735 if error_name in [
1736 ErrorIf.CondIfOutputListThenGraphMismatch,
1737 ErrorIf.CondIfOutputListElseGraphMismatch,
1738 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001739 incorrect_shape = deepcopy(then_tens.shape)
1740 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001741 incorrect_shape[i] += (
1742 self.rng.choice([-3, -2, 2, 3])
1743 if incorrect_shape[i] > 3
1744 else self.rng.choice([1, 2, 4])
1745 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001746 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1747
Jeremy Johnson18e26662021-07-22 16:15:29 +01001748 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1749 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001750
1751 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001752 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001753
1754 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001755 then_block = "THEN_BLOCK"
1756 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001757 attr = ts.TosaSerializerAttribute()
1758 attr.CondIfAttribute(then_block, else_block)
1759
1760 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001761 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001762
1763 self.ser.startBasicBlock(then_block)
1764 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001765 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1766 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1767 else:
1768 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001769 self.ser.addOutputTensor(then_tens)
1770
1771 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001772 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1773 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1774 else:
1775 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001776 self.ser.addOutputTensor(else_tens)
1777
Les Bell729b0352021-11-24 10:28:21 +00001778 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001779 self.ser,
1780 validator_fcns,
1781 error_name,
1782 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001783 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001784 ):
1785 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001786
Eric Kunzee5e26762020-10-13 16:11:07 -07001787 return result_tens
1788
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001789 def build_cond_if_binary(
1790 self, op, a, b, cond, validator_fcns=None, error_name=None
1791 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001792 # For cond_if with a binary op in the then/else blocks, take a and b and
1793 # alternately add or subtract them based on the condition
1794
1795 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001796 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001797
Kevin Cheng550ccc52021-03-03 11:21:43 -08001798 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001799
1800 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001801 then_block = "THEN_BLOCK"
1802 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001803 attr = ts.TosaSerializerAttribute()
1804 attr.CondIfAttribute(then_block, else_block)
1805
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001806 if error_name in [
1807 ErrorIf.CondIfInputListThenGraphMismatch,
1808 ErrorIf.CondIfInputListElseGraphMismatch,
1809 ErrorIf.CondIfOutputListElseGraphMismatch,
1810 ErrorIf.CondIfOutputListThenGraphMismatch,
1811 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001812 incorrect_shape = a.shape.copy()
1813 for i in range(len(incorrect_shape)):
1814 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1815 incorrect_block_input = deepcopy(a)
1816 incorrect_block_input.shape = incorrect_shape
1817
Eric Kunzee5e26762020-10-13 16:11:07 -07001818 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001819 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001820 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001821 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001822
Les Bell6040b4d2021-10-11 12:50:31 +01001823 if a.dtype in (DType.FLOAT, DType.INT32):
1824 then_op, else_op = Op.ADD, Op.SUB
1825 elif a.dtype in (DType.INT8, DType.INT16):
1826 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1827 else:
1828 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001829
Les Bell6040b4d2021-10-11 12:50:31 +01001830 for block, op in ((then_block, then_op), (else_block, else_op)):
1831 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001832 if (
1833 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1834 and block == then_block
1835 ) or (
1836 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1837 and block == else_block
1838 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001839 self.ser.addInputTensor(incorrect_block_input)
1840 self.ser.addInputTensor(b)
1841 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001842 elif (
1843 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1844 and block == then_block
1845 ) or (
1846 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1847 and block == else_block
1848 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001849 self.ser.addInputTensor(a)
1850 self.ser.addInputTensor(b)
1851 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1852 else:
1853 self.ser.addInputTensor(a)
1854 self.ser.addInputTensor(b)
1855 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001856 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001857
Les Bell729b0352021-11-24 10:28:21 +00001858 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001859 self.ser,
1860 validator_fcns,
1861 error_name,
1862 op=op,
1863 a=a,
1864 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001865 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001866 ):
1867 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001868
Eric Kunzee5e26762020-10-13 16:11:07 -07001869 return result_tens
1870
Matthew Haddon630c17c2021-10-14 15:05:41 +01001871 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001872 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001873
Kevin Cheng550ccc52021-03-03 11:21:43 -08001874 cond_block = "COND_BLOCK"
1875 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001876
1877 attr = ts.TosaSerializerAttribute()
1878 attr.WhileLoopAttribute(cond_block, body_block)
1879
1880 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001881 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001882 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001883 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001884
1885 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001886 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1887 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001888 if error_name == ErrorIf.InputListOutputListMismatch:
1889 incorrect_acc = deepcopy(acc)
1890 for i in range(len(incorrect_acc.shape)):
1891 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1892 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1893 else:
1894 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001895
1896 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001897 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001898 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001899 [iter.name, a.name, acc.name],
1900 [iter_out.name, a_out.name, acc_out.name],
1901 attr,
1902 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001903 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001904
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001905 if error_name in [
1906 ErrorIf.InputListCondGraphMismatch,
1907 ErrorIf.InputListBodyGraphInputMismatch,
1908 ErrorIf.InputListBodyGraphOutputMismatch,
1909 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001910 incorrect_iter = deepcopy(iter)
1911 for i in range(len(incorrect_iter.shape)):
1912 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
1913 if len(incorrect_iter.shape) == 0:
1914 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
1915
1916 incorrect_acc = deepcopy(acc)
1917 for i in range(len(incorrect_acc.shape)):
1918 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1919
Eric Kunzee5e26762020-10-13 16:11:07 -07001920 # COND block (input: iter, output: cond_tens )
1921 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001922 if error_name == ErrorIf.InputListCondGraphMismatch:
1923 self.ser.addInputTensor(incorrect_iter)
1924 self.ser.addInputTensor(a)
1925 self.ser.addInputTensor(incorrect_acc)
1926 else:
1927 self.ser.addInputTensor(iter)
1928 self.ser.addInputTensor(a)
1929 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001930 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001931
1932 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001933 cond_tens = self.ser.addOutput(
1934 [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
1935 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001936 else:
1937 cond_tens = self.ser.addOutput([], DType.BOOL)
1938
Kevin Cheng550ccc52021-03-03 11:21:43 -08001939 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001940
1941 # BODY block (input: a, acc, iter, output: a, acc, iter)
1942 # Note that local intermediate tensors need to be declared here for the outputs
1943 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001944 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
1945 self.ser.addInputTensor(incorrect_iter)
1946 self.ser.addInputTensor(a)
1947 self.ser.addInputTensor(incorrect_acc)
1948 else:
1949 self.ser.addInputTensor(iter)
1950 self.ser.addInputTensor(a)
1951 self.ser.addInputTensor(acc)
1952
Kevin Cheng550ccc52021-03-03 11:21:43 -08001953 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001954
1955 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001956 iter_body_out = self.ser.addIntermediate(
1957 incorrect_iter.shape, incorrect_iter.dtype
1958 )
1959 acc_body_out = self.ser.addIntermediate(
1960 incorrect_acc.shape, incorrect_acc.dtype
1961 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001962 else:
1963 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1964 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
1965
Eric Kunzee5e26762020-10-13 16:11:07 -07001966 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1967 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1968 self.ser.addOutputTensor(iter_body_out)
1969 self.ser.addOutputTensor(a)
1970 self.ser.addOutputTensor(acc_body_out)
1971
Les Bell729b0352021-11-24 10:28:21 +00001972 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001973 self.ser,
1974 validator_fcns,
1975 error_name,
1976 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001977 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001978 ):
1979 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001980
Eric Kunzee5e26762020-10-13 16:11:07 -07001981 return acc_out
1982
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001983 def create_filter_lists(
1984 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
1985 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01001986 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1987 default_test_rank_range = range(1, 5)
1988 if not shapeFilter:
1989 shapeFilter = [None]
1990
1991 # Calculate the filters based on what is requested and what the operator allows
1992 rmin, rmax = op["rank"]
1993 if rankFilter is not None:
1994 cleanRankFilter = []
1995 # Ensure rankFilter values are allowed by operator
1996 for rank in rankFilter:
1997 if rank >= rmin and rank <= rmax:
1998 cleanRankFilter.append(rank)
1999 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002000 # Ensure default behaviour is bounded by default range or by operator,
2001 # whichever is the smaller range of ranks.
2002 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002003 cleanRankFilter = (
2004 opRankRange
2005 if len(opRankRange) <= len(default_test_rank_range)
2006 else default_test_rank_range
2007 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002008 else:
2009 cleanRankFilter = range(rmin, rmax + 1)
2010
2011 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002012
Matthew Haddon1c00b712021-10-01 15:51:03 +01002013 if dtypeFilter is not None:
2014 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002015 # Create list of operator dtypes filtered by requested dtypes
2016 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002017 if dtype in dtypeFilter or (
2018 isinstance(dtype, list) and dtype[0] in dtypeFilter
2019 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002020 cleanDtypeFilter.append(dtype)
2021 else:
2022 cleanDtypeFilter = dtypes
2023
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002024 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002025 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002026 "shapeFilter": shapeFilter,
2027 "rankFilter": cleanRankFilter,
2028 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002029 }
2030 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002031 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002032 if validator is not None:
2033 validator_info = validator(check=False, op=op)
2034 else:
2035 return None
2036
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002037 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002038
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002039 # Set parameters as required
2040 if error_arguments["rank"] is not None:
2041 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002042 else:
2043 rankFilter = cleanRankFilter
2044
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002045 if error_arguments["dtype"] is not None:
2046 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002047 else:
2048 dtypeFilter = cleanDtypeFilter
2049
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002050 if error_arguments["shape"] is not None:
2051 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002052 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002053 shapeFilter = shapeFilter[
2054 :2
2055 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002056
2057 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002058 "shapeFilter": shapeFilter,
2059 "rankFilter": rankFilter,
2060 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002061 }
2062 return filterDict
2063
Kevin Cheng550ccc52021-03-03 11:21:43 -08002064 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002065 self,
2066 opName,
2067 shapeFilter=[None],
2068 rankFilter=None,
2069 dtypeFilter=None,
2070 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002071 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002072
2073 try:
2074 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002075 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002076 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002077
2078 # Initialize a new random number generator
2079 self.rng = np.random.default_rng(self.random_seed)
2080
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002081 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002082
Eric Kunzee5e26762020-10-13 16:11:07 -07002083 # Test list consists of a tuple of:
2084 # (opName, testNameStr, dtype, shapeList, argumentsList)
2085 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002086 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002087 error_if_validators = op["error_if_validators"]
2088 else:
2089 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002090
Matthew Haddon1c00b712021-10-01 15:51:03 +01002091 for validator in error_if_validators:
2092 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002093 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002094 else:
2095 error_name = None
2096
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002097 filterDict = self.create_filter_lists(
2098 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2099 )
2100 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002101 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002102 cleanRankFilter = filterDict["rankFilter"]
2103 cleanDtypeFilter = filterDict["dtypeFilter"]
2104 cleanShapeFilter = filterDict["shapeFilter"]
2105 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002106
2107 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002108 for t in cleanDtypeFilter:
2109 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002110 # Filter out by rank
2111 if shape is not None and len(shape) != r:
2112 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002113 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002114 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002115
Matthew Haddon74567092021-07-16 15:38:20 +01002116 shapeStr = self.shapeStr(shapeList[0])
2117 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002118
Matthew Haddon74567092021-07-16 15:38:20 +01002119 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2120 argList = []
2121 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002122 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002123 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002124 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002125
Matthew Haddon74567092021-07-16 15:38:20 +01002126 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002127 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002128 if argStr:
2129 testStr = "{}_{}_{}_{}".format(
2130 opName, shapeStr, typeStr, argStr
2131 )
2132 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 testStr = "{}_{}_{}".format(
2134 opName, shapeStr, typeStr
2135 )
2136 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002137 if argStr:
2138 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2139 opName, error_name, shapeStr, typeStr, argStr
2140 )
2141 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002142 testStr = "{}_ERRORIF_{}_{}_{}".format(
2143 opName, error_name, shapeStr, typeStr
2144 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002145
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002146 testList.append(
2147 (opName, testStr, t, error_name, shapeList, args)
2148 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002149
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002150 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002151 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2152 if "invalid_test_validators" in op:
2153 invalid_test_validators = op["invalid_test_validators"]
2154 clean_testList = []
2155 for test in testList:
2156 for validator_fcn in invalid_test_validators:
2157 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002158 if validator_fcn(
2159 opName=test[0],
2160 input_dtype=test[2],
2161 shapeList=test[4],
2162 args=test[5],
2163 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002164 remove_test = True
2165 if not remove_test:
2166 clean_testList.append(test)
2167 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002168
2169 return testList
2170
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002171 def serializeTest(
2172 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2173 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002174 try:
2175 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002176 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002177 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002178
2179 # Create a serializer
2180 self.createSerializer(opName, testStr)
2181
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002182 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002183 if "error_if_validators" in op:
2184 error_if_validators = op["error_if_validators"]
2185 else:
2186 error_if_validators = None
2187
Kevin Cheng550ccc52021-03-03 11:21:43 -08002188 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002189 num_operands = pCount + cCount
2190
2191 if isinstance(dtype_or_dtypeList, list):
2192 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002193 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002194 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002195 else:
2196 dtypeList = [dtype_or_dtypeList] * (num_operands)
2197
Kevin Cheng93a16282021-08-31 16:14:03 -07002198 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002199 assert (
2200 len(shapeList) == num_operands
2201 ), "shapeList length {} must match number of operands {}".format(
2202 len(shapeList), num_operands
2203 )
2204 assert (
2205 len(dtypeList) == num_operands
2206 ), "dtypeList length {} must match number of operands {}".format(
2207 len(dtypeList), num_operands
2208 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002209
2210 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002211 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002212 except KeyError:
2213 qgen = None
2214
2215 # Build the random tensor operands and the test
2216 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002217
Matthew Haddon1c00b712021-10-01 15:51:03 +01002218 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002219 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002220 else:
2221 qinfo = None
2222
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002223 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002224
Matthew Haddon1c00b712021-10-01 15:51:03 +01002225 try:
2226 if error_if_validators is None:
2227 if qinfo is not None:
2228 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2229 else:
2230 resultName = build_fcn(self, op, *tens, *testArgs)
2231 else:
2232 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002233 resultName = build_fcn(
2234 self,
2235 op,
2236 *tens,
2237 *testArgs,
2238 validator_fcns=error_if_validators,
2239 error_name=error_name,
2240 qinfo=qinfo,
2241 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002242 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002243 resultName = build_fcn(
2244 self,
2245 op,
2246 *tens,
2247 *testArgs,
2248 validator_fcns=error_if_validators,
2249 error_name=error_name,
2250 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002251 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002252 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002253 raise e
2254
Les Bell729b0352021-11-24 10:28:21 +00002255 if resultName:
2256 # The test is valid, serialize it
2257 self.serialize("test")
2258 else:
2259 # The test is not valid
2260 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002261
Eric Kunzee5e26762020-10-13 16:11:07 -07002262 def createDynamicOpLists(self):
2263
2264 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002265 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002266
Kevin Cheng1533b852021-09-01 12:51:58 -07002267 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002268 testName = "conv2d_{}x{}".format(k[0], k[1])
2269 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2270 self.TOSA_OP_LIST[testName]["filter"] = k
2271 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002272
Kevin Cheng550ccc52021-03-03 11:21:43 -08002273 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2274 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2275 "depthwise_conv2d_TEMPLATE"
2276 ].copy()
2277 self.TOSA_OP_LIST[testName]["filter"] = k
2278 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002279
Kevin Cheng550ccc52021-03-03 11:21:43 -08002280 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2281 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2282 "transpose_conv2d_TEMPLATE"
2283 ].copy()
2284 self.TOSA_OP_LIST[testName]["filter"] = k
2285 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002286
Kevin Cheng1533b852021-09-01 12:51:58 -07002287 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2288 for k in KERNELS_3D:
2289 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2290 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2291 self.TOSA_OP_LIST[testName]["filter"] = k
2292 self.TOSA_OP_LIST[testName]["template"] = False
2293
Eric Kunzee5e26762020-10-13 16:11:07 -07002294 # Delete any templates after having created any dynamic ops
2295 # This is a two-pass operation because it's bad practice to delete
2296 # keys from dictionaries while iterating
2297 keyList = []
2298 for k in self.TOSA_OP_LIST:
2299 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002300 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002301 keyList.append(k)
2302 continue
2303 except KeyError:
2304 pass
2305
2306 for k in keyList:
2307 del self.TOSA_OP_LIST[k]
2308
2309 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002310 """Fill in default fields for ops if they aren't already specified.
2311 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002312 for op in self.TOSA_OP_LIST:
2313
2314 # Required fields
2315 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002316 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002317 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002318 raise Exception(
2319 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2320 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002321
2322 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002323 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002324 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002325 raise Exception(
2326 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2327 op
2328 )
2329 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002330
2331 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002332 _ = self.TOSA_OP_LIST[op]["types"]
2333 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002334 raise Exception(
2335 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2336 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002337
2338 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002339 _ = self.TOSA_OP_LIST[op]["op"]
2340 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002341 raise Exception(
2342 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2343 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002344
2345 # Put in default rank range, if missing
2346 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002347 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002348 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002349 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002350
2351 # Tensor operator list
2352 # 'op': op name
2353 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002354 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2355 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002356 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2357 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002358 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002359
Kevin Cheng550ccc52021-03-03 11:21:43 -08002360 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2361 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002362
Kevin Cheng550ccc52021-03-03 11:21:43 -08002363 TYPE_BOOL = [DType.BOOL]
2364 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2365 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2366 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002367
Kevin Cheng550ccc52021-03-03 11:21:43 -08002368 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002369
Kevin Cheng1533b852021-09-01 12:51:58 -07002370 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002371 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002372 [DType.INT8, DType.INT8, DType.INT32],
2373 [DType.INT16, DType.INT8, DType.INT48],
2374 DType.FLOAT,
2375 ]
2376
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002377 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002378
2379 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002380 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002381 "argmax": {
2382 "op": Op.ARGMAX,
2383 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002384 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002385 "build_fcn": (
2386 build_argmax,
2387 TosaTensorGen.tgBasic,
2388 TosaTensorValuesGen.tvgDefault,
2389 TosaArgGen.agAxis,
2390 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002391 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002392 "error_if_validators": (
2393 TosaErrorValidator.evAxisSmallerZero,
2394 TosaErrorValidator.evAxisLargerRank,
2395 TosaErrorValidator.evArgmaxOutputRankMismatch,
2396 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2397 TosaErrorValidator.evWrongRank,
2398 TosaErrorValidator.evWrongInputType,
2399 TosaErrorValidator.evWrongOutputType,
2400 TosaErrorValidator.evWrongInputList,
2401 TosaErrorValidator.evWrongOutputList,
2402 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002403 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002404 "avg_pool2d": {
2405 "op": Op.AVG_POOL2D,
2406 "operands": (1, 0),
2407 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002408 "build_fcn": (
2409 build_pool2d,
2410 TosaTensorGen.tgNHWC,
2411 TosaTensorValuesGen.tvgDefault,
2412 TosaArgGen.agPooling,
2413 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002414 "qgen": TosaQuantGen.qgUnary,
2415 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002416 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002417 "error_if_validators": (
2418 TosaErrorValidator.evKernelSmallerOne,
2419 TosaErrorValidator.evStrideSmallerOne,
2420 TosaErrorValidator.evPadSmallerZero,
2421 TosaErrorValidator.evWrongRank,
2422 TosaErrorValidator.evWrongInputType,
2423 TosaErrorValidator.evWrongOutputType,
2424 TosaErrorValidator.evWrongInputList,
2425 TosaErrorValidator.evWrongOutputList,
2426 TosaErrorValidator.evInputZeroPointNotZero,
2427 TosaErrorValidator.evOutputZeroPointNotZero,
2428 TosaErrorValidator.evPadLargerEqualKernel,
2429 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002430 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002431 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002432 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002433 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002434 "conv2d_TEMPLATE": {
2435 "op": Op.CONV2D,
2436 "operands": (1, 2),
2437 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002438 "build_fcn": (
2439 build_conv2d,
2440 TosaTensorGen.tgConv2D,
2441 TosaTensorValuesGen.tvgDefault,
2442 TosaArgGen.agConv,
2443 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002444 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002445 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002446 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2447 "error_if_validators": (
2448 TosaErrorValidator.evWrongInputType,
2449 TosaErrorValidator.evWrongOutputType,
2450 TosaErrorValidator.evWrongInputList,
2451 TosaErrorValidator.evWrongOutputList,
2452 TosaErrorValidator.evInputZeroPointNotZero,
2453 TosaErrorValidator.evWeightZeroPointNotZero,
2454 TosaErrorValidator.evPadSmallerZero,
2455 TosaErrorValidator.evStrideSmallerOne,
2456 TosaErrorValidator.evDilationSmallerOne,
2457 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002458 TosaErrorValidator.evConvOutputShapeMismatch,
2459 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002460 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002461 "template": True,
2462 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002463 # Templated operator. Filled in by createDynamicOpLists
2464 "conv3d_TEMPLATE": {
2465 "op": Op.CONV3D,
2466 "operands": (1, 2),
2467 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002468 "build_fcn": (
2469 build_conv3d,
2470 TosaTensorGen.tgConv3D,
2471 TosaTensorValuesGen.tvgDefault,
2472 TosaArgGen.agConv,
2473 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002474 "qgen": TosaQuantGen.qgConv,
2475 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002476 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2477 "error_if_validators": (
2478 TosaErrorValidator.evWrongInputType,
2479 TosaErrorValidator.evWrongOutputType,
2480 TosaErrorValidator.evWrongInputList,
2481 TosaErrorValidator.evWrongOutputList,
2482 TosaErrorValidator.evInputZeroPointNotZero,
2483 TosaErrorValidator.evWeightZeroPointNotZero,
2484 TosaErrorValidator.evPadSmallerZero,
2485 TosaErrorValidator.evStrideSmallerOne,
2486 TosaErrorValidator.evDilationSmallerOne,
2487 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002488 TosaErrorValidator.evConvOutputShapeMismatch,
2489 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002490 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002491 "template": True,
2492 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002493 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002494 "depthwise_conv2d_TEMPLATE": {
2495 "op": Op.DEPTHWISE_CONV2D,
2496 "operands": (1, 2),
2497 "filter": [1, 1],
2498 "rank": (4, 4),
2499 "build_fcn": (
2500 build_depthwise_conv2d,
2501 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002502 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002503 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002504 ),
2505 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002506 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002507 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2508 "error_if_validators": (
2509 TosaErrorValidator.evWrongInputType,
2510 TosaErrorValidator.evWrongOutputType,
2511 TosaErrorValidator.evWrongInputList,
2512 TosaErrorValidator.evWrongOutputList,
2513 TosaErrorValidator.evInputZeroPointNotZero,
2514 TosaErrorValidator.evWeightZeroPointNotZero,
2515 TosaErrorValidator.evPadSmallerZero,
2516 TosaErrorValidator.evStrideSmallerOne,
2517 TosaErrorValidator.evDilationSmallerOne,
2518 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002519 TosaErrorValidator.evConvOutputShapeMismatch,
2520 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002521 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002522 "template": True,
2523 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002524 "fully_connected": {
2525 "op": Op.FULLY_CONNECTED,
2526 "operands": (1, 2),
2527 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002528 "build_fcn": (
2529 build_fully_connected,
2530 TosaTensorGen.tgFullyConnected,
2531 TosaTensorValuesGen.tvgDefault,
2532 None,
2533 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002534 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002535 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002536 "error_if_validators": (
2537 TosaErrorValidator.evInputZeroPointNotZero,
2538 TosaErrorValidator.evWeightZeroPointNotZero,
2539 TosaErrorValidator.evWrongRank,
2540 TosaErrorValidator.evWrongInputType,
2541 TosaErrorValidator.evWrongOutputType,
2542 TosaErrorValidator.evWrongInputList,
2543 TosaErrorValidator.evWrongOutputList,
2544 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002545 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002546 "matmul": {
2547 "op": Op.MATMUL,
2548 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002549 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002550 "build_fcn": (
2551 build_matmul,
2552 TosaTensorGen.tgMatmul,
2553 TosaTensorValuesGen.tvgDefault,
2554 None,
2555 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002556 "qgen": TosaQuantGen.qgMatmul,
2557 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002558 "error_if_validators": (
2559 TosaErrorValidator.evInputZeroPointNotZero,
2560 TosaErrorValidator.evWrongRank,
2561 TosaErrorValidator.evWrongInputType,
2562 TosaErrorValidator.evWrongOutputType,
2563 TosaErrorValidator.evWrongInputList,
2564 TosaErrorValidator.evWrongOutputList,
2565 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002566 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002567 "max_pool2d": {
2568 "op": Op.MAX_POOL2D,
2569 "operands": (1, 0),
2570 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002571 "build_fcn": (
2572 build_pool2d,
2573 TosaTensorGen.tgNHWC,
2574 TosaTensorValuesGen.tvgDefault,
2575 TosaArgGen.agPooling,
2576 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002577 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002578 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002579 "error_if_validators": (
2580 TosaErrorValidator.evKernelSmallerOne,
2581 TosaErrorValidator.evStrideSmallerOne,
2582 TosaErrorValidator.evPadSmallerZero,
2583 TosaErrorValidator.evWrongRank,
2584 TosaErrorValidator.evWrongInputType,
2585 TosaErrorValidator.evWrongOutputType,
2586 TosaErrorValidator.evWrongInputList,
2587 TosaErrorValidator.evWrongOutputList,
2588 TosaErrorValidator.evPadLargerEqualKernel,
2589 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002590 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002591 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002592 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002593 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002594 "transpose_conv2d_TEMPLATE": {
2595 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002596 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002597 "rank": (4, 4),
2598 "build_fcn": (
2599 build_transpose_conv2d,
2600 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002601 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002602 TosaArgGen.agTransposeConv2D,
2603 ),
2604 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002605 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002606 "invalid_test_validators": (
2607 TosaInvalidValidator.ivHeightWidthInvalid,
2608 TosaInvalidValidator.ivNonPositiveOutputShape,
2609 ),
2610 "error_if_validators": (
2611 TosaErrorValidator.evWrongInputType,
2612 TosaErrorValidator.evWrongOutputType,
2613 TosaErrorValidator.evWrongInputList,
2614 TosaErrorValidator.evWrongOutputList,
2615 TosaErrorValidator.evInputZeroPointNotZero,
2616 TosaErrorValidator.evWeightZeroPointNotZero,
2617 TosaErrorValidator.evPadSmallerZero,
2618 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002619 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002620 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002621 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002622 "template": True,
2623 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002624 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002625 "clamp": {
2626 "op": Op.CLAMP,
2627 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002628 "build_fcn": (
2629 build_clamp,
2630 TosaTensorGen.tgBasic,
2631 TosaTensorValuesGen.tvgDefault,
2632 None,
2633 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002634 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002635 "error_if_validators": (
2636 TosaErrorValidator.evMaxSmallerMin,
2637 TosaErrorValidator.evWrongInputType,
2638 TosaErrorValidator.evWrongOutputType,
2639 TosaErrorValidator.evWrongInputList,
2640 TosaErrorValidator.evWrongOutputList,
2641 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002642 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002643 "sigmoid": {
2644 "op": Op.SIGMOID,
2645 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002646 "build_fcn": (
2647 build_sigmoid,
2648 TosaTensorGen.tgBasic,
2649 TosaTensorValuesGen.tvgDefault,
2650 None,
2651 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002652 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002653 "error_if_validators": (
2654 TosaErrorValidator.evWrongInputType,
2655 TosaErrorValidator.evWrongOutputType,
2656 TosaErrorValidator.evWrongInputList,
2657 TosaErrorValidator.evWrongOutputList,
2658 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002659 },
2660 "tanh": {
2661 "op": Op.TANH,
2662 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002663 "build_fcn": (
2664 build_tanh,
2665 TosaTensorGen.tgBasic,
2666 TosaTensorValuesGen.tvgDefault,
2667 None,
2668 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002669 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002670 "error_if_validators": (
2671 TosaErrorValidator.evWrongInputType,
2672 TosaErrorValidator.evWrongOutputType,
2673 TosaErrorValidator.evWrongInputList,
2674 TosaErrorValidator.evWrongOutputList,
2675 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002676 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002677 # Elementwise Binary Operators
2678 "add": {
2679 "op": Op.ADD,
2680 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002681 "build_fcn": (
2682 build_binary_broadcast,
2683 TosaTensorGen.tgBroadcastFuzz,
2684 TosaTensorValuesGen.tvgAddSub,
2685 None,
2686 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002687 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002688 "error_if_validators": (
2689 TosaErrorValidator.evRankMismatch,
2690 TosaErrorValidator.evWrongInputType,
2691 TosaErrorValidator.evWrongOutputType,
2692 TosaErrorValidator.evWrongInputList,
2693 TosaErrorValidator.evWrongOutputList,
2694 TosaErrorValidator.evDimensionMismatch,
2695 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002696 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002697 "arithmetic_right_shift": {
2698 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2699 "operands": (2, 0),
2700 "build_fcn": (
2701 build_arithmetic_right_shift,
2702 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002703 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002704 TosaArgGen.agArithmeticRightShift,
2705 ),
2706 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002707 "error_if_validators": (
2708 TosaErrorValidator.evRankMismatch,
2709 TosaErrorValidator.evWrongInputType,
2710 TosaErrorValidator.evWrongOutputType,
2711 TosaErrorValidator.evWrongInputList,
2712 TosaErrorValidator.evWrongOutputList,
2713 TosaErrorValidator.evDimensionMismatch,
2714 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002715 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002716 "bitwise_and": {
2717 "op": Op.BITWISE_AND,
2718 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002719 "build_fcn": (
2720 build_binary_broadcast,
2721 TosaTensorGen.tgBroadcastFuzz,
2722 TosaTensorValuesGen.tvgDefault,
2723 None,
2724 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002725 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002726 "error_if_validators": (
2727 TosaErrorValidator.evRankMismatch,
2728 TosaErrorValidator.evWrongInputType,
2729 TosaErrorValidator.evWrongOutputType,
2730 TosaErrorValidator.evWrongInputList,
2731 TosaErrorValidator.evWrongOutputList,
2732 TosaErrorValidator.evDimensionMismatch,
2733 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002734 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002735 "bitwise_or": {
2736 "op": Op.BITWISE_OR,
2737 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002738 "build_fcn": (
2739 build_binary_broadcast,
2740 TosaTensorGen.tgBroadcastFuzz,
2741 TosaTensorValuesGen.tvgDefault,
2742 None,
2743 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002744 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002745 "error_if_validators": (
2746 TosaErrorValidator.evRankMismatch,
2747 TosaErrorValidator.evWrongInputType,
2748 TosaErrorValidator.evWrongOutputType,
2749 TosaErrorValidator.evWrongInputList,
2750 TosaErrorValidator.evWrongOutputList,
2751 TosaErrorValidator.evDimensionMismatch,
2752 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002753 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002754 "bitwise_xor": {
2755 "op": Op.BITWISE_XOR,
2756 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002757 "build_fcn": (
2758 build_binary_broadcast,
2759 TosaTensorGen.tgBroadcastFuzz,
2760 TosaTensorValuesGen.tvgDefault,
2761 None,
2762 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002763 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002764 "error_if_validators": (
2765 TosaErrorValidator.evRankMismatch,
2766 TosaErrorValidator.evWrongInputType,
2767 TosaErrorValidator.evWrongOutputType,
2768 TosaErrorValidator.evWrongInputList,
2769 TosaErrorValidator.evWrongOutputList,
2770 TosaErrorValidator.evDimensionMismatch,
2771 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002772 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002773 "intdiv": {
2774 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002775 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002776 "build_fcn": (
2777 build_binary_broadcast,
2778 TosaTensorGen.tgBroadcastFuzz,
2779 TosaTensorValuesGen.tvgIntDiv,
2780 None,
2781 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002782 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002783 "error_if_validators": (
2784 TosaErrorValidator.evRankMismatch,
2785 TosaErrorValidator.evWrongInputType,
2786 TosaErrorValidator.evWrongOutputType,
2787 TosaErrorValidator.evWrongInputList,
2788 TosaErrorValidator.evWrongOutputList,
2789 TosaErrorValidator.evDimensionMismatch,
2790 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002791 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002792 "logical_and": {
2793 "op": Op.LOGICAL_AND,
2794 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002795 "build_fcn": (
2796 build_binary_broadcast,
2797 TosaTensorGen.tgBroadcastFuzz,
2798 TosaTensorValuesGen.tvgDefault,
2799 None,
2800 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002801 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002802 "error_if_validators": (
2803 TosaErrorValidator.evRankMismatch,
2804 TosaErrorValidator.evWrongInputType,
2805 TosaErrorValidator.evWrongOutputType,
2806 TosaErrorValidator.evWrongInputList,
2807 TosaErrorValidator.evWrongOutputList,
2808 TosaErrorValidator.evDimensionMismatch,
2809 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002810 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002811 "logical_left_shift": {
2812 "op": Op.LOGICAL_LEFT_SHIFT,
2813 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002814 "build_fcn": (
2815 build_binary_broadcast,
2816 TosaTensorGen.tgBroadcastFuzz,
2817 TosaTensorValuesGen.tvgLogicalShift,
2818 None,
2819 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002820 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002821 "error_if_validators": (
2822 TosaErrorValidator.evRankMismatch,
2823 TosaErrorValidator.evWrongInputType,
2824 TosaErrorValidator.evWrongOutputType,
2825 TosaErrorValidator.evWrongInputList,
2826 TosaErrorValidator.evWrongOutputList,
2827 TosaErrorValidator.evDimensionMismatch,
2828 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002829 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002830 "logical_right_shift": {
2831 "op": Op.LOGICAL_RIGHT_SHIFT,
2832 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002833 "build_fcn": (
2834 build_binary_broadcast,
2835 TosaTensorGen.tgBroadcastFuzz,
2836 TosaTensorValuesGen.tvgLogicalShift,
2837 None,
2838 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002839 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002840 "error_if_validators": (
2841 TosaErrorValidator.evRankMismatch,
2842 TosaErrorValidator.evWrongInputType,
2843 TosaErrorValidator.evWrongOutputType,
2844 TosaErrorValidator.evWrongInputList,
2845 TosaErrorValidator.evWrongOutputList,
2846 TosaErrorValidator.evDimensionMismatch,
2847 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002848 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002849 "logical_or": {
2850 "op": Op.LOGICAL_OR,
2851 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002852 "build_fcn": (
2853 build_binary_broadcast,
2854 TosaTensorGen.tgBroadcastFuzz,
2855 TosaTensorValuesGen.tvgDefault,
2856 None,
2857 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002858 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002859 "error_if_validators": (
2860 TosaErrorValidator.evRankMismatch,
2861 TosaErrorValidator.evWrongInputType,
2862 TosaErrorValidator.evWrongOutputType,
2863 TosaErrorValidator.evWrongInputList,
2864 TosaErrorValidator.evWrongOutputList,
2865 TosaErrorValidator.evDimensionMismatch,
2866 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002867 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002868 "logical_xor": {
2869 "op": Op.LOGICAL_XOR,
2870 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002871 "build_fcn": (
2872 build_binary_broadcast,
2873 TosaTensorGen.tgBroadcastFuzz,
2874 TosaTensorValuesGen.tvgDefault,
2875 None,
2876 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002877 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002878 "error_if_validators": (
2879 TosaErrorValidator.evRankMismatch,
2880 TosaErrorValidator.evWrongInputType,
2881 TosaErrorValidator.evWrongOutputType,
2882 TosaErrorValidator.evWrongInputList,
2883 TosaErrorValidator.evWrongOutputList,
2884 TosaErrorValidator.evDimensionMismatch,
2885 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002886 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002887 "maximum": {
2888 "op": Op.MAXIMUM,
2889 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002890 "build_fcn": (
2891 build_binary_broadcast,
2892 TosaTensorGen.tgBroadcastFuzz,
2893 TosaTensorValuesGen.tvgDefault,
2894 None,
2895 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002896 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002897 "error_if_validators": (
2898 TosaErrorValidator.evRankMismatch,
2899 TosaErrorValidator.evWrongInputType,
2900 TosaErrorValidator.evWrongOutputType,
2901 TosaErrorValidator.evWrongInputList,
2902 TosaErrorValidator.evWrongOutputList,
2903 TosaErrorValidator.evDimensionMismatch,
2904 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002905 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002906 "minimum": {
2907 "op": Op.MINIMUM,
2908 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002909 "build_fcn": (
2910 build_binary_broadcast,
2911 TosaTensorGen.tgBroadcastFuzz,
2912 TosaTensorValuesGen.tvgDefault,
2913 None,
2914 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002915 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002916 "error_if_validators": (
2917 TosaErrorValidator.evRankMismatch,
2918 TosaErrorValidator.evWrongInputType,
2919 TosaErrorValidator.evWrongOutputType,
2920 TosaErrorValidator.evWrongInputList,
2921 TosaErrorValidator.evWrongOutputList,
2922 TosaErrorValidator.evDimensionMismatch,
2923 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002924 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002925 "mul": {
2926 "op": Op.MUL,
2927 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002928 "build_fcn": (
2929 build_mul,
2930 TosaTensorGen.tgBroadcastFuzz,
2931 TosaTensorValuesGen.tvgMul,
2932 TosaArgGen.agMul,
2933 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002934 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002935 "error_if_validators": (
2936 TosaErrorValidator.evWrongInputType,
2937 TosaErrorValidator.evWrongOutputType,
2938 TosaErrorValidator.evWrongInputList,
2939 TosaErrorValidator.evWrongOutputList,
2940 TosaErrorValidator.evRankMismatch,
2941 TosaErrorValidator.evDimensionMismatch,
2942 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002943 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002944 "pow": {
2945 "op": Op.POW,
2946 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002947 "build_fcn": (
2948 build_binary_broadcast,
2949 TosaTensorGen.tgBroadcastFuzz,
2950 TosaTensorValuesGen.tvgDefault,
2951 None,
2952 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002953 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002954 "error_if_validators": (
2955 TosaErrorValidator.evRankMismatch,
2956 TosaErrorValidator.evWrongInputType,
2957 TosaErrorValidator.evWrongOutputType,
2958 TosaErrorValidator.evWrongInputList,
2959 TosaErrorValidator.evWrongOutputList,
2960 TosaErrorValidator.evDimensionMismatch,
2961 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002962 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002963 "sub": {
2964 "op": Op.SUB,
2965 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002966 "build_fcn": (
2967 build_binary_broadcast,
2968 TosaTensorGen.tgBroadcastFuzz,
2969 TosaTensorValuesGen.tvgAddSub,
2970 None,
2971 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002972 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002973 "error_if_validators": (
2974 TosaErrorValidator.evRankMismatch,
2975 TosaErrorValidator.evWrongInputType,
2976 TosaErrorValidator.evWrongOutputType,
2977 TosaErrorValidator.evWrongInputList,
2978 TosaErrorValidator.evWrongOutputList,
2979 TosaErrorValidator.evDimensionMismatch,
2980 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002981 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002982 "table": {
2983 "op": Op.TABLE,
2984 # Use the automatic generation functions to create the input array
2985 # but create the table tensor in the build function, as it may be
2986 # a different type from the input
2987 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002988 "build_fcn": (
2989 build_table,
2990 TosaTensorGen.tgBasic,
2991 TosaTensorValuesGen.tvgDefault,
2992 TosaArgGen.agTable,
2993 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002994 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002995 "error_if_validators": (
2996 TosaErrorValidator.evWrongInputType,
2997 TosaErrorValidator.evWrongOutputType,
2998 TosaErrorValidator.evWrongInputList,
2999 TosaErrorValidator.evWrongOutputList,
3000 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003001 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003002 # Elementwise Unary operators
3003 "abs": {
3004 "op": Op.ABS,
3005 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003006 "build_fcn": (
3007 build_unary,
3008 TosaTensorGen.tgBasic,
3009 TosaTensorValuesGen.tvgDefault,
3010 None,
3011 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003012 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003013 "error_if_validators": (
3014 TosaErrorValidator.evWrongInputType,
3015 TosaErrorValidator.evWrongOutputType,
3016 TosaErrorValidator.evWrongInputList,
3017 TosaErrorValidator.evWrongOutputList,
3018 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003019 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003020 "bitwise_not": {
3021 "op": Op.BITWISE_NOT,
3022 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003023 "build_fcn": (
3024 build_unary,
3025 TosaTensorGen.tgBasic,
3026 TosaTensorValuesGen.tvgDefault,
3027 None,
3028 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003029 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003030 "error_if_validators": (
3031 TosaErrorValidator.evWrongInputType,
3032 TosaErrorValidator.evWrongOutputType,
3033 TosaErrorValidator.evWrongInputList,
3034 TosaErrorValidator.evWrongOutputList,
3035 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003036 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003037 "ceil": {
3038 "op": Op.CEIL,
3039 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003040 "build_fcn": (
3041 build_unary,
3042 TosaTensorGen.tgBasic,
3043 TosaTensorValuesGen.tvgDefault,
3044 None,
3045 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003046 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003047 "error_if_validators": (
3048 TosaErrorValidator.evWrongInputType,
3049 TosaErrorValidator.evWrongOutputType,
3050 TosaErrorValidator.evWrongInputList,
3051 TosaErrorValidator.evWrongOutputList,
3052 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003053 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003054 "clz": {
3055 "op": Op.CLZ,
3056 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003057 "build_fcn": (
3058 build_unary,
3059 TosaTensorGen.tgBasic,
3060 TosaTensorValuesGen.tvgDefault,
3061 None,
3062 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003063 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003064 "error_if_validators": (
3065 TosaErrorValidator.evWrongInputType,
3066 TosaErrorValidator.evWrongOutputType,
3067 TosaErrorValidator.evWrongInputList,
3068 TosaErrorValidator.evWrongOutputList,
3069 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003070 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003071 "exp": {
3072 "op": Op.EXP,
3073 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003074 "build_fcn": (
3075 build_unary,
3076 TosaTensorGen.tgBasic,
3077 TosaTensorValuesGen.tvgDefault,
3078 None,
3079 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003080 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003081 "error_if_validators": (
3082 TosaErrorValidator.evWrongInputType,
3083 TosaErrorValidator.evWrongOutputType,
3084 TosaErrorValidator.evWrongInputList,
3085 TosaErrorValidator.evWrongOutputList,
3086 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003087 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003088 "floor": {
3089 "op": Op.FLOOR,
3090 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003091 "build_fcn": (
3092 build_unary,
3093 TosaTensorGen.tgBasic,
3094 TosaTensorValuesGen.tvgDefault,
3095 None,
3096 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003097 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003098 "error_if_validators": (
3099 TosaErrorValidator.evWrongInputType,
3100 TosaErrorValidator.evWrongOutputType,
3101 TosaErrorValidator.evWrongInputList,
3102 TosaErrorValidator.evWrongOutputList,
3103 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003104 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003105 "log": {
3106 "op": Op.LOG,
3107 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003108 "build_fcn": (
3109 build_unary,
3110 TosaTensorGen.tgBasic,
3111 TosaTensorValuesGen.tvgDefault,
3112 None,
3113 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003114 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003115 "error_if_validators": (
3116 TosaErrorValidator.evWrongInputType,
3117 TosaErrorValidator.evWrongOutputType,
3118 TosaErrorValidator.evWrongInputList,
3119 TosaErrorValidator.evWrongOutputList,
3120 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003121 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003122 "logical_not": {
3123 "op": Op.LOGICAL_NOT,
3124 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003125 "build_fcn": (
3126 build_unary,
3127 TosaTensorGen.tgBasic,
3128 TosaTensorValuesGen.tvgDefault,
3129 None,
3130 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003131 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003132 "error_if_validators": (
3133 TosaErrorValidator.evWrongInputType,
3134 TosaErrorValidator.evWrongOutputType,
3135 TosaErrorValidator.evWrongInputList,
3136 TosaErrorValidator.evWrongOutputList,
3137 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003138 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003139 "negate": {
3140 "op": Op.NEGATE,
3141 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003142 "build_fcn": (
3143 build_unary,
3144 TosaTensorGen.tgBasic,
3145 TosaTensorValuesGen.tvgNegate,
3146 None,
3147 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003148 "qgen": TosaQuantGen.qgUnary,
3149 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003150 "error_if_validators": (
3151 TosaErrorValidator.evInputZeroPointNotZero,
3152 TosaErrorValidator.evOutputZeroPointNotZero,
3153 TosaErrorValidator.evWrongInputType,
3154 TosaErrorValidator.evWrongOutputType,
3155 TosaErrorValidator.evWrongInputList,
3156 TosaErrorValidator.evWrongOutputList,
3157 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003158 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003159 "reciprocal": {
3160 "op": Op.RECIPROCAL,
3161 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003162 "build_fcn": (
3163 build_unary,
3164 TosaTensorGen.tgBasic,
3165 TosaTensorValuesGen.tvgDefault,
3166 None,
3167 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003168 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003169 "error_if_validators": (
3170 TosaErrorValidator.evWrongInputType,
3171 TosaErrorValidator.evWrongOutputType,
3172 TosaErrorValidator.evWrongInputList,
3173 TosaErrorValidator.evWrongOutputList,
3174 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003175 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003176 "rsqrt": {
3177 "op": Op.RSQRT,
3178 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003179 "build_fcn": (
3180 build_unary,
3181 TosaTensorGen.tgBasic,
3182 TosaTensorValuesGen.tvgDefault,
3183 None,
3184 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003185 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003186 "error_if_validators": (
3187 TosaErrorValidator.evWrongInputType,
3188 TosaErrorValidator.evWrongOutputType,
3189 TosaErrorValidator.evWrongInputList,
3190 TosaErrorValidator.evWrongOutputList,
3191 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003192 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003193 # Elementwise Ternary operators
3194 "select": {
3195 "op": Op.SELECT,
3196 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003197 "build_fcn": (
3198 build_select,
3199 TosaTensorGen.tgBroadcastFuzz,
3200 TosaTensorValuesGen.tvgSelect,
3201 None,
3202 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003203 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003204 "error_if_validators": (
3205 TosaErrorValidator.evRankMismatch,
3206 TosaErrorValidator.evWrongInputType,
3207 TosaErrorValidator.evWrongOutputType,
3208 TosaErrorValidator.evWrongInputList,
3209 TosaErrorValidator.evWrongOutputList,
3210 TosaErrorValidator.evDimensionMismatch,
3211 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003212 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003213 # Comparison operators
3214 "equal": {
3215 "op": Op.EQUAL,
3216 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003217 "build_fcn": (
3218 build_comparison,
3219 TosaTensorGen.tgBroadcastFuzz,
3220 TosaTensorValuesGen.tvgEqual,
3221 None,
3222 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003223 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003224 "error_if_validators": (
3225 TosaErrorValidator.evRankMismatch,
3226 TosaErrorValidator.evWrongInputType,
3227 TosaErrorValidator.evWrongOutputType,
3228 TosaErrorValidator.evWrongInputList,
3229 TosaErrorValidator.evWrongOutputList,
3230 TosaErrorValidator.evDimensionMismatch,
3231 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003232 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003233 "greater_equal": {
3234 "op": Op.GREATER_EQUAL,
3235 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003236 "build_fcn": (
3237 build_comparison,
3238 TosaTensorGen.tgBroadcastFuzz,
3239 TosaTensorValuesGen.tvgDefault,
3240 None,
3241 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003242 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003243 "error_if_validators": (
3244 TosaErrorValidator.evRankMismatch,
3245 TosaErrorValidator.evWrongInputType,
3246 TosaErrorValidator.evWrongOutputType,
3247 TosaErrorValidator.evWrongInputList,
3248 TosaErrorValidator.evWrongOutputList,
3249 TosaErrorValidator.evDimensionMismatch,
3250 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003251 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003252 "greater": {
3253 "op": Op.GREATER,
3254 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 "build_fcn": (
3256 build_comparison,
3257 TosaTensorGen.tgBroadcastFuzz,
3258 TosaTensorValuesGen.tvgDefault,
3259 None,
3260 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003261 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003262 "error_if_validators": (
3263 TosaErrorValidator.evRankMismatch,
3264 TosaErrorValidator.evWrongInputType,
3265 TosaErrorValidator.evWrongOutputType,
3266 TosaErrorValidator.evWrongInputList,
3267 TosaErrorValidator.evWrongOutputList,
3268 TosaErrorValidator.evDimensionMismatch,
3269 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003270 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003271 # Reduction operators
3272 "reduce_all": {
3273 "op": Op.REDUCE_ALL,
3274 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003275 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003276 "build_fcn": (
3277 build_reduce,
3278 TosaTensorGen.tgBasic,
3279 TosaTensorValuesGen.tvgDefault,
3280 TosaArgGen.agAxis,
3281 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003282 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003283 "error_if_validators": (
3284 TosaErrorValidator.evAxisLargerRank,
3285 TosaErrorValidator.evAxisSmallerZero,
3286 TosaErrorValidator.evShapeOfAxisNotOne,
3287 TosaErrorValidator.evWrongInputType,
3288 TosaErrorValidator.evWrongOutputType,
3289 TosaErrorValidator.evWrongRank,
3290 TosaErrorValidator.evWrongInputList,
3291 TosaErrorValidator.evWrongOutputList,
3292 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003293 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003294 "reduce_any": {
3295 "op": Op.REDUCE_ANY,
3296 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003297 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003298 "build_fcn": (
3299 build_reduce,
3300 TosaTensorGen.tgBasic,
3301 TosaTensorValuesGen.tvgDefault,
3302 TosaArgGen.agAxis,
3303 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003304 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003305 "error_if_validators": (
3306 TosaErrorValidator.evAxisLargerRank,
3307 TosaErrorValidator.evAxisSmallerZero,
3308 TosaErrorValidator.evShapeOfAxisNotOne,
3309 TosaErrorValidator.evWrongInputType,
3310 TosaErrorValidator.evWrongOutputType,
3311 TosaErrorValidator.evWrongRank,
3312 TosaErrorValidator.evWrongInputList,
3313 TosaErrorValidator.evWrongOutputList,
3314 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003315 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003316 "reduce_max": {
3317 "op": Op.REDUCE_MAX,
3318 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003319 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003320 "build_fcn": (
3321 build_reduce,
3322 TosaTensorGen.tgBasic,
3323 TosaTensorValuesGen.tvgDefault,
3324 TosaArgGen.agAxis,
3325 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003326 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003327 "error_if_validators": (
3328 TosaErrorValidator.evAxisLargerRank,
3329 TosaErrorValidator.evAxisSmallerZero,
3330 TosaErrorValidator.evShapeOfAxisNotOne,
3331 TosaErrorValidator.evWrongInputType,
3332 TosaErrorValidator.evWrongOutputType,
3333 TosaErrorValidator.evWrongRank,
3334 TosaErrorValidator.evWrongInputList,
3335 TosaErrorValidator.evWrongOutputList,
3336 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003339 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003340 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003341 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 "build_fcn": (
3343 build_reduce,
3344 TosaTensorGen.tgBasic,
3345 TosaTensorValuesGen.tvgDefault,
3346 TosaArgGen.agAxis,
3347 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003348 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003349 "error_if_validators": (
3350 TosaErrorValidator.evAxisLargerRank,
3351 TosaErrorValidator.evAxisSmallerZero,
3352 TosaErrorValidator.evShapeOfAxisNotOne,
3353 TosaErrorValidator.evWrongInputType,
3354 TosaErrorValidator.evWrongOutputType,
3355 TosaErrorValidator.evWrongRank,
3356 TosaErrorValidator.evWrongInputList,
3357 TosaErrorValidator.evWrongOutputList,
3358 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003359 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 "reduce_product": {
3361 "op": Op.REDUCE_PRODUCT,
3362 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003363 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003364 "build_fcn": (
3365 build_reduce,
3366 TosaTensorGen.tgBasic,
3367 TosaTensorValuesGen.tvgDefault,
3368 TosaArgGen.agAxis,
3369 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003370 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003371 "error_if_validators": (
3372 TosaErrorValidator.evAxisLargerRank,
3373 TosaErrorValidator.evAxisSmallerZero,
3374 TosaErrorValidator.evShapeOfAxisNotOne,
3375 TosaErrorValidator.evWrongInputType,
3376 TosaErrorValidator.evWrongOutputType,
3377 TosaErrorValidator.evWrongRank,
3378 TosaErrorValidator.evWrongInputList,
3379 TosaErrorValidator.evWrongOutputList,
3380 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003381 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003382 "reduce_sum": {
3383 "op": Op.REDUCE_SUM,
3384 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003385 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003386 "build_fcn": (
3387 build_reduce,
3388 TosaTensorGen.tgBasic,
3389 TosaTensorValuesGen.tvgReduceSum,
3390 TosaArgGen.agAxis,
3391 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003392 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003393 "error_if_validators": (
3394 TosaErrorValidator.evAxisLargerRank,
3395 TosaErrorValidator.evAxisSmallerZero,
3396 TosaErrorValidator.evShapeOfAxisNotOne,
3397 TosaErrorValidator.evWrongInputType,
3398 TosaErrorValidator.evWrongOutputType,
3399 TosaErrorValidator.evWrongRank,
3400 TosaErrorValidator.evWrongInputList,
3401 TosaErrorValidator.evWrongOutputList,
3402 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003403 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003404 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003405 "concat": {
3406 "op": Op.CONCAT,
3407 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003408 "build_fcn": (
3409 build_concat,
3410 TosaTensorGen.tgConcat,
3411 TosaTensorValuesGen.tvgConcat,
3412 TosaArgGen.agAxis,
3413 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003414 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003415 "error_if_validators": (
3416 TosaErrorValidator.evAxisLargerRank,
3417 TosaErrorValidator.evAxisSmallerZero,
3418 TosaErrorValidator.evConcatInputRankMismatch,
3419 TosaErrorValidator.evConcatShapeSumMismatch,
3420 TosaErrorValidator.evConcatInputDimMismatch,
3421 TosaErrorValidator.evWrongInputType,
3422 TosaErrorValidator.evWrongOutputType,
3423 TosaErrorValidator.evWrongOutputList,
3424 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003425 },
3426 "pad": {
3427 "op": Op.PAD,
3428 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003429 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003430 "build_fcn": (
3431 build_pad,
3432 TosaTensorGen.tgBasic,
3433 TosaTensorValuesGen.tvgDefault,
3434 TosaArgGen.agPad,
3435 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003436 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003437 "error_if_validators": (
3438 TosaErrorValidator.evWrongInputType,
3439 TosaErrorValidator.evPadSmallerZero,
3440 TosaErrorValidator.evWrongOutputType,
3441 TosaErrorValidator.evWrongInputList,
3442 TosaErrorValidator.evWrongOutputList,
3443 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003444 },
3445 "reshape": {
3446 "op": Op.RESHAPE,
3447 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003448 "build_fcn": (
3449 build_reshape,
3450 TosaTensorGen.tgBasic,
3451 TosaTensorValuesGen.tvgDefault,
3452 TosaArgGen.agReshape,
3453 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003454 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003455 "error_if_validators": (
3456 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3457 TosaErrorValidator.evWrongInputType,
3458 TosaErrorValidator.evWrongOutputType,
3459 TosaErrorValidator.evWrongInputList,
3460 TosaErrorValidator.evWrongOutputList,
3461 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003462 },
3463 "reverse": {
3464 "op": Op.REVERSE,
3465 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003466 "build_fcn": (
3467 build_reverse,
3468 TosaTensorGen.tgBasic,
3469 TosaTensorValuesGen.tvgDefault,
3470 TosaArgGen.agAxis,
3471 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003472 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003473 "error_if_validators": (
3474 TosaErrorValidator.evAxisSmallerZero,
3475 TosaErrorValidator.evAxisLargerRank,
3476 TosaErrorValidator.evWrongInputType,
3477 TosaErrorValidator.evWrongOutputType,
3478 TosaErrorValidator.evWrongInputList,
3479 TosaErrorValidator.evWrongOutputList,
3480 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003481 },
3482 "slice": {
3483 "op": Op.SLICE,
3484 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003485 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003486 "build_fcn": (
3487 build_slice,
3488 TosaTensorGen.tgBasic,
3489 TosaTensorValuesGen.tvgDefault,
3490 TosaArgGen.agSlice,
3491 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003492 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003493 "error_if_validators": (
3494 TosaErrorValidator.evStartSmallerZero,
3495 TosaErrorValidator.evSizeSmallerEqualZero,
3496 TosaErrorValidator.evStartSizeOutsideBounds,
3497 TosaErrorValidator.evSizeOutputShapeMismatch,
3498 TosaErrorValidator.evInputSizeStartLengthMismatch,
3499 TosaErrorValidator.evWrongRank,
3500 TosaErrorValidator.evWrongInputType,
3501 TosaErrorValidator.evWrongOutputType,
3502 TosaErrorValidator.evWrongInputList,
3503 TosaErrorValidator.evWrongOutputList,
3504 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003505 },
3506 "tile": {
3507 "op": Op.TILE,
3508 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003509 "build_fcn": (
3510 build_tile,
3511 TosaTensorGen.tgBasic,
3512 TosaTensorValuesGen.tvgDefault,
3513 TosaArgGen.agTile,
3514 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003515 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003516 "error_if_validators": (
3517 TosaErrorValidator.evWrongInputType,
3518 TosaErrorValidator.evWrongOutputType,
3519 TosaErrorValidator.evWrongInputList,
3520 TosaErrorValidator.evWrongOutputList,
3521 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003522 },
3523 "transpose": {
3524 "op": Op.TRANSPOSE,
3525 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003526 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003527 "build_fcn": (
3528 build_transpose,
3529 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003530 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003531 TosaArgGen.agTranspose,
3532 ),
3533 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003534 "error_if_validators": (
3535 TosaErrorValidator.evIndexOutsideBounds,
3536 TosaErrorValidator.evIndexUsedTwice,
3537 TosaErrorValidator.evWrongInputType,
3538 TosaErrorValidator.evWrongOutputType,
3539 TosaErrorValidator.evWrongInputList,
3540 TosaErrorValidator.evWrongOutputList,
3541 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003542 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003543 # Data nodes
3544 "const": {
3545 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003546 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003547 "build_fcn": (
3548 build_const,
3549 TosaTensorGen.tgBasic,
3550 TosaTensorValuesGen.tvgDefault,
3551 None,
3552 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003553 "types": TYPE_FIB,
3554 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003555 "identity": {
3556 "op": Op.IDENTITY,
3557 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003558 "build_fcn": (
3559 build_unary,
3560 TosaTensorGen.tgBasic,
3561 TosaTensorValuesGen.tvgDefault,
3562 None,
3563 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003564 "types": TYPE_FIB,
3565 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003566 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003567 "gather": {
3568 "op": Op.GATHER,
3569 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3570 "operands": (1, 0),
3571 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003572 "build_fcn": (
3573 build_gather,
3574 TosaTensorGen.tgBasic,
3575 TosaTensorValuesGen.tvgDefault,
3576 None,
3577 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003578 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003579 "error_if_validators": (
3580 TosaErrorValidator.evWrongInputType,
3581 TosaErrorValidator.evWrongOutputType,
3582 TosaErrorValidator.evWrongInputList,
3583 TosaErrorValidator.evWrongOutputList,
3584 TosaErrorValidator.evWrongRank,
3585 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003586 },
3587 "scatter": {
3588 "op": Op.SCATTER,
3589 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003590 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003591 "operands": (2, 0),
3592 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003593 "build_fcn": (
3594 build_scatter,
3595 TosaTensorGen.tgScatter,
3596 TosaTensorValuesGen.tvgDefault,
3597 None,
3598 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003599 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003600 "error_if_validators": (
3601 TosaErrorValidator.evWrongInputType,
3602 TosaErrorValidator.evWrongOutputType,
3603 TosaErrorValidator.evWrongInputList,
3604 TosaErrorValidator.evWrongOutputList,
3605 TosaErrorValidator.evWrongRank,
3606 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003607 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003608 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003609 "resize": {
3610 "op": Op.RESIZE,
3611 "operands": (1, 0),
3612 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003613 "build_fcn": (
3614 build_resize,
3615 TosaTensorGen.tgNHWC,
3616 TosaTensorValuesGen.tvgDefault,
3617 TosaArgGen.agResize,
3618 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003619 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003620 "invalid_test_validators": (
3621 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
3622 TosaInvalidValidator.ivBadStride,
3623 ),
3624 "error_if_validators": (
3625 TosaErrorValidator.evMaxDimExceeded,
3626 TosaErrorValidator.evStrideSmallerEqualZero,
3627 TosaErrorValidator.evStrideLargerDimension,
3628 TosaErrorValidator.evStrideLargerEqualMax,
3629 TosaErrorValidator.evOffsetSmallerEqualMin,
3630 TosaErrorValidator.evOffsetLargerEqualMax,
3631 TosaErrorValidator.evShiftNotZero,
3632 TosaErrorValidator.evShiftSmallerOne,
3633 TosaErrorValidator.evShiftLargerEleven,
3634 TosaErrorValidator.evWrongInputType,
3635 TosaErrorValidator.evWrongOutputType,
3636 TosaErrorValidator.evWrongRank,
3637 TosaErrorValidator.evWrongInputList,
3638 TosaErrorValidator.evWrongOutputList,
3639 TosaErrorValidator.evBatchMismatch,
3640 TosaErrorValidator.evChannelMismatch,
3641 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003642 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003643 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003644 "cast": {
3645 "op": Op.CAST,
3646 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003647 "build_fcn": (
3648 build_cast,
3649 TosaTensorGen.tgBasic,
3650 TosaTensorValuesGen.tvgDefault,
3651 TosaArgGen.agCast,
3652 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003653 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003654 "error_if_validators": (
3655 TosaErrorValidator.evWrongInputType,
3656 TosaErrorValidator.evWrongOutputType,
3657 TosaErrorValidator.evWrongInputList,
3658 TosaErrorValidator.evWrongOutputList,
3659 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003660 },
3661 "rescale": {
3662 "op": Op.RESCALE,
3663 "operands": (1, 0),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003664 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003665 "build_fcn": (
3666 build_rescale,
3667 TosaTensorGen.tgBasic,
3668 TosaTensorValuesGen.tvgDefault,
3669 TosaArgGen.agRescale,
3670 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003671 "types": [
3672 DType.UINT8,
3673 DType.INT8,
3674 DType.INT16,
3675 DType.INT32,
3676 DType.INT48,
3677 DType.UINT16,
3678 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003679 "error_if_validators": (
3680 TosaErrorValidator.evInputZeroPointNotZero,
3681 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003682 TosaErrorValidator.evU16InputZeroPointNotValid,
3683 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003684 TosaErrorValidator.evScaleTrue,
3685 TosaErrorValidator.evScaleNotTrue,
3686 TosaErrorValidator.evWrongInputType,
3687 TosaErrorValidator.evWrongOutputType,
3688 TosaErrorValidator.evWrongRank,
3689 TosaErrorValidator.evWrongInputList,
3690 TosaErrorValidator.evWrongOutputList,
3691 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003692 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003693 # Custom
3694 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003695 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003696 # Two varients of cond_if, one that generates one of two constant tensors (no
3697 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3698 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003699 "cond_if_const": {
3700 "op": Op.COND_IF,
3701 "operands": (0, 2),
3702 "build_fcn": (
3703 build_cond_if_const,
3704 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003705 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003706 TosaArgGen.agCondIf,
3707 ),
3708 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003709 "error_if_validators": (
3710 TosaErrorValidator.evOutputListThenGraphMismatch,
3711 TosaErrorValidator.evOutputListElseGraphMismatch,
3712 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003713 },
3714 "cond_if_binary": {
3715 "op": Op.COND_IF,
3716 "operands": (2, 0),
3717 "build_fcn": (
3718 build_cond_if_binary,
3719 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003720 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003721 TosaArgGen.agCondIf,
3722 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003723 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003724 "error_if_validators": (
3725 TosaErrorValidator.evInputListThenGraphMismatch,
3726 TosaErrorValidator.evInputListElseGraphMismatch,
3727 TosaErrorValidator.evOutputListThenGraphMismatch,
3728 TosaErrorValidator.evOutputListElseGraphMismatch,
3729 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003730 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003731 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003732 "while_loop": {
3733 "op": Op.WHILE_LOOP,
3734 "operands": (0, 1),
3735 "build_fcn": (
3736 build_while_loop,
3737 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003738 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003739 TosaArgGen.agWhileLoop,
3740 ),
3741 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 "error_if_validators": (
3743 TosaErrorValidator.evInputListOutputListMismatch,
3744 TosaErrorValidator.evInputListCondGraphMismatch,
3745 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3746 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3747 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
3748 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003749 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003750 }
3751
Kevin Cheng550ccc52021-03-03 11:21:43 -08003752
Eric Kunzee5e26762020-10-13 16:11:07 -07003753class OutputShaper:
3754 # Methods in this class compute the expected output shape and datatype
3755 # for common classes of operations
3756 def __init__(self):
3757 pass
3758
3759 # These methods return arguments that can be used for
3760 # creating a new output tensor
3761 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003762 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3763 if error_name != ErrorIf.RankMismatch:
3764 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003765 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003766
3767 shape = []
3768 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003769 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003770 shape.append(b.shape[i])
3771 else:
3772 shape.append(a.shape[i])
3773
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003774 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003775 all_dtypes = [
3776 DType.INT8,
3777 DType.INT16,
3778 DType.INT32,
3779 DType.INT48,
3780 DType.FLOAT,
3781 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003782 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3783 outputDType = rng.choice(wrong_dtypes)
3784 else:
3785 outputDType = a.dtype
3786
3787 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003788
3789 @staticmethod
3790 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003791 assert len(a.shape) == len(b.shape)
3792 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003793
3794 shape = []
3795 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003796 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003797 shape.append(a.shape[i])
3798
Kevin Cheng550ccc52021-03-03 11:21:43 -08003799 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003800
3801 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003802 def unaryOp(ser, rng, a, error_name=None):
3803 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003804 all_dtypes = [
3805 DType.INT8,
3806 DType.INT16,
3807 DType.INT32,
3808 DType.INT48,
3809 DType.FLOAT,
3810 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003811 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3812 outputDType = rng.choice(wrong_dtypes)
3813 else:
3814 outputDType = a.dtype
3815
3816 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003817
3818 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003819 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003820 if error_name != ErrorIf.RankMismatch:
3821 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003822 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003823
3824 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003825 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003826 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003827 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3828 else:
3829 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003830
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003831 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003832 all_dtypes = [
3833 DType.INT8,
3834 DType.INT16,
3835 DType.INT32,
3836 DType.INT48,
3837 DType.FLOAT,
3838 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003839 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3840 outputDType = rng.choice(wrong_dtypes)
3841 else:
3842 outputDType = a.dtype
3843
3844 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003845
3846 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003847 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003848 if error_name != ErrorIf.RankMismatch:
3849 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003850 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003851
3852 # Do broadcast
3853 shape = []
3854 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08003855 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07003856 shape.append(b.shape[i])
3857 else:
3858 shape.append(a.shape[i])
3859
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003860 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003861 wrong_dtypes = [
3862 DType.INT8,
3863 DType.INT16,
3864 DType.INT32,
3865 DType.INT48,
3866 DType.FLOAT,
3867 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003868 outputDType = rng.choice(wrong_dtypes)
3869 else:
3870 outputDType = DType.BOOL
3871
3872 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003873
3874 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01003875 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003876 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003877 if error_name not in [
3878 ErrorIf.AxisSmallerZero,
3879 ErrorIf.AxisLargerRank,
3880 ErrorIf.ShapeOfAxisNotOne,
3881 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01003882 shape[axis] = 1
3883 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
3884 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07003885
Matthew Haddond6ce7252021-09-29 15:35:44 +01003886 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003887 all_dtypes = [
3888 DType.INT8,
3889 DType.INT16,
3890 DType.INT32,
3891 DType.INT48,
3892 DType.FLOAT,
3893 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01003894 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3895 outputDType = rng.choice(wrong_dtypes)
3896 else:
3897 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003898
Matthew Haddond6ce7252021-09-29 15:35:44 +01003899 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003900
3901 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003902 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003903 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003904
3905 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
3906 del shape[axis]
3907
3908 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
3909 remove = rng.choice([True, False])
3910 if remove and len(shape) > 1:
3911 del shape[0]
3912 else:
3913 shape.append(1)
3914 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
3915 for i in range(len(shape)):
3916 shape[i] = shape[i] + rng.integers(1, 10)
3917
3918 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003919 all_dtypes = [
3920 DType.INT8,
3921 DType.INT16,
3922 DType.INT32,
3923 DType.INT48,
3924 DType.FLOAT,
3925 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003926 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
3927 outputDType = rng.choice(wrong_dtypes)
3928 else:
3929 outputDType = DType.INT32
3930
3931 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003932
3933 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003934 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003935
3936 # IFM: NHWC
3937 # Filter: OHWI
3938 # OFM: NHWC
3939
Kevin Cheng550ccc52021-03-03 11:21:43 -08003940 h = (
3941 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003942 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003943 + padding[0]
3944 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003945 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003946 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003947
Kevin Cheng550ccc52021-03-03 11:21:43 -08003948 w = (
3949 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003950 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003951 + padding[2]
3952 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003953 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003954 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003955
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003956 if error_name == ErrorIf.ConvOutputShapeMismatch:
3957 choices = [1, 2, 3]
3958 change = rng.choice(choices)
3959 # increment in multiples of stride to not hit non-integer error case
3960 if change in [1, 3]:
3961 h = h + (rng.choice(choices) * strides[0])
3962 if change in [2, 3]:
3963 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00003964
Eric Kunzee5e26762020-10-13 16:11:07 -07003965 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3966
Kevin Cheng3a478572021-01-22 17:21:02 -08003967 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003968 out_dtype = DType.INT32
3969 elif ifm.dtype == DType.INT16:
3970 out_dtype = DType.INT48
3971 elif ifm.dtype == DType.FLOAT:
3972 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00003973 elif error_name == ErrorIf.WrongInputType:
3974 # Pick some potentially correct output dtype if input type is incorrect
3975 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07003976 else:
Les Bell0e027d42021-11-09 14:42:14 +00003977 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
3978
3979 if error_name == ErrorIf.WrongOutputType:
3980 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
3981 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07003982
Kevin Cheng550ccc52021-03-03 11:21:43 -08003983 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003984
3985 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003986 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07003987
3988 # IFM: NDHWC
3989 # Filter: ODHWI
3990 # OFM: NDHWC
3991
3992 d = (
3993 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003994 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003995 + padding[0]
3996 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003997 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07003998 ) // strides[0] + 1
3999
4000 h = (
4001 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004002 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004003 + padding[2]
4004 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004005 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004006 ) // strides[1] + 1
4007
4008 w = (
4009 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004010 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004011 + padding[4]
4012 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004013 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004014 ) // strides[2] + 1
4015
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004016 if error_name == ErrorIf.ConvOutputShapeMismatch:
4017 choices = [1, 2, 3, 4]
4018 change = rng.choice(choices)
4019 # increment in multiples of stride to not hit non-integer error case
4020 if change in [1, 4]:
4021 d = d + (rng.choice(choices) * strides[0])
4022 if change in [2, 4]:
4023 h = h + (rng.choice(choices) * strides[1])
4024 if change in [3, 4]:
4025 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004026
Kevin Cheng1533b852021-09-01 12:51:58 -07004027 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4028
4029 if ifm.dtype == DType.INT8:
4030 out_dtype = DType.INT32
4031 elif ifm.dtype == DType.INT16:
4032 out_dtype = DType.INT48
4033 elif ifm.dtype == DType.FLOAT:
4034 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004035 elif error_name == ErrorIf.WrongInputType:
4036 # Pick some potentially correct output dtype if input type is incorrect
4037 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004038 else:
Les Bell0e027d42021-11-09 14:42:14 +00004039 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4040
4041 if error_name == ErrorIf.WrongOutputType:
4042 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4043 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004044
4045 return ser.addOutput(ofm_shape, out_dtype)
4046
4047 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004048 def depthwiseConv2dOp(
4049 ser, rng, ifm, filter, strides, padding, dilations, error_name=None
4050 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004051 # IFM: NHWC
4052 # Filter: HWCM
4053 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004054
Kevin Cheng550ccc52021-03-03 11:21:43 -08004055 h = (
4056 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004057 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004058 + padding[0]
4059 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004060 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004061 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004062
Kevin Cheng550ccc52021-03-03 11:21:43 -08004063 w = (
4064 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004065 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004066 + padding[2]
4067 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004068 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004069 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004070
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004071 if error_name == ErrorIf.ConvOutputShapeMismatch:
4072 choices = [1, 2, 3]
4073 change = rng.choice(choices)
4074 # increment in multiples of stride to not hit non-integer error case
4075 if change in [1, 3]:
4076 h = h + (rng.choice(choices) * strides[0])
4077 if change in [2, 3]:
4078 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004079
Eric Kunzee5e26762020-10-13 16:11:07 -07004080 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4081
Kevin Cheng3a478572021-01-22 17:21:02 -08004082 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004083 out_dtype = DType.INT32
4084 elif ifm.dtype == DType.INT16:
4085 out_dtype = DType.INT48
4086 elif ifm.dtype == DType.FLOAT:
4087 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004088 elif error_name == ErrorIf.WrongInputType:
4089 # Pick some potentially correct output dtype if input type is incorrect
4090 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004091 else:
Les Bell0e027d42021-11-09 14:42:14 +00004092 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4093
4094 if error_name == ErrorIf.WrongOutputType:
4095 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4096 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004097
Kevin Cheng550ccc52021-03-03 11:21:43 -08004098 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004099
4100 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004101 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004102 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004103 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004104 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004105 h = 1
4106 w = 1
4107 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004108 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4109 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004110
4111 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004112 choices = [1, 2, 3]
4113 change = rng.choice(choices)
4114 # increment in multiples of stride to not hit non-integer error case
4115 if change in [1, 3]:
4116 h = h + (rng.choice(choices) * stride[0])
4117 if change in [2, 3]:
4118 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004119 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004120
4121 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004122 all_dtypes = [
4123 DType.INT8,
4124 DType.INT16,
4125 DType.INT32,
4126 DType.INT48,
4127 DType.FLOAT,
4128 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004129 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4130 outputDType = rng.choice(wrong_dtypes)
4131 else:
4132 outputDType = ifm.dtype
4133
4134 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004135
4136 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004137 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004138 # input: N, IC
4139 # filter: OC, IC
4140 # output: N, OC
4141
4142 output_shape = [input.shape[0], filter.shape[0]]
4143
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004144 if error_name == ErrorIf.WrongOutputType:
4145 if input.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004146 incorrect_types = (
4147 DType.INT4,
4148 DType.INT8,
4149 DType.INT16,
4150 DType.INT48,
4151 DType.FLOAT,
4152 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004153 elif input.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004154 incorrect_types = (
4155 DType.INT4,
4156 DType.INT8,
4157 DType.INT16,
4158 DType.INT32,
4159 DType.FLOAT,
4160 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004161 elif input.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004162 incorrect_types = (
4163 DType.INT4,
4164 DType.INT8,
4165 DType.INT16,
4166 DType.INT32,
4167 DType.INT48,
4168 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004169 out_dtype = rng.choice(a=incorrect_types)
4170 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004171 out_dtype = DType.INT32
4172 elif input.dtype == DType.INT16:
4173 out_dtype = DType.INT48
4174 elif input.dtype == DType.FLOAT:
4175 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004176 elif error_name == ErrorIf.WrongInputType:
4177 # Pick some potentially correct output dtype if input type is incorrect
4178 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004179 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004180 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004181
Kevin Cheng550ccc52021-03-03 11:21:43 -08004182 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004183
4184 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004185 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004186 # a: N, H, C
4187 # b: N, C, W
4188 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004189
Kevin Cheng2d60f002021-06-09 14:18:32 -07004190 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004191
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004192 if error_name == ErrorIf.WrongOutputType:
4193 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004194 incorrect_types = (
4195 DType.INT4,
4196 DType.INT8,
4197 DType.INT16,
4198 DType.INT48,
4199 DType.FLOAT,
4200 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004201 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004202 incorrect_types = (
4203 DType.INT4,
4204 DType.INT8,
4205 DType.INT16,
4206 DType.INT32,
4207 DType.FLOAT,
4208 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004209 elif a.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004210 incorrect_types = (
4211 DType.INT4,
4212 DType.INT8,
4213 DType.INT16,
4214 DType.INT32,
4215 DType.INT48,
4216 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004217 out_dtype = rng.choice(a=incorrect_types)
4218 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004219 out_dtype = DType.INT32
4220 elif a.dtype == DType.INT16:
4221 out_dtype = DType.INT48
4222 elif a.dtype == DType.FLOAT:
4223 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004224 elif error_name == ErrorIf.WrongInputType:
4225 # Pick some potentially correct output dtype if input type is incorrect
4226 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004227 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004228 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004229
Kevin Cheng550ccc52021-03-03 11:21:43 -08004230 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004231
4232 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004233 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004234 input1 = a[0]
4235 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004236
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004237 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004238 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004239 if not (
4240 # unable to concat tensors of different ranks
4241 error_name == ErrorIf.ConcatInputRankMismatch
4242 # unable to concat tensors along an invalid axis
4243 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004244 ):
4245 for tensor in remaining_inputs:
4246 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004247
Matthew Haddon01c359d2021-10-15 16:30:48 +01004248 if error_name == ErrorIf.ConcatShapeSumMismatch:
4249 output_shape[axis] += rng.integers(5, 10)
4250
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004251 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004252 all_dtypes = {
4253 DType.INT8,
4254 DType.INT16,
4255 DType.INT32,
4256 DType.INT48,
4257 DType.FLOAT,
4258 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004259 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4260 outputDType = rng.choice(wrong_dtypes)
4261 else:
4262 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004263
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004264 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004265
4266 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004267 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004268
4269 output_shape = a.shape.copy()
4270
4271 for i in range(len(output_shape)):
4272 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4273
Matthew Haddone807aae2021-10-11 18:12:58 +01004274 # Fix negative output shape if error_if test causes it
4275 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
4276 output_shape = [i if i >= 1 else 1 for i in output_shape]
4277
4278 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004279 all_dtypes = [
4280 DType.INT8,
4281 DType.INT16,
4282 DType.INT32,
4283 DType.INT48,
4284 DType.FLOAT,
4285 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004286 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4287 outputDType = rng.choice(wrong_dtypes)
4288 else:
4289 outputDType = a.dtype
4290
4291 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004292
4293 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004294 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004295 output_shape = shape.copy()
4296
Matthew Haddone807aae2021-10-11 18:12:58 +01004297 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4298 for i in range(len(output_shape)):
4299 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4300
4301 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004302 all_dtypes = [
4303 DType.INT8,
4304 DType.INT16,
4305 DType.INT32,
4306 DType.INT48,
4307 DType.FLOAT,
4308 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004309 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4310 outputDType = rng.choice(wrong_dtypes)
4311 else:
4312 outputDType = a.dtype
4313
4314 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004315
4316 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004317 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004318
Matthew Haddone807aae2021-10-11 18:12:58 +01004319 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004320 all_dtypes = [
4321 DType.INT8,
4322 DType.INT16,
4323 DType.INT32,
4324 DType.INT48,
4325 DType.FLOAT,
4326 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004327 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4328 outputDType = rng.choice(wrong_dtypes)
4329 else:
4330 outputDType = a.dtype
4331
4332 if error_name == ErrorIf.SizeOutputShapeMismatch:
4333 output_shape = size.copy()
4334 for index in range(len(output_shape)):
4335 if output_shape[index] <= 2:
4336 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4337 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004338 output_shape[index] = output_shape[index] + rng.choice(
4339 [-2, -1, 1, 2]
4340 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004341 else:
4342 output_shape = size.copy()
4343
4344 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004345
4346 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004347 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004348
4349 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004350 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004351
4352 for i in range(len(output_shape)):
4353 output_shape[i] = a.shape[i] * multiples[i]
4354
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004355 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004356 all_dtypes = [
4357 DType.INT8,
4358 DType.INT16,
4359 DType.INT32,
4360 DType.INT48,
4361 DType.FLOAT,
4362 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004363 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4364 outputDType = rng.choice(wrong_dtypes)
4365 else:
4366 outputDType = a.dtype
4367
4368 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004369
4370 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004371 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004372 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004373
Kevin Cheng550ccc52021-03-03 11:21:43 -08004374 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004375
Matthew Haddone807aae2021-10-11 18:12:58 +01004376 if error_name == ErrorIf.IndexOutsideBounds:
4377 for i in range(len(output_shape)):
4378 output_shape[i] = a.shape[0]
4379 else:
4380 for i in range(len(output_shape)):
4381 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004382
Matthew Haddone807aae2021-10-11 18:12:58 +01004383 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004384 all_dtypes = [
4385 DType.INT8,
4386 DType.INT16,
4387 DType.INT32,
4388 DType.INT48,
4389 DType.FLOAT,
4390 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004391 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4392 outputDType = rng.choice(wrong_dtypes)
4393 else:
4394 outputDType = a.dtype
4395
4396 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004397
4398 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004399 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004400 if error_name != ErrorIf.WrongRank:
4401 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004402 assert len(indices.shape) == 2
4403 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004404
Kevin Cheng77d0f762020-11-24 10:26:32 -08004405 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4406
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004407 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004408 all_dtypes = [
4409 DType.INT8,
4410 DType.INT16,
4411 DType.INT32,
4412 DType.INT48,
4413 DType.FLOAT,
4414 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004415 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4416 outputDType = rng.choice(wrong_dtypes)
4417 else:
4418 outputDType = values.dtype
4419
4420 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004421
4422 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004423 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004424 if error_name != ErrorIf.WrongRank:
4425 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004426 assert len(indices.shape) == 2
4427 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004428 assert values_in.shape[0] == indices.shape[0] # N
4429 assert input.shape[1] == indices.shape[1] # W
4430 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004431
4432 output_shape = values_in.shape
4433
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004434 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004435 all_dtypes = [
4436 DType.INT8,
4437 DType.INT16,
4438 DType.INT32,
4439 DType.INT48,
4440 DType.FLOAT,
4441 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004442 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4443 outputDType = rng.choice(wrong_dtypes)
4444 else:
4445 outputDType = values_in.dtype
4446
4447 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004448
4449 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004450 def tableOp(ser, rng, input, error_name=None):
4451 # Same shape as the input, dtype dependent on input dtype
4452 if error_name != ErrorIf.WrongInputType:
4453 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004454 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004455 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004456 wrong_dtypes = [
4457 DType.INT8,
4458 DType.INT16,
4459 DType.INT32,
4460 DType.INT48,
4461 DType.FLOAT,
4462 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004463 wrong_dtypes.remove(output_dtype)
4464 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004465 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004466
4467 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004468 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004469 serializer,
4470 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004471 input,
4472 mode,
4473 stride,
4474 offset,
4475 shift,
4476 stride_fp,
4477 offset_fp,
4478 output_dims,
4479 input_dtype,
4480 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004481 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004482 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01004483 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004484 output_dims = [
4485 input.shape[0],
4486 output_dims[0],
4487 output_dims[0],
4488 input.shape[0],
4489 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004490 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004491 if error_name == ErrorIf.BatchMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004492 output_dims = [
4493 input.shape[0] + rng.integers(1, 10),
4494 output_dims[0],
4495 output_dims[1],
4496 input.shape[3],
4497 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004498 elif error_name == ErrorIf.ChannelMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004499 output_dims = [
4500 input.shape[0],
4501 output_dims[0],
4502 output_dims[1],
4503 input.shape[3] + rng.integers(1, 10),
4504 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004505 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004506 output_dims = [
4507 input.shape[0],
4508 output_dims[0],
4509 output_dims[1],
4510 input.shape[3],
4511 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07004512
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004513 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004514
4515 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004516 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004517 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004518
4519 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00004520 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004521 if error_name == ErrorIf.ConvOutputShapeMismatch:
4522 choices = [1, 2, 3]
4523 change = rng.choice(choices)
4524 if change in [1, 3]:
4525 output_shape[1] = output_shape[1] + rng.choice(choices)
4526 if change in [2, 3]:
4527 output_shape[2] = output_shape[2] + rng.choice(choices)
4528
Kevin Cheng3a478572021-01-22 17:21:02 -08004529 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004530 out_dtype = DType.INT32
4531 elif ifm.dtype == DType.INT16:
4532 out_dtype = DType.INT48
4533 elif ifm.dtype == DType.FLOAT:
4534 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004535 elif error_name == ErrorIf.WrongInputType:
4536 # Pick some potentially correct output dtype if input type is incorrect
4537 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004538 else:
Les Bell0e027d42021-11-09 14:42:14 +00004539 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4540
4541 if error_name == ErrorIf.WrongOutputType:
4542 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4543 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004544
Kevin Cheng550ccc52021-03-03 11:21:43 -08004545 return ser.addOutput(output_shape, out_dtype)