blob: 262a652630d16da76156b20c76d7557a90db7e39 [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
16from generator.tosa_utils import usableDTypes
Les Bell0e027d42021-11-09 14:42:14 +000017from tosa.DType import DType
18from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010019
20
Eric Kunzee5e26762020-10-13 16:11:07 -070021class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010022 # Maximum rank of tensor supported by test generator.
23 TOSA_TENSOR_MAX_RANK = 6
24
Eric Kunzee5e26762020-10-13 16:11:07 -070025 def __init__(self, args):
26 self.args = args
27 self.basePath = args.output_dir
28 self.random_seed = args.random_seed
29 self.ser = None
30 self.rng = np.random.default_rng(self.random_seed)
31 self.createDynamicOpLists()
32 self.initOpListDefaults()
33 self.quantGen = TosaQuantGen()
34 # Force makeShape to do a specific starting shape
35 self.targetted_shape = None
36
37 def createSerializer(self, opName, testPath):
38 self.testPath = os.path.join(opName, testPath)
39
40 fullPath = os.path.join(self.basePath, self.testPath)
41 os.makedirs(fullPath, exist_ok=True)
42 self.ser = ts.TosaSerializer(fullPath)
43
44 def getSerializer(self):
45 return self.ser
46
47 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080048 with open(
49 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
50 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070051 fd.write(self.ser.serialize())
52
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
54 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070055
Matthew Haddon74567092021-07-16 15:38:20 +010056 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000057 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010058 seed = self.random_seed + 1
59 self.rng = np.random.default_rng(seed)
60
Eric Kunzee5e26762020-10-13 16:11:07 -070061 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070062 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070063 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070064 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070065 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070066 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070067 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010068 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
69 elif dtype == DType.UINT8:
70 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070071 elif dtype == DType.INT16:
72 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010073 elif dtype == DType.UINT16:
74 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070075 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080076 return np.int32(
77 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
78 )
Eric Kunzee5e26762020-10-13 16:11:07 -070079 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080080 return np.int64(
81 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
82 )
Eric Kunzee5e26762020-10-13 16:11:07 -070083 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +010084 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070085 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -080086 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070087
Kevin Cheng989cb052021-04-28 16:29:44 -070088 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -070089 placeholders = []
90
Kevin Cheng989cb052021-04-28 16:29:44 -070091 assert len(shape_list) == len(dtype_list)
92
93 for idx, shape in enumerate(shape_list):
94 arr = self.getRandTensor(shape, dtype_list[idx])
95 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -070096
97 return placeholders
98
Kevin Cheng989cb052021-04-28 16:29:44 -070099 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 consts = []
101
Kevin Cheng989cb052021-04-28 16:29:44 -0700102 assert len(shape_list) == len(dtype_list)
103
104 for idx, shape in enumerate(shape_list):
105 arr = self.getRandTensor(shape, dtype_list[idx])
106 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700107
108 return consts
109
110 def makeShape(self, rank):
111 if self.targetted_shape:
112 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800113 return np.int32(
114 self.rng.integers(
115 low=self.args.tensor_shape_range[0],
116 high=self.args.tensor_shape_range[1],
117 size=rank,
118 )
119 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 def setTargetShape(self, shape):
122 self.targetted_shape = shape
123
124 def randInt(self, low=0, high=256):
125 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
126
127 def getRandNumberDType(self, dtype):
128 if dtype == DType.FLOAT:
129 return self.rng.random()
130 elif dtype == DType.BOOL:
131 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700132 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700133 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700134 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100136 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700137 elif dtype == DType.INT16:
138 low, high = (-32768, 32768)
139 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800142 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 # Special size
144 return np.int64(self.rng.integers(low, high, size=1))[0]
145 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800146 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700147
148 return np.int32(self.rng.integers(low, high, size=1))[0]
149
150 def shapeStr(self, shape):
151
152 sStr = []
153 # Convert to strings
154 for i in shape:
155 sStr.append(str(i))
156
Kevin Cheng550ccc52021-03-03 11:21:43 -0800157 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700158
159 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -0700160 if isinstance(t, list):
161 assert len(t) >= 2
162 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 else:
Kevin Cheng989cb052021-04-28 16:29:44 -0700164 if t == DType.BOOL:
165 return "b"
166 elif t == DType.INT4:
167 return "i4"
168 elif t == DType.INT8:
169 return "i8"
170 elif t == DType.UINT8:
171 return "u8"
172 elif t == DType.INT16:
173 return "i16"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100174 elif t == DType.UINT16:
175 return "u16"
Kevin Cheng989cb052021-04-28 16:29:44 -0700176 elif t == DType.INT32:
177 return "i32"
178 elif t == DType.INT48:
179 return "i48"
180 elif t == DType.FLOAT:
181 return "float"
182 else:
183 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184
185 def typeWidth(self, t):
Jeremy Johnson5d1a3472022-03-31 09:50:06 +0100186 """Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -0800187 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -0700188 return 4
189 elif t == DType.INT8:
190 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -0800191 elif t == DType.UINT8:
192 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 elif t == DType.INT16:
194 return 16
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100195 elif t == DType.UINT16:
196 return 16
Eric Kunzee5e26762020-10-13 16:11:07 -0700197 elif t == DType.INT32:
198 return 32
199 elif t == DType.INT48:
200 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +0100201 elif t == DType.FLOAT:
202 return 32
203 elif t == DType.BOOL:
204 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700205 else:
Les Bell729b0352021-11-24 10:28:21 +0000206 raise Exception(f"Unknown dtype, cannot determine width: {t}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
208 # Argument generators
209 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
210 # Where the string descriptor is used to generate the test name and
211 # The build_fcn_arg_list is expanded and passed to the operator test
212 # build function
213
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100214 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
215 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
216
Matthew Haddon848efb42021-09-09 12:30:53 +0100217 # build_placeholder returns an int, ABS/other ops does not
218 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100219 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
220 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000221 elif op["op"] == Op.IDENTITY:
222 self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100223 return result_tens
224
225 # Ensure new output type has correct qinfo
226 if error_name == ErrorIf.WrongOutputType:
227 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
228 qinfo = ts.TosaSerializerQuantInfo()
229 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000230 TosaQuantGen.getQinfo(self, a.dtype),
231 TosaQuantGen.getQinfo(self, result_tens.dtype),
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100232 )
233
234 # Invalidate Input/Output list for error if checks.
235 input_list = [a.name]
236 output_list = [result_tens.name]
237 pCount, cCount = op["operands"]
238 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000239 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
240 self, error_name, input_list, output_list
241 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100242
Les Bell729b0352021-11-24 10:28:21 +0000243 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100244 self.ser,
245 validator_fcns,
246 error_name,
247 op=op,
248 input_dtype=a.dtype,
249 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000250 qinfo=qinfo,
251 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100252 input_list=input_list,
253 output_list=output_list,
254 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000255 ):
256 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100257
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000258 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700259 return result_tens
260
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100261 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000262 result_tens = OutputShaper.binaryBroadcastOp(
263 self.ser, self.rng, a, b, error_name
264 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100265
266 # Invalidate Input/Output list for error if checks.
267 input_list = [a.name, b.name]
268 output_list = [result_tens.name]
269 pCount, cCount = op["operands"]
270 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000271 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
272 self, error_name, input_list, output_list
273 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100274
Les Bell729b0352021-11-24 10:28:21 +0000275 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100276 self.ser,
277 validator_fcns,
278 error_name,
279 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000280 input1=a,
281 input2=b,
282 input_dtype=a.dtype,
283 output_dtype=result_tens.dtype,
284 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100285 input_list=input_list,
286 output_list=output_list,
287 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000288 ):
289 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100290
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000291 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 return result_tens
293
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100294 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700295 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000296 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700297 return result_tens
298
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000299 def build_arithmetic_right_shift(
300 self, op, a, b, round, validator_fcns=None, error_name=None
301 ):
302 result_tens = OutputShaper.binaryBroadcastOp(
303 self.ser, self.rng, a, b, error_name
304 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100305
306 # Invalidate Input/Output list for error if checks.
307 input_list = [a.name, b.name]
308 output_list = [result_tens.name]
309 pCount, cCount = op["operands"]
310 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000311 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
312 self, error_name, input_list, output_list
313 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100314
Les Bell729b0352021-11-24 10:28:21 +0000315 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100316 self.ser,
317 validator_fcns,
318 error_name,
319 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000320 input1=a,
321 input2=b,
322 input_dtype=a.dtype,
323 output_dtype=result_tens.dtype,
324 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100325 input_list=input_list,
326 output_list=output_list,
327 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000328 ):
329 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800330
331 attr = ts.TosaSerializerAttribute()
332 attr.ArithmeticRightShiftAttribute(round)
333
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000334 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800335 return result_tens
336
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100337 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000338 result_tens = OutputShaper.binaryBroadcastOp(
339 self.ser, self.rng, a, b, error_name
340 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700341
342 # Special for multiply:
343 # Force the result to INT32 for INT types
344 if a.dtype != DType.FLOAT:
345 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100346 if error_name == ErrorIf.WrongOutputType:
347 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
348 outputDType = self.rng.choice(all_dtypes)
349 result_tens.setDtype(outputDType)
350
351 # Invalidate Input/Output list for error if checks.
352 input_list = [a.name, b.name]
353 output_list = [result_tens.name]
354 pCount, cCount = op["operands"]
355 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000356 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
357 self, error_name, input_list, output_list
358 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100359
Les Bell729b0352021-11-24 10:28:21 +0000360 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100361 self.ser,
362 validator_fcns,
363 error_name,
364 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000365 input1=a,
366 input2=b,
367 input_dtype=a.dtype,
368 output_dtype=result_tens.dtype,
369 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100370 input_list=input_list,
371 output_list=output_list,
372 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000373 ):
374 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700375
Kevin Chengaee1fac2020-11-11 13:54:06 -0800376 attr = ts.TosaSerializerAttribute()
377 attr.MulAttribute(shift)
378
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000379 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700380 return result_tens
381
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100382 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
383 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
Kevin Chengfe392ce2021-10-18 21:51:55 +0000385 attr = ts.TosaSerializerAttribute()
386 attr.TableAttribute(table)
387
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100388 # Invalidate Input/Output list for error if checks.
389 input_list = [a.name]
390 output_list = [result_tens.name]
391 pCount, cCount = op["operands"]
392 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
394 self, error_name, input_list, output_list
395 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100396
Les Bell729b0352021-11-24 10:28:21 +0000397 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100398 self.ser,
399 validator_fcns,
400 error_name,
401 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000402 input_shape=a.shape,
403 input_dtype=a.dtype,
404 output_dtype=result_tens.dtype,
405 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100406 input_list=input_list,
407 output_list=output_list,
408 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000409 ):
410 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100411
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000412 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700413
414 return result_tens
415
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100416 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
417 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
418
419 # Invalidate Input/Output list for error if checks.
420 input_list = [cond.name, a.name, b.name]
421 output_list = [result_tens.name]
422 pCount, cCount = op["operands"]
423 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
425 self, error_name, input_list, output_list
426 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100427
Les Bell729b0352021-11-24 10:28:21 +0000428 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100429 self.ser,
430 validator_fcns,
431 error_name,
432 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000433 input1=cond,
434 input2=a,
435 input3=b,
436 input_shape=a.shape,
437 input_dtype=a.dtype,
438 output_dtype=result_tens.dtype,
439 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100440 input_list=input_list,
441 output_list=output_list,
442 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000443 ):
444 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100445
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000446 self.ser.addOperator(
447 op["op"],
448 input_list,
449 output_list,
450 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700451 return result_tens
452
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100453 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000454 result_tens = OutputShaper.binaryComparisonOp(
455 self.ser, self.rng, a, b, error_name
456 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457
458 # Invalidate Input/Output list for error if checks.
459 input_list = [a.name, b.name]
460 output_list = [result_tens.name]
461 pCount, cCount = op["operands"]
462 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000463 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
464 self, error_name, input_list, output_list
465 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100466
Les Bell729b0352021-11-24 10:28:21 +0000467 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100468 self.ser,
469 validator_fcns,
470 error_name,
471 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000472 input1=a,
473 input2=b,
474 input_shape=a.shape,
475 input_dtype=a.dtype,
476 output_shape=result_tens.shape,
477 output_dtype=result_tens.dtype,
478 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100479 input_list=input_list,
480 output_list=output_list,
481 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000482 ):
483 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100484
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000485 self.ser.addOperator(
486 op["op"],
487 input_list,
488 output_list,
489 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700490 return result_tens
491
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100492 def build_argmax(self, op, a, axis, validator_fcns, error_name):
493 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
494
495 # Invalidate Input/Output list for error if checks.
496 input_list = [a.name]
497 output_list = [result_tens.name]
498 pCount, cCount = op["operands"]
499 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
501 self, error_name, input_list, output_list
502 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100503
Les Bell729b0352021-11-24 10:28:21 +0000504 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100505 self.ser,
506 validator_fcns,
507 error_name,
508 op=op,
509 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000510 input_shape=a.shape,
511 input_dtype=a.dtype,
512 output_shape=result_tens.shape,
513 output_dtype=result_tens.dtype,
514 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100515 input_list=input_list,
516 output_list=output_list,
517 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000518 ):
519 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700520
521 attr = ts.TosaSerializerAttribute()
522 attr.AxisAttribute(axis)
523
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000524 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 return result_tens
526
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 def build_pool2d(
528 self,
529 op,
530 input,
531 stride,
532 pad,
533 kernel,
534 validator_fcns=None,
535 error_name=None,
536 qinfo=None,
537 ):
538 result_tens = OutputShaper.pool2dOp(
539 self.ser, self.rng, input, kernel, stride, pad, error_name
540 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100541
542 # Ensure new output type has correct qinfo
543 if error_name == ErrorIf.WrongInputType:
544 if input.dtype not in [DType.INT8, DType.UINT8]:
545 qinfo = ts.TosaSerializerQuantInfo()
546 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000547 TosaQuantGen.getQinfo(self, input.dtype),
548 TosaQuantGen.getQinfo(self, result_tens.dtype),
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100549 )
550
551 # Invalidate Input/Output list for error if checks.
552 input_list = [input.name]
553 output_list = [result_tens.name]
554 pCount, cCount = op["operands"]
555 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000556 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
557 self, error_name, input_list, output_list
558 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100559
Les Bell729b0352021-11-24 10:28:21 +0000560 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100561 self.ser,
562 validator_fcns,
563 error_name,
564 op=op,
565 input_shape=input.shape,
566 input_dtype=input.dtype,
567 output_shape=result_tens.shape,
568 output_dtype=result_tens.dtype,
569 kernel=kernel,
570 stride=stride,
571 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000572 qinfo=qinfo,
573 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100574 input_list=input_list,
575 output_list=output_list,
576 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000577 ):
578 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700579
580 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700581 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -0700582
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000583 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700584 return result_tens
585
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000586 def build_conv2d(
587 self,
588 op,
589 ifm,
590 filter,
591 bias,
592 strides,
593 padding,
594 dilations,
595 validator_fcns=None,
596 error_name=None,
597 qinfo=None,
598 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800599 assert len(padding) == 4
600 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000601 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
602 )
603
604 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000605 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
606 DType.INT8,
607 DType.UINT8,
608 ):
Les Bell0e027d42021-11-09 14:42:14 +0000609 qinfo = ts.TosaSerializerQuantInfo()
610 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000611 TosaQuantGen.getQinfo(self, ifm.dtype),
612 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000613 )
614
615 # Invalidate Input/Output list for error_if checks.
616 input_list = [ifm.name, filter.name, bias.name]
617 output_list = [result_tens.name]
618 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000619 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
620 self, error_name, input_list, output_list
621 )
Les Bell0e027d42021-11-09 14:42:14 +0000622
Les Bell729b0352021-11-24 10:28:21 +0000623 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000624 self.ser,
625 validator_fcns,
626 error_name,
627 op=op,
628 input_dtype=ifm.dtype,
629 weight_dtype=filter.dtype,
630 output_dtype=result_tens.dtype,
631 qinfo=qinfo,
632 input_list=input_list,
633 num_operands=num_operands,
634 output_list=output_list,
635 pad=padding,
636 stride=strides,
637 dilation=dilations,
638 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100639 weight_shape=filter.shape,
640 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000641 ):
642 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700643
644 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700645 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700646
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000647 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700648 return result_tens
649
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000650 def build_conv3d(
651 self,
652 op,
653 ifm,
654 filter,
655 bias,
656 strides,
657 padding,
658 dilations,
659 validator_fcns=None,
660 error_name=None,
661 qinfo=None,
662 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700663 assert len(padding) == 6
664 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000665 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
666 )
667
668 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000669 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
670 DType.INT8,
671 DType.UINT8,
672 ):
Les Bell0e027d42021-11-09 14:42:14 +0000673 qinfo = ts.TosaSerializerQuantInfo()
674 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000675 TosaQuantGen.getQinfo(self, ifm.dtype),
676 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000677 )
678
679 # Invalidate Input/Output list for error_if checks.
680 input_list = [ifm.name, filter.name, bias.name]
681 output_list = [result_tens.name]
682 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000683 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
684 self, error_name, input_list, output_list
685 )
Les Bell0e027d42021-11-09 14:42:14 +0000686
Les Bell729b0352021-11-24 10:28:21 +0000687 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000688 self.ser,
689 validator_fcns,
690 error_name,
691 op=op,
692 input_dtype=ifm.dtype,
693 weight_dtype=filter.dtype,
694 output_dtype=result_tens.dtype,
695 qinfo=qinfo,
696 input_list=input_list,
697 num_operands=num_operands,
698 output_list=output_list,
699 pad=padding,
700 stride=strides,
701 dilation=dilations,
702 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100703 weight_shape=filter.shape,
704 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000705 ):
706 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700707
708 attr = ts.TosaSerializerAttribute()
709 attr.ConvAttribute(padding, strides, dilations)
710
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000711 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Kevin Cheng1533b852021-09-01 12:51:58 -0700712 return result_tens
713
Kevin Cheng550ccc52021-03-03 11:21:43 -0800714 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000715 self,
716 op,
717 ifm,
718 filter,
719 bias,
720 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700721 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000722 output_shape,
723 validator_fcns=None,
724 error_name=None,
725 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800726 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700727 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000728 result_tens = OutputShaper.transposeConv2DOp(
729 self.ser, self.rng, ifm, output_shape, error_name
730 )
Les Bell0e027d42021-11-09 14:42:14 +0000731
732 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000733 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
734 DType.INT8,
735 DType.UINT8,
736 ):
Les Bell0e027d42021-11-09 14:42:14 +0000737 qinfo = ts.TosaSerializerQuantInfo()
738 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000739 TosaQuantGen.getQinfo(self, ifm.dtype),
740 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000741 )
742
743 # Invalidate Input/Output list for error_if checks.
744 input_list = [ifm.name, filter.name, bias.name]
745 output_list = [result_tens.name]
746 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000747 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
748 self, error_name, input_list, output_list
749 )
Les Bell0e027d42021-11-09 14:42:14 +0000750
Les Bell729b0352021-11-24 10:28:21 +0000751 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000752 self.ser,
753 validator_fcns,
754 error_name,
755 op=op,
756 input_dtype=ifm.dtype,
757 weight_dtype=filter.dtype,
758 output_dtype=result_tens.dtype,
759 qinfo=qinfo,
760 input_list=input_list,
761 num_operands=num_operands,
762 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700763 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000764 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000765 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100766 weight_shape=filter.shape,
767 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000768 ):
769 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700770
771 attr = ts.TosaSerializerAttribute()
TatWai Chong24594f52022-06-08 00:48:04 -0700772 attr.TransposeConvAttribute(out_pad, stride, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -0700773
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000774 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700775 return result_tens
776
Kevin Cheng550ccc52021-03-03 11:21:43 -0800777 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000778 self,
779 op,
780 ifm,
781 filter,
782 bias,
783 strides,
784 padding,
785 dilations,
786 validator_fcns=None,
787 error_name=None,
788 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800789 ):
790 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000791 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
792 )
793
794 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000795 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
796 DType.INT8,
797 DType.UINT8,
798 ):
Les Bell0e027d42021-11-09 14:42:14 +0000799 qinfo = ts.TosaSerializerQuantInfo()
800 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 TosaQuantGen.getQinfo(self, ifm.dtype),
802 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000803 )
804
805 # Invalidate Input/Output list for error_if checks.
806 input_list = [ifm.name, filter.name, bias.name]
807 output_list = [result_tens.name]
808 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000809 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
810 self, error_name, input_list, output_list
811 )
Les Bell0e027d42021-11-09 14:42:14 +0000812
Les Bell729b0352021-11-24 10:28:21 +0000813 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000814 self.ser,
815 validator_fcns,
816 error_name,
817 op=op,
818 input_dtype=ifm.dtype,
819 weight_dtype=filter.dtype,
820 output_dtype=result_tens.dtype,
821 qinfo=qinfo,
822 input_list=input_list,
823 num_operands=num_operands,
824 output_list=output_list,
825 pad=padding,
826 stride=strides,
827 dilation=dilations,
828 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100829 weight_shape=filter.shape,
830 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000831 ):
832 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700833
834 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700835 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700836
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000837 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700838 return result_tens
839
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000840 def build_fully_connected(
841 self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
842 ):
843 result_tens = OutputShaper.fullyConnectedOp(
844 self.ser, self.rng, ifm, filter, error_name
845 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100846
847 # Invalidate Input/Output list for error if checks.
848 input_list = [ifm.name, filter.name, bias.name]
849 output_list = [result_tens.name]
850 pCount, cCount = op["operands"]
851 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000852 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
853 self, error_name, input_list, output_list
854 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100855
Les Bell729b0352021-11-24 10:28:21 +0000856 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100857 self.ser,
858 validator_fcns,
859 error_name,
860 op=op,
861 input_shape=ifm.shape,
862 input_dtype=ifm.dtype,
863 weight_dtype=filter.dtype,
864 output_shape=result_tens.shape,
865 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000866 qinfo=qinfo,
867 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100868 input_list=input_list,
869 output_list=output_list,
870 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000871 ):
872 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700873
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000874 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700875 return result_tens
876
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100877 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
878 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
879
880 # Invalidate Input/Output list for error if checks.
881 input_list = [a.name, b.name]
882 output_list = [result_tens.name]
883 pCount, cCount = op["operands"]
884 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000885 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
886 self, error_name, input_list, output_list
887 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100888
Les Bell729b0352021-11-24 10:28:21 +0000889 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100890 self.ser,
891 validator_fcns,
892 error_name,
893 op=op,
894 input_shape=a.shape,
895 input_dtype=a.dtype,
896 input2_shape=b.shape,
897 input2_dtype=b.dtype,
898 output_shape=result_tens.shape,
899 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000900 qinfo=qinfo,
901 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100902 input_list=input_list,
903 output_list=output_list,
904 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000905 ):
906 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100907
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000908 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700909 return result_tens
910
Matthew Haddond6ce7252021-09-29 15:35:44 +0100911 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
912 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
913
914 # Invalidate Input/Output list for error if checks.
915 input_list = [a.name]
916 output_list = [result_tens.name]
917 pCount, cCount = op["operands"]
918 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000919 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
920 self, error_name, input_list, output_list
921 )
Matthew Haddond6ce7252021-09-29 15:35:44 +0100922
Les Bell729b0352021-11-24 10:28:21 +0000923 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +0100924 self.ser,
925 validator_fcns,
926 error_name,
927 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000928 axis=axis,
929 input_shape=a.shape,
930 output_shape=result_tens.shape,
931 input_dtype=a.dtype,
932 output_dtype=result_tens.dtype,
933 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +0100934 input_list=input_list,
935 output_list=output_list,
936 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000937 ):
938 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700939
940 attr = ts.TosaSerializerAttribute()
941 attr.AxisAttribute(axis)
942
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000943 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700944 return result_tens
945
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100946 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
947 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700948
Jeremy Johnson18e26662021-07-22 16:15:29 +0100949 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700950
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100951 if error_name == ErrorIf.MaxSmallerMin:
952 # Make sure the numbers are different to invoke this error
953 while v[0] == v[1]:
954 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
955 max_val = min(v)
956 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700957 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100958 max_val = max(v)
959 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700960
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100961 # Invalidate Input/Output list for error if checks.
962 input_list = [a.name]
963 output_list = [result_tens.name]
964 pCount, cCount = op["operands"]
965 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000966 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
967 self, error_name, input_list, output_list
968 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100969
Les Bell729b0352021-11-24 10:28:21 +0000970 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100971 self.ser,
972 validator_fcns,
973 error_name,
974 op=op,
975 max_val=max_val,
976 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000977 input_shape=a.shape,
978 output_shape=result_tens.shape,
979 input_dtype=a.dtype,
980 output_dtype=result_tens.dtype,
981 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100982 input_list=input_list,
983 output_list=output_list,
984 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000985 ):
986 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100987
988 attr = ts.TosaSerializerAttribute()
989 if a.dtype == DType.FLOAT:
990 attr.ClampAttribute(0, 0, min_val, max_val)
991 else:
992 attr.ClampAttribute(min_val, max_val, 0, 0)
993
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000994 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700995 return result_tens
996
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100997 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
998 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700999 attr = ts.TosaSerializerAttribute()
1000
1001 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1002
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001003 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001004 return result_tens
1005
1006 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001007 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1008 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001009
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001010 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001011 return result_tens
1012
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001013 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1014 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1015
1016 # Invalidate Input/Output list for error if checks.
1017 input_list = [a.name]
1018 output_list = [result_tens.name]
1019 pCount, cCount = op["operands"]
1020 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001021 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1022 self, error_name, input_list, output_list
1023 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001024
Les Bell729b0352021-11-24 10:28:21 +00001025 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001026 self.ser,
1027 validator_fcns,
1028 error_name,
1029 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001030 input_shape=a.shape,
1031 output_shape=result_tens.shape,
1032 input_dtype=a.dtype,
1033 output_dtype=result_tens.dtype,
1034 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001035 input_list=input_list,
1036 output_list=output_list,
1037 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001038 ):
1039 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001040
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001041 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001042 return result_tens
1043
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001044 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1045 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1046
1047 # Invalidate Input/Output list for error if checks.
1048 input_list = [a.name]
1049 output_list = [result_tens.name]
1050 pCount, cCount = op["operands"]
1051 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001052 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1053 self, error_name, input_list, output_list
1054 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001055
Les Bell729b0352021-11-24 10:28:21 +00001056 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001057 self.ser,
1058 validator_fcns,
1059 error_name,
1060 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001061 input_shape=a.shape,
1062 output_shape=result_tens.shape,
1063 input_dtype=a.dtype,
1064 output_dtype=result_tens.dtype,
1065 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001066 input_list=input_list,
1067 output_list=output_list,
1068 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001069 ):
1070 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001071
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001072 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001073 return result_tens
1074
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001075 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1076 if error_name != ErrorIf.WrongInputType:
1077 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001078
1079 # To store variable length list of input tensors we need to store axis along with it
1080 axis = a[-1]
1081 a = a[:-1]
1082
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001083 result_tens = OutputShaper.concatOp(
1084 self.ser, self.rng, axis, *a, error_name=error_name
1085 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001086
Matthew Haddon818ab902021-07-27 09:12:49 +01001087 input_tensor_names = []
1088 for tensor in a:
1089 input_tensor_names.append(tensor.name)
1090
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001091 # Invalidate Input/Output list for error if checks.
1092 input_list = input_tensor_names
1093 output_list = [result_tens.name]
1094 pCount, cCount = op["operands"]
1095 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001096 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1097 self, error_name, input_list, output_list
1098 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001099
Les Bell729b0352021-11-24 10:28:21 +00001100 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001101 self.ser,
1102 validator_fcns,
1103 error_name,
1104 op=op,
1105 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001106 input_shape=a[0].shape,
1107 output_shape=result_tens.shape,
1108 input_dtype=a[0].dtype,
1109 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001110 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001111 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001112 input_list=input_list,
1113 output_list=output_list,
1114 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001115 ):
1116 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001117
1118 attr = ts.TosaSerializerAttribute()
1119 attr.AxisAttribute(axis)
1120
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001121 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001122 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001123
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001124 def build_pad(
1125 self,
1126 op,
1127 a,
1128 padding,
1129 pad_const_int,
1130 pad_const_float,
1131 validator_fcns=None,
1132 error_name=None,
1133 qinfo=None,
1134 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001135 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001136
Kevin Chengfe392ce2021-10-18 21:51:55 +00001137 attr = ts.TosaSerializerAttribute()
1138 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001139
Matthew Haddone807aae2021-10-11 18:12:58 +01001140 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001141 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001142 output_list = [result_tens.name]
1143 pCount, cCount = op["operands"]
1144 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001145 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1146 self, error_name, input_list, output_list
1147 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001148
Les Bell729b0352021-11-24 10:28:21 +00001149 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001150 self.ser,
1151 validator_fcns,
1152 error_name,
1153 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001154 input_shape=a.shape,
1155 output_shape=result_tens.shape,
1156 input_dtype=a.dtype,
1157 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001158 pad=padding,
1159 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001160 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001161 input_list=input_list,
1162 output_list=output_list,
1163 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001164 ):
1165 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001166
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001167 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Matthew Haddone86fd342021-09-07 16:12:21 +01001168 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001169
Matthew Haddone807aae2021-10-11 18:12:58 +01001170 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001171 result_tens = OutputShaper.reshapeOp(
1172 self.ser, self.rng, a, newShape, error_name
1173 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001174
1175 # Invalidate Input/Output list for error if checks.
1176 input_list = [a.name]
1177 output_list = [result_tens.name]
1178 pCount, cCount = op["operands"]
1179 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001180 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1181 self, error_name, input_list, output_list
1182 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001183
Les Bell729b0352021-11-24 10:28:21 +00001184 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001185 self.ser,
1186 validator_fcns,
1187 error_name,
1188 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001189 input_shape=a.shape,
1190 output_shape=result_tens.shape,
1191 input_dtype=a.dtype,
1192 output_dtype=result_tens.dtype,
1193 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001194 input_list=input_list,
1195 output_list=output_list,
1196 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001197 ):
1198 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001199
1200 attr = ts.TosaSerializerAttribute()
1201 attr.ReshapeAttribute(newShape)
1202
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001203 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001204 return result_tens
1205
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001206 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1207 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1208
1209 # Invalidate Input/Output list for error if checks.
1210 input_list = [a.name]
1211 output_list = [result_tens.name]
1212 pCount, cCount = op["operands"]
1213 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001214 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1215 self, error_name, input_list, output_list
1216 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001217
Les Bell729b0352021-11-24 10:28:21 +00001218 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001219 self.ser,
1220 validator_fcns,
1221 error_name,
1222 op=op,
1223 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001224 input_shape=a.shape,
1225 output_shape=result_tens.shape,
1226 input_dtype=a.dtype,
1227 output_dtype=result_tens.dtype,
1228 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001229 input_list=input_list,
1230 output_list=output_list,
1231 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001232 ):
1233 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001234
1235 attr = ts.TosaSerializerAttribute()
1236 attr.AxisAttribute(axis)
1237
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001238 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001239 return result_tens
1240
Matthew Haddone807aae2021-10-11 18:12:58 +01001241 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1242 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001243
Kevin Chengfe392ce2021-10-18 21:51:55 +00001244 attr = ts.TosaSerializerAttribute()
1245 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001246
Matthew Haddone807aae2021-10-11 18:12:58 +01001247 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001248 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001249 output_list = [result_tens.name]
1250 pCount, cCount = op["operands"]
1251 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001252 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1253 self, error_name, input_list, output_list
1254 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001255
Les Bell729b0352021-11-24 10:28:21 +00001256 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001257 self.ser,
1258 validator_fcns,
1259 error_name,
1260 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001261 input_shape=a.shape,
1262 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001263 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001264 input_dtype=a.dtype,
1265 output_dtype=result_tens.dtype,
1266 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001267 input_list=input_list,
1268 output_list=output_list,
1269 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001270 ):
1271 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001272
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001273 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001274 return result_tens
1275
Matthew Haddone807aae2021-10-11 18:12:58 +01001276 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001277 result_tens = OutputShaper.sliceOp(
1278 self.ser, self.rng, a, start, size, error_name
1279 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001280
1281 # Invalidate Input/Output list for error if checks.
1282 input_list = [a.name]
1283 output_list = [result_tens.name]
1284 pCount, cCount = op["operands"]
1285 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001286 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1287 self, error_name, input_list, output_list
1288 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001289
Les Bell729b0352021-11-24 10:28:21 +00001290 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001291 self.ser,
1292 validator_fcns,
1293 error_name,
1294 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001295 input_shape=a.shape,
1296 output_shape=result_tens.shape,
1297 input_dtype=a.dtype,
1298 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001299 start=start,
1300 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001301 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001302 input_list=input_list,
1303 output_list=output_list,
1304 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001305 ):
1306 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001307
1308 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001309 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001311 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001312 return result_tens
1313
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001314 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1315 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1316
1317 # Invalidate Input/Output list for error if checks.
1318 input_list = [a.name]
1319 output_list = [result_tens.name]
1320 pCount, cCount = op["operands"]
1321 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001322 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1323 self, error_name, input_list, output_list
1324 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001325
Les Bell729b0352021-11-24 10:28:21 +00001326 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001327 self.ser,
1328 validator_fcns,
1329 error_name,
1330 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001331 input_shape=a.shape,
1332 output_shape=result_tens.shape,
1333 input_dtype=a.dtype,
1334 output_dtype=result_tens.dtype,
1335 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001336 input_list=input_list,
1337 output_list=output_list,
1338 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001339 ):
1340 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001341
1342 attr = ts.TosaSerializerAttribute()
1343 attr.TileAttribute(multiples)
1344
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001345 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001346 return result_tens
1347
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001348 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001349
1350 # Create a new indicies tensor
1351 # here with data that doesn't exceed the dimensions of the values tensor
1352
Kevin Cheng550ccc52021-03-03 11:21:43 -08001353 K = values.shape[1] # K
1354 W = self.randInt(
1355 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1356 ) # W
1357 indicies_arr = np.int32(
1358 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1359 ) # (N, W)
1360 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001361
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001362 result_tens = OutputShaper.gatherOp(
1363 self.ser, self.rng, values, indicies, error_name
1364 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001365
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001366 # Invalidate Input/Output list for error if checks.
1367 input_list = [values.name, indicies.name]
1368 output_list = [result_tens.name]
1369 pCount, cCount = op["operands"]
1370 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1372 self, error_name, input_list, output_list
1373 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374
Les Bell729b0352021-11-24 10:28:21 +00001375 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001376 self.ser,
1377 validator_fcns,
1378 error_name,
1379 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001380 input_shape=values.shape,
1381 output_shape=result_tens.shape,
1382 input_dtype=values.dtype,
1383 output_dtype=result_tens.dtype,
1384 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001385 input_list=input_list,
1386 output_list=output_list,
1387 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001388 ):
1389 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001391 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001392
1393 return result_tens
1394
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001395 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001396
1397 # Create a new indicies tensor
1398 # here with data that doesn't exceed the dimensions of the values_in tensor
1399
Kevin Cheng550ccc52021-03-03 11:21:43 -08001400 K = values_in.shape[1] # K
1401 W = input.shape[1] # W
1402 indicies_arr = np.int32(
1403 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1404 ) # (N, W)
1405 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001406
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001407 result_tens = OutputShaper.scatterOp(
1408 self.ser, self.rng, values_in, indicies, input, error_name
1409 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001410
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001411 # Invalidate Input/Output list for error if checks.
1412 input_list = [values_in.name, indicies.name, input.name]
1413 output_list = [result_tens.name]
1414 pCount, cCount = op["operands"]
1415 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001416 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1417 self, error_name, input_list, output_list
1418 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001419
Les Bell729b0352021-11-24 10:28:21 +00001420 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001421 self.ser,
1422 validator_fcns,
1423 error_name,
1424 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001425 input_shape=values_in.shape,
1426 output_shape=result_tens.shape,
1427 input_dtype=values_in.dtype,
1428 output_dtype=result_tens.dtype,
1429 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001430 input_list=input_list,
1431 output_list=output_list,
1432 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001433 ):
1434 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001435
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001436 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001437
Kevin Cheng77d0f762020-11-24 10:26:32 -08001438 return result_tens
1439
Kevin Cheng550ccc52021-03-03 11:21:43 -08001440 def build_resize(
1441 self,
1442 op,
1443 input,
1444 mode,
1445 stride,
1446 offset,
1447 shift,
1448 stride_fp,
1449 offset_fp,
1450 output_dims,
1451 input_dtype,
1452 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001453 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001454 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001455 ):
1456 result_tens = OutputShaper.resizeOp(
1457 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001458 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001459 input,
1460 mode,
1461 stride,
1462 offset,
1463 shift,
1464 stride_fp,
1465 offset_fp,
1466 output_dims,
1467 input_dtype,
1468 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001469 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001470 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001471
Matthew Haddon848efb42021-09-09 12:30:53 +01001472 # Invalidate Input/Output list for error if checks.
1473 input_list = [input.name]
1474 output_list = [result_tens.name]
1475 pCount, cCount = op["operands"]
1476 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001477 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1478 self, error_name, input_list, output_list
1479 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001480
Les Bell729b0352021-11-24 10:28:21 +00001481 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001482 self.ser,
1483 validator_fcns,
1484 error_name,
1485 op=op,
1486 mode=mode,
1487 shift=shift,
1488 input_dtype=input_dtype,
1489 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001490 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001491 output_shape=output_dims,
1492 offset=offset,
1493 offset_fp=offset_fp,
1494 stride=stride,
1495 stride_fp=stride_fp,
1496 input_list=input_list,
1497 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001498 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001499 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001500 ):
1501 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001502
Eric Kunzee5e26762020-10-13 16:11:07 -07001503 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001504
Kevin Cheng550ccc52021-03-03 11:21:43 -08001505 attr.ResizeAttribute(
1506 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1507 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001508
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001509 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001510 return result_tens
1511
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001512 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1513 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1514 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001515 self.ser.addOperator(
1516 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1517 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001518 return result_tens
1519
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001520 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001521 self.ser.addOutputTensor(val)
1522 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001523
1524 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001525 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001526 result_tens = OutputShaper.typeConversionOp(
1527 self.ser, self.rng, val, out_dtype, error_name
1528 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001529
1530 # Invalidate Input/Output list for error if checks.
1531 input_list = [val.name]
1532 output_list = [result_tens.name]
1533 pCount, cCount = op["operands"]
1534 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001535 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1536 self, error_name, input_list, output_list
1537 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001538
Les Bell729b0352021-11-24 10:28:21 +00001539 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001540 self.ser,
1541 validator_fcns,
1542 error_name,
1543 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001544 input_shape=val.shape,
1545 output_shape=result_tens.shape,
1546 input_dtype=val.dtype,
1547 output_dtype=result_tens.dtype,
1548 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001549 input_list=input_list,
1550 output_list=output_list,
1551 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001552 ):
1553 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001554
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001555 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001556 return result_tens
1557
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001558 def build_rescale(
1559 self,
1560 op,
1561 val,
1562 out_dtype,
1563 scale32,
1564 double_round,
1565 per_channel,
1566 validator_fcns,
1567 error_name,
1568 ):
1569 result_tens = OutputShaper.typeConversionOp(
1570 self.ser, self.rng, val, out_dtype, error_name
1571 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001572
1573 if per_channel:
1574 nc = val.shape[-1]
1575 else:
1576 nc = 1
1577
1578 in_type_width = self.typeWidth(val.dtype)
1579 out_type_width = self.typeWidth(out_dtype)
1580
Kevin Cheng3a478572021-01-22 17:21:02 -08001581 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001582 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001583 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001584 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001585 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001586 in_type_width += 1
1587 elif error_name in [
1588 ErrorIf.InputZeroPointNotZero,
1589 ErrorIf.U16InputZeroPointNotValid,
1590 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001591 input_zp = self.randInt(-128, 128)
1592 if input_zp == 0:
1593 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001594 in_type_width += 1
1595 elif val.dtype == DType.UINT16:
1596 # Must come after ErrorIf.U16InputZeroPointNotValid check
1597 input_zp = self.rng.choice([0, 32768])
1598 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001599 else:
1600 input_zp = 0
1601
Kevin Cheng3a478572021-01-22 17:21:02 -08001602 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001603 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001604 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001605 elif out_dtype == DType.UINT8:
1606 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001607 out_type_width += 1
1608 elif error_name in [
1609 ErrorIf.OutputZeroPointNotZero,
1610 ErrorIf.U16OutputZeroPointNotValid,
1611 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001612 output_zp = self.randInt(-128, 128)
1613 if output_zp == 0:
1614 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001615 out_type_width += 1
1616 elif out_dtype == DType.UINT16:
1617 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1618 output_zp = self.rng.choice([0, 32768])
1619 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001620 else:
1621 output_zp = 0
1622
1623 # Calculate scale based on:
1624 # scale = a *(2^output_width)/(2^input_width))
1625
1626 a = np.float32(self.rng.random(size=[nc]))
1627 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1628
1629 if scale32:
1630 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001631 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001632 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1633 else:
1634 # Cap the scaling at 2^15 - 1 for scale16
1635 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1636
Kevin Cheng550ccc52021-03-03 11:21:43 -08001637 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001638
1639 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1640 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001641 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1642 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001643
1644 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001645 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1646 scale_arr[i], scale32
1647 )
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001648 min_shift_value_arr[i] = -1 << (shift_arr[i] - 2)
1649 max_shift_value_arr[i] = (1 << (shift_arr[i] - 2)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001650
Kevin Cheng550ccc52021-03-03 11:21:43 -08001651 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001652 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001653 # Make sure random values are within apply_scale_32 specification
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001654 # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2))
1655 assert val.placeholderFilename
1656 values = np.load(
1657 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1658 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001659 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1660 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1661 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1662 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001663 if not np.all(np.array_equal(values, val_adj)):
1664 # Values changed so overwrite file with new values
1665 np.save(
1666 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1667 val_adj,
1668 False,
1669 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001670
Matthew Haddonc2025212021-10-08 21:21:05 +01001671 # Invalidate Input/Output list for error if checks.
1672 input_list = [val.name]
1673 output_list = [result_tens.name]
1674 pCount, cCount = op["operands"]
1675 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001676 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1677 self, error_name, input_list, output_list
1678 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001679
1680 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001681 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001682 self.ser,
1683 validator_fcns,
1684 error_name,
1685 op=op,
1686 input_dtype=val.dtype,
1687 output_dtype=out_dtype,
1688 input_shape=val.shape,
1689 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001690 scale32=scale32,
1691 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001692 input_list=input_list,
1693 output_list=output_list,
1694 result_tensor=result_tens,
1695 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001696 ):
1697 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001698
Eric Kunzee5e26762020-10-13 16:11:07 -07001699 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001700 attr.RescaleAttribute(
1701 input_zp,
1702 output_zp,
1703 multiplier_arr,
1704 shift_arr,
1705 scale32,
1706 double_round,
1707 per_channel,
1708 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001709
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001710 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001711 return result_tens
1712
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001713 def build_cond_if_const(
1714 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1715 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001716 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1717 # (except for the generated shap) and the condition. Build Then/Else blocks
1718 # and fill them with const nodes for the body.
1719
1720 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001721 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001722
1723 # Make then/else tensors
1724 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001725
1726 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001727 if error_name in [
1728 ErrorIf.CondIfOutputListThenGraphMismatch,
1729 ErrorIf.CondIfOutputListElseGraphMismatch,
1730 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001731 incorrect_shape = deepcopy(then_tens.shape)
1732 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001733 incorrect_shape[i] += (
1734 self.rng.choice([-3, -2, 2, 3])
1735 if incorrect_shape[i] > 3
1736 else self.rng.choice([1, 2, 4])
1737 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001738 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1739
Jeremy Johnson18e26662021-07-22 16:15:29 +01001740 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1741 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001742
1743 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001744 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001745
1746 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001747 then_block = "THEN_BLOCK"
1748 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001749 attr = ts.TosaSerializerAttribute()
1750 attr.CondIfAttribute(then_block, else_block)
1751
1752 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001753 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001754
1755 self.ser.startBasicBlock(then_block)
1756 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001757 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1758 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1759 else:
1760 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001761 self.ser.addOutputTensor(then_tens)
1762
1763 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001764 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1765 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1766 else:
1767 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001768 self.ser.addOutputTensor(else_tens)
1769
Les Bell729b0352021-11-24 10:28:21 +00001770 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001771 self.ser,
1772 validator_fcns,
1773 error_name,
1774 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001775 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001776 ):
1777 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001778
Eric Kunzee5e26762020-10-13 16:11:07 -07001779 return result_tens
1780
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001781 def build_cond_if_binary(
1782 self, op, a, b, cond, validator_fcns=None, error_name=None
1783 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001784 # For cond_if with a binary op in the then/else blocks, take a and b and
1785 # alternately add or subtract them based on the condition
1786
1787 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001788 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001789
Kevin Cheng550ccc52021-03-03 11:21:43 -08001790 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001791
1792 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001793 then_block = "THEN_BLOCK"
1794 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001795 attr = ts.TosaSerializerAttribute()
1796 attr.CondIfAttribute(then_block, else_block)
1797
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001798 if error_name in [
1799 ErrorIf.CondIfInputListThenGraphMismatch,
1800 ErrorIf.CondIfInputListElseGraphMismatch,
1801 ErrorIf.CondIfOutputListElseGraphMismatch,
1802 ErrorIf.CondIfOutputListThenGraphMismatch,
1803 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001804 incorrect_shape = a.shape.copy()
1805 for i in range(len(incorrect_shape)):
1806 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1807 incorrect_block_input = deepcopy(a)
1808 incorrect_block_input.shape = incorrect_shape
1809
Eric Kunzee5e26762020-10-13 16:11:07 -07001810 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001811 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001812 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001813 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001814
Les Bell6040b4d2021-10-11 12:50:31 +01001815 if a.dtype in (DType.FLOAT, DType.INT32):
1816 then_op, else_op = Op.ADD, Op.SUB
1817 elif a.dtype in (DType.INT8, DType.INT16):
1818 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1819 else:
1820 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001821
Les Bell6040b4d2021-10-11 12:50:31 +01001822 for block, op in ((then_block, then_op), (else_block, else_op)):
1823 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 if (
1825 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1826 and block == then_block
1827 ) or (
1828 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1829 and block == else_block
1830 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001831 self.ser.addInputTensor(incorrect_block_input)
1832 self.ser.addInputTensor(b)
1833 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001834 elif (
1835 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1836 and block == then_block
1837 ) or (
1838 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1839 and block == else_block
1840 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001841 self.ser.addInputTensor(a)
1842 self.ser.addInputTensor(b)
1843 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1844 else:
1845 self.ser.addInputTensor(a)
1846 self.ser.addInputTensor(b)
1847 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001848 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001849
Les Bell729b0352021-11-24 10:28:21 +00001850 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001851 self.ser,
1852 validator_fcns,
1853 error_name,
1854 op=op,
1855 a=a,
1856 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001857 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001858 ):
1859 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001860
Eric Kunzee5e26762020-10-13 16:11:07 -07001861 return result_tens
1862
Matthew Haddon630c17c2021-10-14 15:05:41 +01001863 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001864 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001865
Kevin Cheng550ccc52021-03-03 11:21:43 -08001866 cond_block = "COND_BLOCK"
1867 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001868
1869 attr = ts.TosaSerializerAttribute()
1870 attr.WhileLoopAttribute(cond_block, body_block)
1871
1872 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001873 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001874 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001875 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001876
1877 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001878 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1879 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001880 if error_name == ErrorIf.InputListOutputListMismatch:
1881 incorrect_acc = deepcopy(acc)
1882 for i in range(len(incorrect_acc.shape)):
1883 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1884 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1885 else:
1886 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001887
1888 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001889 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001890 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001891 [iter.name, a.name, acc.name],
1892 [iter_out.name, a_out.name, acc_out.name],
1893 attr,
1894 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001895 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001896
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001897 if error_name in [
1898 ErrorIf.InputListCondGraphMismatch,
1899 ErrorIf.InputListBodyGraphInputMismatch,
1900 ErrorIf.InputListBodyGraphOutputMismatch,
1901 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001902 incorrect_iter = deepcopy(iter)
1903 for i in range(len(incorrect_iter.shape)):
1904 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
1905 if len(incorrect_iter.shape) == 0:
1906 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
1907
1908 incorrect_acc = deepcopy(acc)
1909 for i in range(len(incorrect_acc.shape)):
1910 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1911
Eric Kunzee5e26762020-10-13 16:11:07 -07001912 # COND block (input: iter, output: cond_tens )
1913 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001914 if error_name == ErrorIf.InputListCondGraphMismatch:
1915 self.ser.addInputTensor(incorrect_iter)
1916 self.ser.addInputTensor(a)
1917 self.ser.addInputTensor(incorrect_acc)
1918 else:
1919 self.ser.addInputTensor(iter)
1920 self.ser.addInputTensor(a)
1921 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001922 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001923
1924 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001925 cond_tens = self.ser.addOutput(
1926 [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
1927 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001928 else:
1929 cond_tens = self.ser.addOutput([], DType.BOOL)
1930
Kevin Cheng550ccc52021-03-03 11:21:43 -08001931 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001932
1933 # BODY block (input: a, acc, iter, output: a, acc, iter)
1934 # Note that local intermediate tensors need to be declared here for the outputs
1935 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001936 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
1937 self.ser.addInputTensor(incorrect_iter)
1938 self.ser.addInputTensor(a)
1939 self.ser.addInputTensor(incorrect_acc)
1940 else:
1941 self.ser.addInputTensor(iter)
1942 self.ser.addInputTensor(a)
1943 self.ser.addInputTensor(acc)
1944
Kevin Cheng550ccc52021-03-03 11:21:43 -08001945 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001946
1947 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001948 iter_body_out = self.ser.addIntermediate(
1949 incorrect_iter.shape, incorrect_iter.dtype
1950 )
1951 acc_body_out = self.ser.addIntermediate(
1952 incorrect_acc.shape, incorrect_acc.dtype
1953 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001954 else:
1955 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1956 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
1957
Eric Kunzee5e26762020-10-13 16:11:07 -07001958 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1959 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1960 self.ser.addOutputTensor(iter_body_out)
1961 self.ser.addOutputTensor(a)
1962 self.ser.addOutputTensor(acc_body_out)
1963
Les Bell729b0352021-11-24 10:28:21 +00001964 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001965 self.ser,
1966 validator_fcns,
1967 error_name,
1968 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001969 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001970 ):
1971 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001972
Eric Kunzee5e26762020-10-13 16:11:07 -07001973 return acc_out
1974
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001975 def create_filter_lists(
1976 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
1977 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01001978 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1979 default_test_rank_range = range(1, 5)
1980 if not shapeFilter:
1981 shapeFilter = [None]
1982
1983 # Calculate the filters based on what is requested and what the operator allows
1984 rmin, rmax = op["rank"]
1985 if rankFilter is not None:
1986 cleanRankFilter = []
1987 # Ensure rankFilter values are allowed by operator
1988 for rank in rankFilter:
1989 if rank >= rmin and rank <= rmax:
1990 cleanRankFilter.append(rank)
1991 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01001992 # Ensure default behaviour is bounded by default range or by operator,
1993 # whichever is the smaller range of ranks.
1994 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001995 cleanRankFilter = (
1996 opRankRange
1997 if len(opRankRange) <= len(default_test_rank_range)
1998 else default_test_rank_range
1999 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002000 else:
2001 cleanRankFilter = range(rmin, rmax + 1)
2002
2003 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002004
Matthew Haddon1c00b712021-10-01 15:51:03 +01002005 if dtypeFilter is not None:
2006 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002007 # Create list of operator dtypes filtered by requested dtypes
2008 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002009 if dtype in dtypeFilter or (
2010 isinstance(dtype, list) and dtype[0] in dtypeFilter
2011 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002012 cleanDtypeFilter.append(dtype)
2013 else:
2014 cleanDtypeFilter = dtypes
2015
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002016 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002017 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002018 "shapeFilter": shapeFilter,
2019 "rankFilter": cleanRankFilter,
2020 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002021 }
2022 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002023 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002024 if validator is not None:
2025 validator_info = validator(check=False, op=op)
2026 else:
2027 return None
2028
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002029 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002030
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002031 # Set parameters as required
2032 if error_arguments["rank"] is not None:
2033 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002034 else:
2035 rankFilter = cleanRankFilter
2036
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002037 if error_arguments["dtype"] is not None:
2038 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002039 else:
2040 dtypeFilter = cleanDtypeFilter
2041
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002042 if error_arguments["shape"] is not None:
2043 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002044 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002045 shapeFilter = shapeFilter[
2046 :2
2047 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002048
2049 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002050 "shapeFilter": shapeFilter,
2051 "rankFilter": rankFilter,
2052 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002053 }
2054 return filterDict
2055
Kevin Cheng550ccc52021-03-03 11:21:43 -08002056 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002057 self,
2058 opName,
2059 shapeFilter=[None],
2060 rankFilter=None,
2061 dtypeFilter=None,
2062 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002063 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002064
2065 try:
2066 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002067 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002068 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002069
2070 # Initialize a new random number generator
2071 self.rng = np.random.default_rng(self.random_seed)
2072
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002073 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002074
Eric Kunzee5e26762020-10-13 16:11:07 -07002075 # Test list consists of a tuple of:
2076 # (opName, testNameStr, dtype, shapeList, argumentsList)
2077 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002078 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002079 error_if_validators = op["error_if_validators"]
2080 else:
2081 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002082
Matthew Haddon1c00b712021-10-01 15:51:03 +01002083 for validator in error_if_validators:
2084 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002085 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002086 else:
2087 error_name = None
2088
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002089 filterDict = self.create_filter_lists(
2090 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2091 )
2092 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002093 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002094 cleanRankFilter = filterDict["rankFilter"]
2095 cleanDtypeFilter = filterDict["dtypeFilter"]
2096 cleanShapeFilter = filterDict["shapeFilter"]
2097 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002098
2099 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002100 for t in cleanDtypeFilter:
2101 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002102 # Filter out by rank
2103 if shape is not None and len(shape) != r:
2104 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002105 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002106 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002107
Matthew Haddon74567092021-07-16 15:38:20 +01002108 shapeStr = self.shapeStr(shapeList[0])
2109 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002110
Matthew Haddon74567092021-07-16 15:38:20 +01002111 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2112 argList = []
2113 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002114 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002115 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002116 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002117
Matthew Haddon74567092021-07-16 15:38:20 +01002118 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002119 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002120 if argStr:
2121 testStr = "{}_{}_{}_{}".format(
2122 opName, shapeStr, typeStr, argStr
2123 )
2124 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002125 testStr = "{}_{}_{}".format(
2126 opName, shapeStr, typeStr
2127 )
2128 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002129 if argStr:
2130 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2131 opName, error_name, shapeStr, typeStr, argStr
2132 )
2133 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002134 testStr = "{}_ERRORIF_{}_{}_{}".format(
2135 opName, error_name, shapeStr, typeStr
2136 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002137
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002138 testList.append(
2139 (opName, testStr, t, error_name, shapeList, args)
2140 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002141
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002142 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002143 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2144 if "invalid_test_validators" in op:
2145 invalid_test_validators = op["invalid_test_validators"]
2146 clean_testList = []
2147 for test in testList:
2148 for validator_fcn in invalid_test_validators:
2149 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002150 if validator_fcn(
2151 opName=test[0],
2152 input_dtype=test[2],
2153 shapeList=test[4],
2154 args=test[5],
2155 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002156 remove_test = True
2157 if not remove_test:
2158 clean_testList.append(test)
2159 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002160
2161 return testList
2162
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002163 def serializeTest(
2164 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2165 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002166 try:
2167 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002168 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002169 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002170
2171 # Create a serializer
2172 self.createSerializer(opName, testStr)
2173
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002174 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002175 if "error_if_validators" in op:
2176 error_if_validators = op["error_if_validators"]
2177 else:
2178 error_if_validators = None
2179
Kevin Cheng550ccc52021-03-03 11:21:43 -08002180 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002181 num_operands = pCount + cCount
2182
2183 if isinstance(dtype_or_dtypeList, list):
2184 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002185 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002186 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002187 else:
2188 dtypeList = [dtype_or_dtypeList] * (num_operands)
2189
Kevin Cheng93a16282021-08-31 16:14:03 -07002190 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002191 assert (
2192 len(shapeList) == num_operands
2193 ), "shapeList length {} must match number of operands {}".format(
2194 len(shapeList), num_operands
2195 )
2196 assert (
2197 len(dtypeList) == num_operands
2198 ), "dtypeList length {} must match number of operands {}".format(
2199 len(dtypeList), num_operands
2200 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002201
2202 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002203 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002204 except KeyError:
2205 qgen = None
2206
2207 # Build the random tensor operands and the test
2208 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002209
Matthew Haddon1c00b712021-10-01 15:51:03 +01002210 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002211 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002212 else:
2213 qinfo = None
2214
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002215 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002216
Matthew Haddon1c00b712021-10-01 15:51:03 +01002217 try:
2218 if error_if_validators is None:
2219 if qinfo is not None:
2220 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2221 else:
2222 resultName = build_fcn(self, op, *tens, *testArgs)
2223 else:
2224 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002225 resultName = build_fcn(
2226 self,
2227 op,
2228 *tens,
2229 *testArgs,
2230 validator_fcns=error_if_validators,
2231 error_name=error_name,
2232 qinfo=qinfo,
2233 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002234 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002235 resultName = build_fcn(
2236 self,
2237 op,
2238 *tens,
2239 *testArgs,
2240 validator_fcns=error_if_validators,
2241 error_name=error_name,
2242 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002243 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002244 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002245 raise e
2246
Les Bell729b0352021-11-24 10:28:21 +00002247 if resultName:
2248 # The test is valid, serialize it
2249 self.serialize("test")
2250 else:
2251 # The test is not valid
2252 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002253
Eric Kunzee5e26762020-10-13 16:11:07 -07002254 def createDynamicOpLists(self):
2255
2256 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002257 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002258
Kevin Cheng1533b852021-09-01 12:51:58 -07002259 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002260 testName = "conv2d_{}x{}".format(k[0], k[1])
2261 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2262 self.TOSA_OP_LIST[testName]["filter"] = k
2263 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002264
Kevin Cheng550ccc52021-03-03 11:21:43 -08002265 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2266 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2267 "depthwise_conv2d_TEMPLATE"
2268 ].copy()
2269 self.TOSA_OP_LIST[testName]["filter"] = k
2270 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002271
Kevin Cheng550ccc52021-03-03 11:21:43 -08002272 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2273 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2274 "transpose_conv2d_TEMPLATE"
2275 ].copy()
2276 self.TOSA_OP_LIST[testName]["filter"] = k
2277 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002278
Kevin Cheng1533b852021-09-01 12:51:58 -07002279 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2280 for k in KERNELS_3D:
2281 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2282 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2283 self.TOSA_OP_LIST[testName]["filter"] = k
2284 self.TOSA_OP_LIST[testName]["template"] = False
2285
Eric Kunzee5e26762020-10-13 16:11:07 -07002286 # Delete any templates after having created any dynamic ops
2287 # This is a two-pass operation because it's bad practice to delete
2288 # keys from dictionaries while iterating
2289 keyList = []
2290 for k in self.TOSA_OP_LIST:
2291 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002292 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002293 keyList.append(k)
2294 continue
2295 except KeyError:
2296 pass
2297
2298 for k in keyList:
2299 del self.TOSA_OP_LIST[k]
2300
2301 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002302 """Fill in default fields for ops if they aren't already specified.
2303 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002304 for op in self.TOSA_OP_LIST:
2305
2306 # Required fields
2307 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002308 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002309 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002310 raise Exception(
2311 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2312 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002313
2314 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002315 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002316 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002317 raise Exception(
2318 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2319 op
2320 )
2321 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002322
2323 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002324 _ = self.TOSA_OP_LIST[op]["types"]
2325 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002326 raise Exception(
2327 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2328 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002329
2330 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002331 _ = self.TOSA_OP_LIST[op]["op"]
2332 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002333 raise Exception(
2334 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2335 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002336
2337 # Put in default rank range, if missing
2338 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002339 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002340 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002341 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002342
2343 # Tensor operator list
2344 # 'op': op name
2345 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002346 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2347 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002348 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2349 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002350 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002351
Kevin Cheng550ccc52021-03-03 11:21:43 -08002352 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2353 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002354
Kevin Cheng550ccc52021-03-03 11:21:43 -08002355 TYPE_BOOL = [DType.BOOL]
2356 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2357 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2358 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002359
Kevin Cheng550ccc52021-03-03 11:21:43 -08002360 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002361
Kevin Cheng1533b852021-09-01 12:51:58 -07002362 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002363 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002364 [DType.INT8, DType.INT8, DType.INT32],
2365 [DType.INT16, DType.INT8, DType.INT48],
2366 DType.FLOAT,
2367 ]
2368
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002369 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002370
2371 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002372 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002373 "argmax": {
2374 "op": Op.ARGMAX,
2375 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002376 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002377 "build_fcn": (
2378 build_argmax,
2379 TosaTensorGen.tgBasic,
2380 TosaTensorValuesGen.tvgDefault,
2381 TosaArgGen.agAxis,
2382 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002383 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002384 "error_if_validators": (
2385 TosaErrorValidator.evAxisSmallerZero,
2386 TosaErrorValidator.evAxisLargerRank,
2387 TosaErrorValidator.evArgmaxOutputRankMismatch,
2388 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2389 TosaErrorValidator.evWrongRank,
2390 TosaErrorValidator.evWrongInputType,
2391 TosaErrorValidator.evWrongOutputType,
2392 TosaErrorValidator.evWrongInputList,
2393 TosaErrorValidator.evWrongOutputList,
2394 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002395 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002396 "avg_pool2d": {
2397 "op": Op.AVG_POOL2D,
2398 "operands": (1, 0),
2399 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002400 "build_fcn": (
2401 build_pool2d,
2402 TosaTensorGen.tgNHWC,
2403 TosaTensorValuesGen.tvgDefault,
2404 TosaArgGen.agPooling,
2405 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002406 "qgen": TosaQuantGen.qgUnary,
2407 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002408 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002409 "error_if_validators": (
2410 TosaErrorValidator.evKernelSmallerOne,
2411 TosaErrorValidator.evStrideSmallerOne,
2412 TosaErrorValidator.evPadSmallerZero,
2413 TosaErrorValidator.evWrongRank,
2414 TosaErrorValidator.evWrongInputType,
2415 TosaErrorValidator.evWrongOutputType,
2416 TosaErrorValidator.evWrongInputList,
2417 TosaErrorValidator.evWrongOutputList,
2418 TosaErrorValidator.evInputZeroPointNotZero,
2419 TosaErrorValidator.evOutputZeroPointNotZero,
2420 TosaErrorValidator.evPadLargerEqualKernel,
2421 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002422 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002423 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002424 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002425 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002426 "conv2d_TEMPLATE": {
2427 "op": Op.CONV2D,
2428 "operands": (1, 2),
2429 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002430 "build_fcn": (
2431 build_conv2d,
2432 TosaTensorGen.tgConv2D,
2433 TosaTensorValuesGen.tvgDefault,
2434 TosaArgGen.agConv,
2435 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002436 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002437 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002438 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2439 "error_if_validators": (
2440 TosaErrorValidator.evWrongInputType,
2441 TosaErrorValidator.evWrongOutputType,
2442 TosaErrorValidator.evWrongInputList,
2443 TosaErrorValidator.evWrongOutputList,
2444 TosaErrorValidator.evInputZeroPointNotZero,
2445 TosaErrorValidator.evWeightZeroPointNotZero,
2446 TosaErrorValidator.evPadSmallerZero,
2447 TosaErrorValidator.evStrideSmallerOne,
2448 TosaErrorValidator.evDilationSmallerOne,
2449 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002450 TosaErrorValidator.evConvOutputShapeMismatch,
2451 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002452 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002453 "template": True,
2454 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002455 # Templated operator. Filled in by createDynamicOpLists
2456 "conv3d_TEMPLATE": {
2457 "op": Op.CONV3D,
2458 "operands": (1, 2),
2459 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002460 "build_fcn": (
2461 build_conv3d,
2462 TosaTensorGen.tgConv3D,
2463 TosaTensorValuesGen.tvgDefault,
2464 TosaArgGen.agConv,
2465 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002466 "qgen": TosaQuantGen.qgConv,
2467 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002468 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2469 "error_if_validators": (
2470 TosaErrorValidator.evWrongInputType,
2471 TosaErrorValidator.evWrongOutputType,
2472 TosaErrorValidator.evWrongInputList,
2473 TosaErrorValidator.evWrongOutputList,
2474 TosaErrorValidator.evInputZeroPointNotZero,
2475 TosaErrorValidator.evWeightZeroPointNotZero,
2476 TosaErrorValidator.evPadSmallerZero,
2477 TosaErrorValidator.evStrideSmallerOne,
2478 TosaErrorValidator.evDilationSmallerOne,
2479 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002480 TosaErrorValidator.evConvOutputShapeMismatch,
2481 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002482 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002483 "template": True,
2484 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002485 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002486 "depthwise_conv2d_TEMPLATE": {
2487 "op": Op.DEPTHWISE_CONV2D,
2488 "operands": (1, 2),
2489 "filter": [1, 1],
2490 "rank": (4, 4),
2491 "build_fcn": (
2492 build_depthwise_conv2d,
2493 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002494 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002495 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002496 ),
2497 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002498 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002499 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2500 "error_if_validators": (
2501 TosaErrorValidator.evWrongInputType,
2502 TosaErrorValidator.evWrongOutputType,
2503 TosaErrorValidator.evWrongInputList,
2504 TosaErrorValidator.evWrongOutputList,
2505 TosaErrorValidator.evInputZeroPointNotZero,
2506 TosaErrorValidator.evWeightZeroPointNotZero,
2507 TosaErrorValidator.evPadSmallerZero,
2508 TosaErrorValidator.evStrideSmallerOne,
2509 TosaErrorValidator.evDilationSmallerOne,
2510 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002511 TosaErrorValidator.evConvOutputShapeMismatch,
2512 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002513 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002514 "template": True,
2515 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002516 "fully_connected": {
2517 "op": Op.FULLY_CONNECTED,
2518 "operands": (1, 2),
2519 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002520 "build_fcn": (
2521 build_fully_connected,
2522 TosaTensorGen.tgFullyConnected,
2523 TosaTensorValuesGen.tvgDefault,
2524 None,
2525 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002526 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002527 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002528 "error_if_validators": (
2529 TosaErrorValidator.evInputZeroPointNotZero,
2530 TosaErrorValidator.evWeightZeroPointNotZero,
2531 TosaErrorValidator.evWrongRank,
2532 TosaErrorValidator.evWrongInputType,
2533 TosaErrorValidator.evWrongOutputType,
2534 TosaErrorValidator.evWrongInputList,
2535 TosaErrorValidator.evWrongOutputList,
2536 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002537 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002538 "matmul": {
2539 "op": Op.MATMUL,
2540 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002541 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002542 "build_fcn": (
2543 build_matmul,
2544 TosaTensorGen.tgMatmul,
2545 TosaTensorValuesGen.tvgDefault,
2546 None,
2547 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002548 "qgen": TosaQuantGen.qgMatmul,
2549 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002550 "error_if_validators": (
2551 TosaErrorValidator.evInputZeroPointNotZero,
2552 TosaErrorValidator.evWrongRank,
2553 TosaErrorValidator.evWrongInputType,
2554 TosaErrorValidator.evWrongOutputType,
2555 TosaErrorValidator.evWrongInputList,
2556 TosaErrorValidator.evWrongOutputList,
2557 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002558 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002559 "max_pool2d": {
2560 "op": Op.MAX_POOL2D,
2561 "operands": (1, 0),
2562 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002563 "build_fcn": (
2564 build_pool2d,
2565 TosaTensorGen.tgNHWC,
2566 TosaTensorValuesGen.tvgDefault,
2567 TosaArgGen.agPooling,
2568 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002569 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002570 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002571 "error_if_validators": (
2572 TosaErrorValidator.evKernelSmallerOne,
2573 TosaErrorValidator.evStrideSmallerOne,
2574 TosaErrorValidator.evPadSmallerZero,
2575 TosaErrorValidator.evWrongRank,
2576 TosaErrorValidator.evWrongInputType,
2577 TosaErrorValidator.evWrongOutputType,
2578 TosaErrorValidator.evWrongInputList,
2579 TosaErrorValidator.evWrongOutputList,
2580 TosaErrorValidator.evPadLargerEqualKernel,
2581 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002582 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002583 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002584 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002585 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002586 "transpose_conv2d_TEMPLATE": {
2587 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002588 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002589 "rank": (4, 4),
2590 "build_fcn": (
2591 build_transpose_conv2d,
2592 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002593 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002594 TosaArgGen.agTransposeConv2D,
2595 ),
2596 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002597 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002598 "invalid_test_validators": (
2599 TosaInvalidValidator.ivHeightWidthInvalid,
2600 TosaInvalidValidator.ivNonPositiveOutputShape,
2601 ),
2602 "error_if_validators": (
2603 TosaErrorValidator.evWrongInputType,
2604 TosaErrorValidator.evWrongOutputType,
2605 TosaErrorValidator.evWrongInputList,
2606 TosaErrorValidator.evWrongOutputList,
2607 TosaErrorValidator.evInputZeroPointNotZero,
2608 TosaErrorValidator.evWeightZeroPointNotZero,
2609 TosaErrorValidator.evPadSmallerZero,
2610 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002611 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002612 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002613 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002614 "template": True,
2615 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002616 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002617 "clamp": {
2618 "op": Op.CLAMP,
2619 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002620 "build_fcn": (
2621 build_clamp,
2622 TosaTensorGen.tgBasic,
2623 TosaTensorValuesGen.tvgDefault,
2624 None,
2625 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002626 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002627 "error_if_validators": (
2628 TosaErrorValidator.evMaxSmallerMin,
2629 TosaErrorValidator.evWrongInputType,
2630 TosaErrorValidator.evWrongOutputType,
2631 TosaErrorValidator.evWrongInputList,
2632 TosaErrorValidator.evWrongOutputList,
2633 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002634 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002635 "sigmoid": {
2636 "op": Op.SIGMOID,
2637 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002638 "build_fcn": (
2639 build_sigmoid,
2640 TosaTensorGen.tgBasic,
2641 TosaTensorValuesGen.tvgDefault,
2642 None,
2643 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002644 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002645 "error_if_validators": (
2646 TosaErrorValidator.evWrongInputType,
2647 TosaErrorValidator.evWrongOutputType,
2648 TosaErrorValidator.evWrongInputList,
2649 TosaErrorValidator.evWrongOutputList,
2650 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002651 },
2652 "tanh": {
2653 "op": Op.TANH,
2654 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002655 "build_fcn": (
2656 build_tanh,
2657 TosaTensorGen.tgBasic,
2658 TosaTensorValuesGen.tvgDefault,
2659 None,
2660 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002661 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002662 "error_if_validators": (
2663 TosaErrorValidator.evWrongInputType,
2664 TosaErrorValidator.evWrongOutputType,
2665 TosaErrorValidator.evWrongInputList,
2666 TosaErrorValidator.evWrongOutputList,
2667 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002668 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002669 # Elementwise Binary Operators
2670 "add": {
2671 "op": Op.ADD,
2672 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002673 "build_fcn": (
2674 build_binary_broadcast,
2675 TosaTensorGen.tgBroadcastFuzz,
2676 TosaTensorValuesGen.tvgAddSub,
2677 None,
2678 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002679 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002680 "error_if_validators": (
2681 TosaErrorValidator.evRankMismatch,
2682 TosaErrorValidator.evWrongInputType,
2683 TosaErrorValidator.evWrongOutputType,
2684 TosaErrorValidator.evWrongInputList,
2685 TosaErrorValidator.evWrongOutputList,
2686 TosaErrorValidator.evDimensionMismatch,
2687 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002688 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002689 "arithmetic_right_shift": {
2690 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2691 "operands": (2, 0),
2692 "build_fcn": (
2693 build_arithmetic_right_shift,
2694 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002695 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002696 TosaArgGen.agArithmeticRightShift,
2697 ),
2698 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002699 "error_if_validators": (
2700 TosaErrorValidator.evRankMismatch,
2701 TosaErrorValidator.evWrongInputType,
2702 TosaErrorValidator.evWrongOutputType,
2703 TosaErrorValidator.evWrongInputList,
2704 TosaErrorValidator.evWrongOutputList,
2705 TosaErrorValidator.evDimensionMismatch,
2706 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002707 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002708 "bitwise_and": {
2709 "op": Op.BITWISE_AND,
2710 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002711 "build_fcn": (
2712 build_binary_broadcast,
2713 TosaTensorGen.tgBroadcastFuzz,
2714 TosaTensorValuesGen.tvgDefault,
2715 None,
2716 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002717 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002718 "error_if_validators": (
2719 TosaErrorValidator.evRankMismatch,
2720 TosaErrorValidator.evWrongInputType,
2721 TosaErrorValidator.evWrongOutputType,
2722 TosaErrorValidator.evWrongInputList,
2723 TosaErrorValidator.evWrongOutputList,
2724 TosaErrorValidator.evDimensionMismatch,
2725 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002726 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002727 "bitwise_or": {
2728 "op": Op.BITWISE_OR,
2729 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002730 "build_fcn": (
2731 build_binary_broadcast,
2732 TosaTensorGen.tgBroadcastFuzz,
2733 TosaTensorValuesGen.tvgDefault,
2734 None,
2735 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002736 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002737 "error_if_validators": (
2738 TosaErrorValidator.evRankMismatch,
2739 TosaErrorValidator.evWrongInputType,
2740 TosaErrorValidator.evWrongOutputType,
2741 TosaErrorValidator.evWrongInputList,
2742 TosaErrorValidator.evWrongOutputList,
2743 TosaErrorValidator.evDimensionMismatch,
2744 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002745 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002746 "bitwise_xor": {
2747 "op": Op.BITWISE_XOR,
2748 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002749 "build_fcn": (
2750 build_binary_broadcast,
2751 TosaTensorGen.tgBroadcastFuzz,
2752 TosaTensorValuesGen.tvgDefault,
2753 None,
2754 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002755 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002756 "error_if_validators": (
2757 TosaErrorValidator.evRankMismatch,
2758 TosaErrorValidator.evWrongInputType,
2759 TosaErrorValidator.evWrongOutputType,
2760 TosaErrorValidator.evWrongInputList,
2761 TosaErrorValidator.evWrongOutputList,
2762 TosaErrorValidator.evDimensionMismatch,
2763 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002764 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002765 "intdiv": {
2766 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002767 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002768 "build_fcn": (
2769 build_binary_broadcast,
2770 TosaTensorGen.tgBroadcastFuzz,
2771 TosaTensorValuesGen.tvgIntDiv,
2772 None,
2773 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002774 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002775 "error_if_validators": (
2776 TosaErrorValidator.evRankMismatch,
2777 TosaErrorValidator.evWrongInputType,
2778 TosaErrorValidator.evWrongOutputType,
2779 TosaErrorValidator.evWrongInputList,
2780 TosaErrorValidator.evWrongOutputList,
2781 TosaErrorValidator.evDimensionMismatch,
2782 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002783 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002784 "logical_and": {
2785 "op": Op.LOGICAL_AND,
2786 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002787 "build_fcn": (
2788 build_binary_broadcast,
2789 TosaTensorGen.tgBroadcastFuzz,
2790 TosaTensorValuesGen.tvgDefault,
2791 None,
2792 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002793 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002794 "error_if_validators": (
2795 TosaErrorValidator.evRankMismatch,
2796 TosaErrorValidator.evWrongInputType,
2797 TosaErrorValidator.evWrongOutputType,
2798 TosaErrorValidator.evWrongInputList,
2799 TosaErrorValidator.evWrongOutputList,
2800 TosaErrorValidator.evDimensionMismatch,
2801 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002802 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002803 "logical_left_shift": {
2804 "op": Op.LOGICAL_LEFT_SHIFT,
2805 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002806 "build_fcn": (
2807 build_binary_broadcast,
2808 TosaTensorGen.tgBroadcastFuzz,
2809 TosaTensorValuesGen.tvgLogicalShift,
2810 None,
2811 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002812 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002813 "error_if_validators": (
2814 TosaErrorValidator.evRankMismatch,
2815 TosaErrorValidator.evWrongInputType,
2816 TosaErrorValidator.evWrongOutputType,
2817 TosaErrorValidator.evWrongInputList,
2818 TosaErrorValidator.evWrongOutputList,
2819 TosaErrorValidator.evDimensionMismatch,
2820 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002821 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002822 "logical_right_shift": {
2823 "op": Op.LOGICAL_RIGHT_SHIFT,
2824 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002825 "build_fcn": (
2826 build_binary_broadcast,
2827 TosaTensorGen.tgBroadcastFuzz,
2828 TosaTensorValuesGen.tvgLogicalShift,
2829 None,
2830 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002831 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002832 "error_if_validators": (
2833 TosaErrorValidator.evRankMismatch,
2834 TosaErrorValidator.evWrongInputType,
2835 TosaErrorValidator.evWrongOutputType,
2836 TosaErrorValidator.evWrongInputList,
2837 TosaErrorValidator.evWrongOutputList,
2838 TosaErrorValidator.evDimensionMismatch,
2839 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002840 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002841 "logical_or": {
2842 "op": Op.LOGICAL_OR,
2843 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002844 "build_fcn": (
2845 build_binary_broadcast,
2846 TosaTensorGen.tgBroadcastFuzz,
2847 TosaTensorValuesGen.tvgDefault,
2848 None,
2849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002850 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002851 "error_if_validators": (
2852 TosaErrorValidator.evRankMismatch,
2853 TosaErrorValidator.evWrongInputType,
2854 TosaErrorValidator.evWrongOutputType,
2855 TosaErrorValidator.evWrongInputList,
2856 TosaErrorValidator.evWrongOutputList,
2857 TosaErrorValidator.evDimensionMismatch,
2858 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002859 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002860 "logical_xor": {
2861 "op": Op.LOGICAL_XOR,
2862 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002863 "build_fcn": (
2864 build_binary_broadcast,
2865 TosaTensorGen.tgBroadcastFuzz,
2866 TosaTensorValuesGen.tvgDefault,
2867 None,
2868 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002869 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 "error_if_validators": (
2871 TosaErrorValidator.evRankMismatch,
2872 TosaErrorValidator.evWrongInputType,
2873 TosaErrorValidator.evWrongOutputType,
2874 TosaErrorValidator.evWrongInputList,
2875 TosaErrorValidator.evWrongOutputList,
2876 TosaErrorValidator.evDimensionMismatch,
2877 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002878 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002879 "maximum": {
2880 "op": Op.MAXIMUM,
2881 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002882 "build_fcn": (
2883 build_binary_broadcast,
2884 TosaTensorGen.tgBroadcastFuzz,
2885 TosaTensorValuesGen.tvgDefault,
2886 None,
2887 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002888 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002889 "error_if_validators": (
2890 TosaErrorValidator.evRankMismatch,
2891 TosaErrorValidator.evWrongInputType,
2892 TosaErrorValidator.evWrongOutputType,
2893 TosaErrorValidator.evWrongInputList,
2894 TosaErrorValidator.evWrongOutputList,
2895 TosaErrorValidator.evDimensionMismatch,
2896 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002897 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002898 "minimum": {
2899 "op": Op.MINIMUM,
2900 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002901 "build_fcn": (
2902 build_binary_broadcast,
2903 TosaTensorGen.tgBroadcastFuzz,
2904 TosaTensorValuesGen.tvgDefault,
2905 None,
2906 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002907 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002908 "error_if_validators": (
2909 TosaErrorValidator.evRankMismatch,
2910 TosaErrorValidator.evWrongInputType,
2911 TosaErrorValidator.evWrongOutputType,
2912 TosaErrorValidator.evWrongInputList,
2913 TosaErrorValidator.evWrongOutputList,
2914 TosaErrorValidator.evDimensionMismatch,
2915 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002916 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002917 "mul": {
2918 "op": Op.MUL,
2919 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002920 "build_fcn": (
2921 build_mul,
2922 TosaTensorGen.tgBroadcastFuzz,
2923 TosaTensorValuesGen.tvgMul,
2924 TosaArgGen.agMul,
2925 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002926 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002927 "error_if_validators": (
2928 TosaErrorValidator.evWrongInputType,
2929 TosaErrorValidator.evWrongOutputType,
2930 TosaErrorValidator.evWrongInputList,
2931 TosaErrorValidator.evWrongOutputList,
2932 TosaErrorValidator.evRankMismatch,
2933 TosaErrorValidator.evDimensionMismatch,
2934 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002935 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002936 "pow": {
2937 "op": Op.POW,
2938 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002939 "build_fcn": (
2940 build_binary_broadcast,
2941 TosaTensorGen.tgBroadcastFuzz,
2942 TosaTensorValuesGen.tvgDefault,
2943 None,
2944 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002945 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002946 "error_if_validators": (
2947 TosaErrorValidator.evRankMismatch,
2948 TosaErrorValidator.evWrongInputType,
2949 TosaErrorValidator.evWrongOutputType,
2950 TosaErrorValidator.evWrongInputList,
2951 TosaErrorValidator.evWrongOutputList,
2952 TosaErrorValidator.evDimensionMismatch,
2953 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002954 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002955 "sub": {
2956 "op": Op.SUB,
2957 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002958 "build_fcn": (
2959 build_binary_broadcast,
2960 TosaTensorGen.tgBroadcastFuzz,
2961 TosaTensorValuesGen.tvgAddSub,
2962 None,
2963 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002964 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002965 "error_if_validators": (
2966 TosaErrorValidator.evRankMismatch,
2967 TosaErrorValidator.evWrongInputType,
2968 TosaErrorValidator.evWrongOutputType,
2969 TosaErrorValidator.evWrongInputList,
2970 TosaErrorValidator.evWrongOutputList,
2971 TosaErrorValidator.evDimensionMismatch,
2972 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002973 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002974 "table": {
2975 "op": Op.TABLE,
2976 # Use the automatic generation functions to create the input array
2977 # but create the table tensor in the build function, as it may be
2978 # a different type from the input
2979 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002980 "build_fcn": (
2981 build_table,
2982 TosaTensorGen.tgBasic,
2983 TosaTensorValuesGen.tvgDefault,
2984 TosaArgGen.agTable,
2985 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002986 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002987 "error_if_validators": (
2988 TosaErrorValidator.evWrongInputType,
2989 TosaErrorValidator.evWrongOutputType,
2990 TosaErrorValidator.evWrongInputList,
2991 TosaErrorValidator.evWrongOutputList,
2992 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002993 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002994 # Elementwise Unary operators
2995 "abs": {
2996 "op": Op.ABS,
2997 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002998 "build_fcn": (
2999 build_unary,
3000 TosaTensorGen.tgBasic,
3001 TosaTensorValuesGen.tvgDefault,
3002 None,
3003 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003004 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003005 "error_if_validators": (
3006 TosaErrorValidator.evWrongInputType,
3007 TosaErrorValidator.evWrongOutputType,
3008 TosaErrorValidator.evWrongInputList,
3009 TosaErrorValidator.evWrongOutputList,
3010 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003011 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003012 "bitwise_not": {
3013 "op": Op.BITWISE_NOT,
3014 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003015 "build_fcn": (
3016 build_unary,
3017 TosaTensorGen.tgBasic,
3018 TosaTensorValuesGen.tvgDefault,
3019 None,
3020 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003021 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003022 "error_if_validators": (
3023 TosaErrorValidator.evWrongInputType,
3024 TosaErrorValidator.evWrongOutputType,
3025 TosaErrorValidator.evWrongInputList,
3026 TosaErrorValidator.evWrongOutputList,
3027 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003028 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003029 "ceil": {
3030 "op": Op.CEIL,
3031 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003032 "build_fcn": (
3033 build_unary,
3034 TosaTensorGen.tgBasic,
3035 TosaTensorValuesGen.tvgDefault,
3036 None,
3037 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003038 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003039 "error_if_validators": (
3040 TosaErrorValidator.evWrongInputType,
3041 TosaErrorValidator.evWrongOutputType,
3042 TosaErrorValidator.evWrongInputList,
3043 TosaErrorValidator.evWrongOutputList,
3044 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003045 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003046 "clz": {
3047 "op": Op.CLZ,
3048 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003049 "build_fcn": (
3050 build_unary,
3051 TosaTensorGen.tgBasic,
3052 TosaTensorValuesGen.tvgDefault,
3053 None,
3054 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003055 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003056 "error_if_validators": (
3057 TosaErrorValidator.evWrongInputType,
3058 TosaErrorValidator.evWrongOutputType,
3059 TosaErrorValidator.evWrongInputList,
3060 TosaErrorValidator.evWrongOutputList,
3061 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003062 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003063 "exp": {
3064 "op": Op.EXP,
3065 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003066 "build_fcn": (
3067 build_unary,
3068 TosaTensorGen.tgBasic,
3069 TosaTensorValuesGen.tvgDefault,
3070 None,
3071 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003072 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003073 "error_if_validators": (
3074 TosaErrorValidator.evWrongInputType,
3075 TosaErrorValidator.evWrongOutputType,
3076 TosaErrorValidator.evWrongInputList,
3077 TosaErrorValidator.evWrongOutputList,
3078 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003079 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003080 "floor": {
3081 "op": Op.FLOOR,
3082 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003083 "build_fcn": (
3084 build_unary,
3085 TosaTensorGen.tgBasic,
3086 TosaTensorValuesGen.tvgDefault,
3087 None,
3088 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003089 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003090 "error_if_validators": (
3091 TosaErrorValidator.evWrongInputType,
3092 TosaErrorValidator.evWrongOutputType,
3093 TosaErrorValidator.evWrongInputList,
3094 TosaErrorValidator.evWrongOutputList,
3095 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003096 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003097 "log": {
3098 "op": Op.LOG,
3099 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003100 "build_fcn": (
3101 build_unary,
3102 TosaTensorGen.tgBasic,
3103 TosaTensorValuesGen.tvgDefault,
3104 None,
3105 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003106 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003107 "error_if_validators": (
3108 TosaErrorValidator.evWrongInputType,
3109 TosaErrorValidator.evWrongOutputType,
3110 TosaErrorValidator.evWrongInputList,
3111 TosaErrorValidator.evWrongOutputList,
3112 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003113 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003114 "logical_not": {
3115 "op": Op.LOGICAL_NOT,
3116 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003117 "build_fcn": (
3118 build_unary,
3119 TosaTensorGen.tgBasic,
3120 TosaTensorValuesGen.tvgDefault,
3121 None,
3122 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003123 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003124 "error_if_validators": (
3125 TosaErrorValidator.evWrongInputType,
3126 TosaErrorValidator.evWrongOutputType,
3127 TosaErrorValidator.evWrongInputList,
3128 TosaErrorValidator.evWrongOutputList,
3129 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003130 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003131 "negate": {
3132 "op": Op.NEGATE,
3133 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003134 "build_fcn": (
3135 build_unary,
3136 TosaTensorGen.tgBasic,
3137 TosaTensorValuesGen.tvgNegate,
3138 None,
3139 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003140 "qgen": TosaQuantGen.qgUnary,
3141 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003142 "error_if_validators": (
3143 TosaErrorValidator.evInputZeroPointNotZero,
3144 TosaErrorValidator.evOutputZeroPointNotZero,
3145 TosaErrorValidator.evWrongInputType,
3146 TosaErrorValidator.evWrongOutputType,
3147 TosaErrorValidator.evWrongInputList,
3148 TosaErrorValidator.evWrongOutputList,
3149 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003150 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003151 "reciprocal": {
3152 "op": Op.RECIPROCAL,
3153 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003154 "build_fcn": (
3155 build_unary,
3156 TosaTensorGen.tgBasic,
3157 TosaTensorValuesGen.tvgDefault,
3158 None,
3159 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003160 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003161 "error_if_validators": (
3162 TosaErrorValidator.evWrongInputType,
3163 TosaErrorValidator.evWrongOutputType,
3164 TosaErrorValidator.evWrongInputList,
3165 TosaErrorValidator.evWrongOutputList,
3166 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003167 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003168 "rsqrt": {
3169 "op": Op.RSQRT,
3170 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003171 "build_fcn": (
3172 build_unary,
3173 TosaTensorGen.tgBasic,
3174 TosaTensorValuesGen.tvgDefault,
3175 None,
3176 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003177 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003178 "error_if_validators": (
3179 TosaErrorValidator.evWrongInputType,
3180 TosaErrorValidator.evWrongOutputType,
3181 TosaErrorValidator.evWrongInputList,
3182 TosaErrorValidator.evWrongOutputList,
3183 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003184 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003185 # Elementwise Ternary operators
3186 "select": {
3187 "op": Op.SELECT,
3188 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003189 "build_fcn": (
3190 build_select,
3191 TosaTensorGen.tgBroadcastFuzz,
3192 TosaTensorValuesGen.tvgSelect,
3193 None,
3194 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003195 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003196 "error_if_validators": (
3197 TosaErrorValidator.evRankMismatch,
3198 TosaErrorValidator.evWrongInputType,
3199 TosaErrorValidator.evWrongOutputType,
3200 TosaErrorValidator.evWrongInputList,
3201 TosaErrorValidator.evWrongOutputList,
3202 TosaErrorValidator.evDimensionMismatch,
3203 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003204 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003205 # Comparison operators
3206 "equal": {
3207 "op": Op.EQUAL,
3208 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003209 "build_fcn": (
3210 build_comparison,
3211 TosaTensorGen.tgBroadcastFuzz,
3212 TosaTensorValuesGen.tvgEqual,
3213 None,
3214 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003215 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003216 "error_if_validators": (
3217 TosaErrorValidator.evRankMismatch,
3218 TosaErrorValidator.evWrongInputType,
3219 TosaErrorValidator.evWrongOutputType,
3220 TosaErrorValidator.evWrongInputList,
3221 TosaErrorValidator.evWrongOutputList,
3222 TosaErrorValidator.evDimensionMismatch,
3223 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003224 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003225 "greater_equal": {
3226 "op": Op.GREATER_EQUAL,
3227 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003228 "build_fcn": (
3229 build_comparison,
3230 TosaTensorGen.tgBroadcastFuzz,
3231 TosaTensorValuesGen.tvgDefault,
3232 None,
3233 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003234 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003235 "error_if_validators": (
3236 TosaErrorValidator.evRankMismatch,
3237 TosaErrorValidator.evWrongInputType,
3238 TosaErrorValidator.evWrongOutputType,
3239 TosaErrorValidator.evWrongInputList,
3240 TosaErrorValidator.evWrongOutputList,
3241 TosaErrorValidator.evDimensionMismatch,
3242 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003243 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003244 "greater": {
3245 "op": Op.GREATER,
3246 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003247 "build_fcn": (
3248 build_comparison,
3249 TosaTensorGen.tgBroadcastFuzz,
3250 TosaTensorValuesGen.tvgDefault,
3251 None,
3252 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003253 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003254 "error_if_validators": (
3255 TosaErrorValidator.evRankMismatch,
3256 TosaErrorValidator.evWrongInputType,
3257 TosaErrorValidator.evWrongOutputType,
3258 TosaErrorValidator.evWrongInputList,
3259 TosaErrorValidator.evWrongOutputList,
3260 TosaErrorValidator.evDimensionMismatch,
3261 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 # Reduction operators
3264 "reduce_all": {
3265 "op": Op.REDUCE_ALL,
3266 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003267 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003268 "build_fcn": (
3269 build_reduce,
3270 TosaTensorGen.tgBasic,
3271 TosaTensorValuesGen.tvgDefault,
3272 TosaArgGen.agAxis,
3273 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003274 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003275 "error_if_validators": (
3276 TosaErrorValidator.evAxisLargerRank,
3277 TosaErrorValidator.evAxisSmallerZero,
3278 TosaErrorValidator.evShapeOfAxisNotOne,
3279 TosaErrorValidator.evWrongInputType,
3280 TosaErrorValidator.evWrongOutputType,
3281 TosaErrorValidator.evWrongRank,
3282 TosaErrorValidator.evWrongInputList,
3283 TosaErrorValidator.evWrongOutputList,
3284 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003285 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003286 "reduce_any": {
3287 "op": Op.REDUCE_ANY,
3288 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003289 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003290 "build_fcn": (
3291 build_reduce,
3292 TosaTensorGen.tgBasic,
3293 TosaTensorValuesGen.tvgDefault,
3294 TosaArgGen.agAxis,
3295 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003296 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003297 "error_if_validators": (
3298 TosaErrorValidator.evAxisLargerRank,
3299 TosaErrorValidator.evAxisSmallerZero,
3300 TosaErrorValidator.evShapeOfAxisNotOne,
3301 TosaErrorValidator.evWrongInputType,
3302 TosaErrorValidator.evWrongOutputType,
3303 TosaErrorValidator.evWrongRank,
3304 TosaErrorValidator.evWrongInputList,
3305 TosaErrorValidator.evWrongOutputList,
3306 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003307 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003308 "reduce_max": {
3309 "op": Op.REDUCE_MAX,
3310 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003311 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003312 "build_fcn": (
3313 build_reduce,
3314 TosaTensorGen.tgBasic,
3315 TosaTensorValuesGen.tvgDefault,
3316 TosaArgGen.agAxis,
3317 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003318 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003319 "error_if_validators": (
3320 TosaErrorValidator.evAxisLargerRank,
3321 TosaErrorValidator.evAxisSmallerZero,
3322 TosaErrorValidator.evShapeOfAxisNotOne,
3323 TosaErrorValidator.evWrongInputType,
3324 TosaErrorValidator.evWrongOutputType,
3325 TosaErrorValidator.evWrongRank,
3326 TosaErrorValidator.evWrongInputList,
3327 TosaErrorValidator.evWrongOutputList,
3328 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003329 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003330 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003331 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003332 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003333 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003334 "build_fcn": (
3335 build_reduce,
3336 TosaTensorGen.tgBasic,
3337 TosaTensorValuesGen.tvgDefault,
3338 TosaArgGen.agAxis,
3339 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003340 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003341 "error_if_validators": (
3342 TosaErrorValidator.evAxisLargerRank,
3343 TosaErrorValidator.evAxisSmallerZero,
3344 TosaErrorValidator.evShapeOfAxisNotOne,
3345 TosaErrorValidator.evWrongInputType,
3346 TosaErrorValidator.evWrongOutputType,
3347 TosaErrorValidator.evWrongRank,
3348 TosaErrorValidator.evWrongInputList,
3349 TosaErrorValidator.evWrongOutputList,
3350 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003351 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003352 "reduce_product": {
3353 "op": Op.REDUCE_PRODUCT,
3354 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003355 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003356 "build_fcn": (
3357 build_reduce,
3358 TosaTensorGen.tgBasic,
3359 TosaTensorValuesGen.tvgDefault,
3360 TosaArgGen.agAxis,
3361 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003362 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003363 "error_if_validators": (
3364 TosaErrorValidator.evAxisLargerRank,
3365 TosaErrorValidator.evAxisSmallerZero,
3366 TosaErrorValidator.evShapeOfAxisNotOne,
3367 TosaErrorValidator.evWrongInputType,
3368 TosaErrorValidator.evWrongOutputType,
3369 TosaErrorValidator.evWrongRank,
3370 TosaErrorValidator.evWrongInputList,
3371 TosaErrorValidator.evWrongOutputList,
3372 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003373 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003374 "reduce_sum": {
3375 "op": Op.REDUCE_SUM,
3376 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003377 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003378 "build_fcn": (
3379 build_reduce,
3380 TosaTensorGen.tgBasic,
3381 TosaTensorValuesGen.tvgReduceSum,
3382 TosaArgGen.agAxis,
3383 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003385 "error_if_validators": (
3386 TosaErrorValidator.evAxisLargerRank,
3387 TosaErrorValidator.evAxisSmallerZero,
3388 TosaErrorValidator.evShapeOfAxisNotOne,
3389 TosaErrorValidator.evWrongInputType,
3390 TosaErrorValidator.evWrongOutputType,
3391 TosaErrorValidator.evWrongRank,
3392 TosaErrorValidator.evWrongInputList,
3393 TosaErrorValidator.evWrongOutputList,
3394 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003395 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003396 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003397 "concat": {
3398 "op": Op.CONCAT,
3399 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003400 "build_fcn": (
3401 build_concat,
3402 TosaTensorGen.tgConcat,
3403 TosaTensorValuesGen.tvgConcat,
3404 TosaArgGen.agAxis,
3405 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003406 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003407 "error_if_validators": (
3408 TosaErrorValidator.evAxisLargerRank,
3409 TosaErrorValidator.evAxisSmallerZero,
3410 TosaErrorValidator.evConcatInputRankMismatch,
3411 TosaErrorValidator.evConcatShapeSumMismatch,
3412 TosaErrorValidator.evConcatInputDimMismatch,
3413 TosaErrorValidator.evWrongInputType,
3414 TosaErrorValidator.evWrongOutputType,
3415 TosaErrorValidator.evWrongOutputList,
3416 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003417 },
3418 "pad": {
3419 "op": Op.PAD,
3420 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003421 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003422 "build_fcn": (
3423 build_pad,
3424 TosaTensorGen.tgBasic,
3425 TosaTensorValuesGen.tvgDefault,
3426 TosaArgGen.agPad,
3427 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003428 "qgen": TosaQuantGen.qgPad,
3429 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003430 "error_if_validators": (
3431 TosaErrorValidator.evWrongInputType,
3432 TosaErrorValidator.evPadSmallerZero,
3433 TosaErrorValidator.evWrongOutputType,
3434 TosaErrorValidator.evWrongInputList,
3435 TosaErrorValidator.evWrongOutputList,
3436 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003437 },
3438 "reshape": {
3439 "op": Op.RESHAPE,
3440 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003441 "build_fcn": (
3442 build_reshape,
3443 TosaTensorGen.tgBasic,
3444 TosaTensorValuesGen.tvgDefault,
3445 TosaArgGen.agReshape,
3446 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003447 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003448 "error_if_validators": (
3449 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3450 TosaErrorValidator.evWrongInputType,
3451 TosaErrorValidator.evWrongOutputType,
3452 TosaErrorValidator.evWrongInputList,
3453 TosaErrorValidator.evWrongOutputList,
3454 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003455 },
3456 "reverse": {
3457 "op": Op.REVERSE,
3458 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003459 "build_fcn": (
3460 build_reverse,
3461 TosaTensorGen.tgBasic,
3462 TosaTensorValuesGen.tvgDefault,
3463 TosaArgGen.agAxis,
3464 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003465 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003466 "error_if_validators": (
3467 TosaErrorValidator.evAxisSmallerZero,
3468 TosaErrorValidator.evAxisLargerRank,
3469 TosaErrorValidator.evWrongInputType,
3470 TosaErrorValidator.evWrongOutputType,
3471 TosaErrorValidator.evWrongInputList,
3472 TosaErrorValidator.evWrongOutputList,
3473 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003474 },
3475 "slice": {
3476 "op": Op.SLICE,
3477 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003478 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003479 "build_fcn": (
3480 build_slice,
3481 TosaTensorGen.tgBasic,
3482 TosaTensorValuesGen.tvgDefault,
3483 TosaArgGen.agSlice,
3484 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003485 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003486 "error_if_validators": (
3487 TosaErrorValidator.evStartSmallerZero,
3488 TosaErrorValidator.evSizeSmallerEqualZero,
3489 TosaErrorValidator.evStartSizeOutsideBounds,
3490 TosaErrorValidator.evSizeOutputShapeMismatch,
3491 TosaErrorValidator.evInputSizeStartLengthMismatch,
3492 TosaErrorValidator.evWrongRank,
3493 TosaErrorValidator.evWrongInputType,
3494 TosaErrorValidator.evWrongOutputType,
3495 TosaErrorValidator.evWrongInputList,
3496 TosaErrorValidator.evWrongOutputList,
3497 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003498 },
3499 "tile": {
3500 "op": Op.TILE,
3501 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003502 "build_fcn": (
3503 build_tile,
3504 TosaTensorGen.tgBasic,
3505 TosaTensorValuesGen.tvgDefault,
3506 TosaArgGen.agTile,
3507 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003508 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003509 "error_if_validators": (
3510 TosaErrorValidator.evWrongInputType,
3511 TosaErrorValidator.evWrongOutputType,
3512 TosaErrorValidator.evWrongInputList,
3513 TosaErrorValidator.evWrongOutputList,
3514 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003515 },
3516 "transpose": {
3517 "op": Op.TRANSPOSE,
3518 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003519 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003520 "build_fcn": (
3521 build_transpose,
3522 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003523 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003524 TosaArgGen.agTranspose,
3525 ),
3526 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003527 "error_if_validators": (
3528 TosaErrorValidator.evIndexOutsideBounds,
3529 TosaErrorValidator.evIndexUsedTwice,
3530 TosaErrorValidator.evWrongInputType,
3531 TosaErrorValidator.evWrongOutputType,
3532 TosaErrorValidator.evWrongInputList,
3533 TosaErrorValidator.evWrongOutputList,
3534 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003535 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 # Data nodes
3537 "const": {
3538 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003539 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003540 "build_fcn": (
3541 build_const,
3542 TosaTensorGen.tgBasic,
3543 TosaTensorValuesGen.tvgDefault,
3544 None,
3545 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003546 "types": TYPE_FIB,
3547 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 "identity": {
3549 "op": Op.IDENTITY,
3550 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003551 "build_fcn": (
3552 build_unary,
3553 TosaTensorGen.tgBasic,
3554 TosaTensorValuesGen.tvgDefault,
3555 None,
3556 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003557 "types": TYPE_FIB,
3558 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003559 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003560 "gather": {
3561 "op": Op.GATHER,
3562 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3563 "operands": (1, 0),
3564 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003565 "build_fcn": (
3566 build_gather,
3567 TosaTensorGen.tgBasic,
3568 TosaTensorValuesGen.tvgDefault,
3569 None,
3570 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003571 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003572 "error_if_validators": (
3573 TosaErrorValidator.evWrongInputType,
3574 TosaErrorValidator.evWrongOutputType,
3575 TosaErrorValidator.evWrongInputList,
3576 TosaErrorValidator.evWrongOutputList,
3577 TosaErrorValidator.evWrongRank,
3578 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003579 },
3580 "scatter": {
3581 "op": Op.SCATTER,
3582 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003583 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003584 "operands": (2, 0),
3585 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003586 "build_fcn": (
3587 build_scatter,
3588 TosaTensorGen.tgScatter,
3589 TosaTensorValuesGen.tvgDefault,
3590 None,
3591 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003592 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003593 "error_if_validators": (
3594 TosaErrorValidator.evWrongInputType,
3595 TosaErrorValidator.evWrongOutputType,
3596 TosaErrorValidator.evWrongInputList,
3597 TosaErrorValidator.evWrongOutputList,
3598 TosaErrorValidator.evWrongRank,
3599 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003600 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003601 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003602 "resize": {
3603 "op": Op.RESIZE,
3604 "operands": (1, 0),
3605 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003606 "build_fcn": (
3607 build_resize,
3608 TosaTensorGen.tgNHWC,
3609 TosaTensorValuesGen.tvgDefault,
3610 TosaArgGen.agResize,
3611 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003612 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003613 "invalid_test_validators": (
3614 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
3615 TosaInvalidValidator.ivBadStride,
3616 ),
3617 "error_if_validators": (
3618 TosaErrorValidator.evMaxDimExceeded,
3619 TosaErrorValidator.evStrideSmallerEqualZero,
3620 TosaErrorValidator.evStrideLargerDimension,
3621 TosaErrorValidator.evStrideLargerEqualMax,
3622 TosaErrorValidator.evOffsetSmallerEqualMin,
3623 TosaErrorValidator.evOffsetLargerEqualMax,
3624 TosaErrorValidator.evShiftNotZero,
3625 TosaErrorValidator.evShiftSmallerOne,
3626 TosaErrorValidator.evShiftLargerEleven,
3627 TosaErrorValidator.evWrongInputType,
3628 TosaErrorValidator.evWrongOutputType,
3629 TosaErrorValidator.evWrongRank,
3630 TosaErrorValidator.evWrongInputList,
3631 TosaErrorValidator.evWrongOutputList,
3632 TosaErrorValidator.evBatchMismatch,
3633 TosaErrorValidator.evChannelMismatch,
3634 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003635 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003636 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003637 "cast": {
3638 "op": Op.CAST,
3639 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003640 "build_fcn": (
3641 build_cast,
3642 TosaTensorGen.tgBasic,
3643 TosaTensorValuesGen.tvgDefault,
3644 TosaArgGen.agCast,
3645 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003646 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003647 "error_if_validators": (
3648 TosaErrorValidator.evWrongInputType,
3649 TosaErrorValidator.evWrongOutputType,
3650 TosaErrorValidator.evWrongInputList,
3651 TosaErrorValidator.evWrongOutputList,
3652 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003653 },
3654 "rescale": {
3655 "op": Op.RESCALE,
3656 "operands": (1, 0),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003657 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003658 "build_fcn": (
3659 build_rescale,
3660 TosaTensorGen.tgBasic,
3661 TosaTensorValuesGen.tvgDefault,
3662 TosaArgGen.agRescale,
3663 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003664 "types": [
3665 DType.UINT8,
3666 DType.INT8,
3667 DType.INT16,
3668 DType.INT32,
3669 DType.INT48,
3670 DType.UINT16,
3671 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003672 "error_if_validators": (
3673 TosaErrorValidator.evInputZeroPointNotZero,
3674 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003675 TosaErrorValidator.evU16InputZeroPointNotValid,
3676 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003677 TosaErrorValidator.evScaleTrue,
3678 TosaErrorValidator.evScaleNotTrue,
3679 TosaErrorValidator.evWrongInputType,
3680 TosaErrorValidator.evWrongOutputType,
3681 TosaErrorValidator.evWrongRank,
3682 TosaErrorValidator.evWrongInputList,
3683 TosaErrorValidator.evWrongOutputList,
3684 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003685 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003686 # Custom
3687 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003688 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003689 # Two varients of cond_if, one that generates one of two constant tensors (no
3690 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3691 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003692 "cond_if_const": {
3693 "op": Op.COND_IF,
3694 "operands": (0, 2),
3695 "build_fcn": (
3696 build_cond_if_const,
3697 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003698 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003699 TosaArgGen.agCondIf,
3700 ),
3701 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003702 "error_if_validators": (
3703 TosaErrorValidator.evOutputListThenGraphMismatch,
3704 TosaErrorValidator.evOutputListElseGraphMismatch,
3705 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003706 },
3707 "cond_if_binary": {
3708 "op": Op.COND_IF,
3709 "operands": (2, 0),
3710 "build_fcn": (
3711 build_cond_if_binary,
3712 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003713 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003714 TosaArgGen.agCondIf,
3715 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003716 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003717 "error_if_validators": (
3718 TosaErrorValidator.evInputListThenGraphMismatch,
3719 TosaErrorValidator.evInputListElseGraphMismatch,
3720 TosaErrorValidator.evOutputListThenGraphMismatch,
3721 TosaErrorValidator.evOutputListElseGraphMismatch,
3722 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003723 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003724 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003725 "while_loop": {
3726 "op": Op.WHILE_LOOP,
3727 "operands": (0, 1),
3728 "build_fcn": (
3729 build_while_loop,
3730 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003731 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003732 TosaArgGen.agWhileLoop,
3733 ),
3734 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003735 "error_if_validators": (
3736 TosaErrorValidator.evInputListOutputListMismatch,
3737 TosaErrorValidator.evInputListCondGraphMismatch,
3738 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3739 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3740 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
3741 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003742 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003743 }
3744
Kevin Cheng550ccc52021-03-03 11:21:43 -08003745
Eric Kunzee5e26762020-10-13 16:11:07 -07003746class OutputShaper:
3747 # Methods in this class compute the expected output shape and datatype
3748 # for common classes of operations
3749 def __init__(self):
3750 pass
3751
3752 # These methods return arguments that can be used for
3753 # creating a new output tensor
3754 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003755 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3756 if error_name != ErrorIf.RankMismatch:
3757 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003758 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003759
3760 shape = []
3761 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003762 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003763 shape.append(b.shape[i])
3764 else:
3765 shape.append(a.shape[i])
3766
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003767 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003768 all_dtypes = [
3769 DType.INT8,
3770 DType.INT16,
3771 DType.INT32,
3772 DType.INT48,
3773 DType.FLOAT,
3774 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003775 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3776 outputDType = rng.choice(wrong_dtypes)
3777 else:
3778 outputDType = a.dtype
3779
3780 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003781
3782 @staticmethod
3783 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003784 assert len(a.shape) == len(b.shape)
3785 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003786
3787 shape = []
3788 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003789 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003790 shape.append(a.shape[i])
3791
Kevin Cheng550ccc52021-03-03 11:21:43 -08003792 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003793
3794 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003795 def unaryOp(ser, rng, a, error_name=None):
3796 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003797 all_dtypes = [
3798 DType.INT8,
3799 DType.INT16,
3800 DType.INT32,
3801 DType.INT48,
3802 DType.FLOAT,
3803 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003804 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3805 outputDType = rng.choice(wrong_dtypes)
3806 else:
3807 outputDType = a.dtype
3808
3809 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003810
3811 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003812 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003813 if error_name != ErrorIf.RankMismatch:
3814 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003815 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003816
3817 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003818 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003820 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3821 else:
3822 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003823
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003824 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003825 all_dtypes = [
3826 DType.INT8,
3827 DType.INT16,
3828 DType.INT32,
3829 DType.INT48,
3830 DType.FLOAT,
3831 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003832 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3833 outputDType = rng.choice(wrong_dtypes)
3834 else:
3835 outputDType = a.dtype
3836
3837 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003838
3839 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003840 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003841 if error_name != ErrorIf.RankMismatch:
3842 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003843 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003844
3845 # Do broadcast
3846 shape = []
3847 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08003848 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07003849 shape.append(b.shape[i])
3850 else:
3851 shape.append(a.shape[i])
3852
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003853 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003854 wrong_dtypes = [
3855 DType.INT8,
3856 DType.INT16,
3857 DType.INT32,
3858 DType.INT48,
3859 DType.FLOAT,
3860 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003861 outputDType = rng.choice(wrong_dtypes)
3862 else:
3863 outputDType = DType.BOOL
3864
3865 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003866
3867 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01003868 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003869 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003870 if error_name not in [
3871 ErrorIf.AxisSmallerZero,
3872 ErrorIf.AxisLargerRank,
3873 ErrorIf.ShapeOfAxisNotOne,
3874 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01003875 shape[axis] = 1
3876 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
3877 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07003878
Matthew Haddond6ce7252021-09-29 15:35:44 +01003879 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003880 all_dtypes = [
3881 DType.INT8,
3882 DType.INT16,
3883 DType.INT32,
3884 DType.INT48,
3885 DType.FLOAT,
3886 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01003887 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3888 outputDType = rng.choice(wrong_dtypes)
3889 else:
3890 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003891
Matthew Haddond6ce7252021-09-29 15:35:44 +01003892 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003893
3894 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003895 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003896 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003897
3898 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
3899 del shape[axis]
3900
3901 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
3902 remove = rng.choice([True, False])
3903 if remove and len(shape) > 1:
3904 del shape[0]
3905 else:
3906 shape.append(1)
3907 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
3908 for i in range(len(shape)):
3909 shape[i] = shape[i] + rng.integers(1, 10)
3910
3911 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003912 all_dtypes = [
3913 DType.INT8,
3914 DType.INT16,
3915 DType.INT32,
3916 DType.INT48,
3917 DType.FLOAT,
3918 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003919 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
3920 outputDType = rng.choice(wrong_dtypes)
3921 else:
3922 outputDType = DType.INT32
3923
3924 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003925
3926 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003927 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003928
3929 # IFM: NHWC
3930 # Filter: OHWI
3931 # OFM: NHWC
3932
Kevin Cheng550ccc52021-03-03 11:21:43 -08003933 h = (
3934 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003935 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003936 + padding[0]
3937 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003938 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003939 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003940
Kevin Cheng550ccc52021-03-03 11:21:43 -08003941 w = (
3942 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003943 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003944 + padding[2]
3945 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003946 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003947 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003948
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003949 if error_name == ErrorIf.ConvOutputShapeMismatch:
3950 choices = [1, 2, 3]
3951 change = rng.choice(choices)
3952 # increment in multiples of stride to not hit non-integer error case
3953 if change in [1, 3]:
3954 h = h + (rng.choice(choices) * strides[0])
3955 if change in [2, 3]:
3956 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00003957
Eric Kunzee5e26762020-10-13 16:11:07 -07003958 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3959
Kevin Cheng3a478572021-01-22 17:21:02 -08003960 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003961 out_dtype = DType.INT32
3962 elif ifm.dtype == DType.INT16:
3963 out_dtype = DType.INT48
3964 elif ifm.dtype == DType.FLOAT:
3965 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00003966 elif error_name == ErrorIf.WrongInputType:
3967 # Pick some potentially correct output dtype if input type is incorrect
3968 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07003969 else:
Les Bell0e027d42021-11-09 14:42:14 +00003970 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
3971
3972 if error_name == ErrorIf.WrongOutputType:
3973 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
3974 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07003975
Kevin Cheng550ccc52021-03-03 11:21:43 -08003976 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003977
3978 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003979 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07003980
3981 # IFM: NDHWC
3982 # Filter: ODHWI
3983 # OFM: NDHWC
3984
3985 d = (
3986 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003987 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003988 + padding[0]
3989 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003990 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07003991 ) // strides[0] + 1
3992
3993 h = (
3994 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003995 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003996 + padding[2]
3997 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003998 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07003999 ) // strides[1] + 1
4000
4001 w = (
4002 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004003 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004004 + padding[4]
4005 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004006 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004007 ) // strides[2] + 1
4008
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004009 if error_name == ErrorIf.ConvOutputShapeMismatch:
4010 choices = [1, 2, 3, 4]
4011 change = rng.choice(choices)
4012 # increment in multiples of stride to not hit non-integer error case
4013 if change in [1, 4]:
4014 d = d + (rng.choice(choices) * strides[0])
4015 if change in [2, 4]:
4016 h = h + (rng.choice(choices) * strides[1])
4017 if change in [3, 4]:
4018 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004019
Kevin Cheng1533b852021-09-01 12:51:58 -07004020 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4021
4022 if ifm.dtype == DType.INT8:
4023 out_dtype = DType.INT32
4024 elif ifm.dtype == DType.INT16:
4025 out_dtype = DType.INT48
4026 elif ifm.dtype == DType.FLOAT:
4027 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004028 elif error_name == ErrorIf.WrongInputType:
4029 # Pick some potentially correct output dtype if input type is incorrect
4030 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004031 else:
Les Bell0e027d42021-11-09 14:42:14 +00004032 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4033
4034 if error_name == ErrorIf.WrongOutputType:
4035 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4036 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004037
4038 return ser.addOutput(ofm_shape, out_dtype)
4039
4040 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004041 def depthwiseConv2dOp(
4042 ser, rng, ifm, filter, strides, padding, dilations, error_name=None
4043 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004044 # IFM: NHWC
4045 # Filter: HWCM
4046 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004047
Kevin Cheng550ccc52021-03-03 11:21:43 -08004048 h = (
4049 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004050 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004051 + padding[0]
4052 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004053 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004054 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004055
Kevin Cheng550ccc52021-03-03 11:21:43 -08004056 w = (
4057 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004058 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004059 + padding[2]
4060 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004061 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004062 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004063
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004064 if error_name == ErrorIf.ConvOutputShapeMismatch:
4065 choices = [1, 2, 3]
4066 change = rng.choice(choices)
4067 # increment in multiples of stride to not hit non-integer error case
4068 if change in [1, 3]:
4069 h = h + (rng.choice(choices) * strides[0])
4070 if change in [2, 3]:
4071 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004072
Eric Kunzee5e26762020-10-13 16:11:07 -07004073 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4074
Kevin Cheng3a478572021-01-22 17:21:02 -08004075 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004076 out_dtype = DType.INT32
4077 elif ifm.dtype == DType.INT16:
4078 out_dtype = DType.INT48
4079 elif ifm.dtype == DType.FLOAT:
4080 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004081 elif error_name == ErrorIf.WrongInputType:
4082 # Pick some potentially correct output dtype if input type is incorrect
4083 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004084 else:
Les Bell0e027d42021-11-09 14:42:14 +00004085 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4086
4087 if error_name == ErrorIf.WrongOutputType:
4088 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4089 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004090
Kevin Cheng550ccc52021-03-03 11:21:43 -08004091 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004092
4093 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004094 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004095 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004096 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004097 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004098 h = 1
4099 w = 1
4100 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004101 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4102 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004103
4104 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004105 choices = [1, 2, 3]
4106 change = rng.choice(choices)
4107 # increment in multiples of stride to not hit non-integer error case
4108 if change in [1, 3]:
4109 h = h + (rng.choice(choices) * stride[0])
4110 if change in [2, 3]:
4111 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004112 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004113
4114 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004115 all_dtypes = [
4116 DType.INT8,
4117 DType.INT16,
4118 DType.INT32,
4119 DType.INT48,
4120 DType.FLOAT,
4121 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004122 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4123 outputDType = rng.choice(wrong_dtypes)
4124 else:
4125 outputDType = ifm.dtype
4126
4127 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004128
4129 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004130 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004131 # input: N, IC
4132 # filter: OC, IC
4133 # output: N, OC
4134
4135 output_shape = [input.shape[0], filter.shape[0]]
4136
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004137 if error_name == ErrorIf.WrongOutputType:
4138 if input.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004139 incorrect_types = (
4140 DType.INT4,
4141 DType.INT8,
4142 DType.INT16,
4143 DType.INT48,
4144 DType.FLOAT,
4145 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004146 elif input.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004147 incorrect_types = (
4148 DType.INT4,
4149 DType.INT8,
4150 DType.INT16,
4151 DType.INT32,
4152 DType.FLOAT,
4153 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004154 elif input.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004155 incorrect_types = (
4156 DType.INT4,
4157 DType.INT8,
4158 DType.INT16,
4159 DType.INT32,
4160 DType.INT48,
4161 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004162 out_dtype = rng.choice(a=incorrect_types)
4163 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004164 out_dtype = DType.INT32
4165 elif input.dtype == DType.INT16:
4166 out_dtype = DType.INT48
4167 elif input.dtype == DType.FLOAT:
4168 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004169 elif error_name == ErrorIf.WrongInputType:
4170 # Pick some potentially correct output dtype if input type is incorrect
4171 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004172 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004173 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004174
Kevin Cheng550ccc52021-03-03 11:21:43 -08004175 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004176
4177 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004178 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004179 # a: N, H, C
4180 # b: N, C, W
4181 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004182
Kevin Cheng2d60f002021-06-09 14:18:32 -07004183 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004184
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004185 if error_name == ErrorIf.WrongOutputType:
4186 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004187 incorrect_types = (
4188 DType.INT4,
4189 DType.INT8,
4190 DType.INT16,
4191 DType.INT48,
4192 DType.FLOAT,
4193 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004194 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004195 incorrect_types = (
4196 DType.INT4,
4197 DType.INT8,
4198 DType.INT16,
4199 DType.INT32,
4200 DType.FLOAT,
4201 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004202 elif a.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004203 incorrect_types = (
4204 DType.INT4,
4205 DType.INT8,
4206 DType.INT16,
4207 DType.INT32,
4208 DType.INT48,
4209 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004210 out_dtype = rng.choice(a=incorrect_types)
4211 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004212 out_dtype = DType.INT32
4213 elif a.dtype == DType.INT16:
4214 out_dtype = DType.INT48
4215 elif a.dtype == DType.FLOAT:
4216 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004217 elif error_name == ErrorIf.WrongInputType:
4218 # Pick some potentially correct output dtype if input type is incorrect
4219 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004220 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004221 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004222
Kevin Cheng550ccc52021-03-03 11:21:43 -08004223 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004224
4225 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004226 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004227 input1 = a[0]
4228 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004229
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004230 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004231 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004232 if not (
4233 # unable to concat tensors of different ranks
4234 error_name == ErrorIf.ConcatInputRankMismatch
4235 # unable to concat tensors along an invalid axis
4236 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004237 ):
4238 for tensor in remaining_inputs:
4239 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004240
Matthew Haddon01c359d2021-10-15 16:30:48 +01004241 if error_name == ErrorIf.ConcatShapeSumMismatch:
4242 output_shape[axis] += rng.integers(5, 10)
4243
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004244 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004245 all_dtypes = {
4246 DType.INT8,
4247 DType.INT16,
4248 DType.INT32,
4249 DType.INT48,
4250 DType.FLOAT,
4251 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004252 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4253 outputDType = rng.choice(wrong_dtypes)
4254 else:
4255 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004256
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004257 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004258
4259 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004260 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004261
4262 output_shape = a.shape.copy()
4263
4264 for i in range(len(output_shape)):
4265 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4266
Matthew Haddone807aae2021-10-11 18:12:58 +01004267 # Fix negative output shape if error_if test causes it
4268 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
4269 output_shape = [i if i >= 1 else 1 for i in output_shape]
4270
4271 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004272 all_dtypes = [
4273 DType.INT8,
4274 DType.INT16,
4275 DType.INT32,
4276 DType.INT48,
4277 DType.FLOAT,
4278 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004279 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4280 outputDType = rng.choice(wrong_dtypes)
4281 else:
4282 outputDType = a.dtype
4283
4284 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004285
4286 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004287 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004288 output_shape = shape.copy()
4289
Matthew Haddone807aae2021-10-11 18:12:58 +01004290 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4291 for i in range(len(output_shape)):
4292 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4293
4294 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004295 all_dtypes = [
4296 DType.INT8,
4297 DType.INT16,
4298 DType.INT32,
4299 DType.INT48,
4300 DType.FLOAT,
4301 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004302 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4303 outputDType = rng.choice(wrong_dtypes)
4304 else:
4305 outputDType = a.dtype
4306
4307 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004308
4309 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004310 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004311
Matthew Haddone807aae2021-10-11 18:12:58 +01004312 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004313 all_dtypes = [
4314 DType.INT8,
4315 DType.INT16,
4316 DType.INT32,
4317 DType.INT48,
4318 DType.FLOAT,
4319 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004320 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4321 outputDType = rng.choice(wrong_dtypes)
4322 else:
4323 outputDType = a.dtype
4324
4325 if error_name == ErrorIf.SizeOutputShapeMismatch:
4326 output_shape = size.copy()
4327 for index in range(len(output_shape)):
4328 if output_shape[index] <= 2:
4329 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4330 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004331 output_shape[index] = output_shape[index] + rng.choice(
4332 [-2, -1, 1, 2]
4333 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004334 else:
4335 output_shape = size.copy()
4336
4337 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004338
4339 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004340 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004341
4342 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004343 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004344
4345 for i in range(len(output_shape)):
4346 output_shape[i] = a.shape[i] * multiples[i]
4347
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004348 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004349 all_dtypes = [
4350 DType.INT8,
4351 DType.INT16,
4352 DType.INT32,
4353 DType.INT48,
4354 DType.FLOAT,
4355 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004356 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4357 outputDType = rng.choice(wrong_dtypes)
4358 else:
4359 outputDType = a.dtype
4360
4361 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004362
4363 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004364 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004365 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004366
Kevin Cheng550ccc52021-03-03 11:21:43 -08004367 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004368
Matthew Haddone807aae2021-10-11 18:12:58 +01004369 if error_name == ErrorIf.IndexOutsideBounds:
4370 for i in range(len(output_shape)):
4371 output_shape[i] = a.shape[0]
4372 else:
4373 for i in range(len(output_shape)):
4374 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004375
Matthew Haddone807aae2021-10-11 18:12:58 +01004376 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004377 all_dtypes = [
4378 DType.INT8,
4379 DType.INT16,
4380 DType.INT32,
4381 DType.INT48,
4382 DType.FLOAT,
4383 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004384 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4385 outputDType = rng.choice(wrong_dtypes)
4386 else:
4387 outputDType = a.dtype
4388
4389 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004390
4391 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004392 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004393 if error_name != ErrorIf.WrongRank:
4394 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004395 assert len(indices.shape) == 2
4396 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004397
Kevin Cheng77d0f762020-11-24 10:26:32 -08004398 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4399
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004400 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004401 all_dtypes = [
4402 DType.INT8,
4403 DType.INT16,
4404 DType.INT32,
4405 DType.INT48,
4406 DType.FLOAT,
4407 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004408 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4409 outputDType = rng.choice(wrong_dtypes)
4410 else:
4411 outputDType = values.dtype
4412
4413 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004414
4415 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004416 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004417 if error_name != ErrorIf.WrongRank:
4418 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004419 assert len(indices.shape) == 2
4420 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004421 assert values_in.shape[0] == indices.shape[0] # N
4422 assert input.shape[1] == indices.shape[1] # W
4423 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004424
4425 output_shape = values_in.shape
4426
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004427 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004428 all_dtypes = [
4429 DType.INT8,
4430 DType.INT16,
4431 DType.INT32,
4432 DType.INT48,
4433 DType.FLOAT,
4434 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004435 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4436 outputDType = rng.choice(wrong_dtypes)
4437 else:
4438 outputDType = values_in.dtype
4439
4440 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004441
4442 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004443 def tableOp(ser, rng, input, error_name=None):
4444 # Same shape as the input, dtype dependent on input dtype
4445 if error_name != ErrorIf.WrongInputType:
4446 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004447 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004448 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004449 wrong_dtypes = [
4450 DType.INT8,
4451 DType.INT16,
4452 DType.INT32,
4453 DType.INT48,
4454 DType.FLOAT,
4455 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004456 wrong_dtypes.remove(output_dtype)
4457 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004458 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004459
4460 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004461 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004462 serializer,
4463 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004464 input,
4465 mode,
4466 stride,
4467 offset,
4468 shift,
4469 stride_fp,
4470 offset_fp,
4471 output_dims,
4472 input_dtype,
4473 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004474 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004475 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01004476 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004477 output_dims = [
4478 input.shape[0],
4479 output_dims[0],
4480 output_dims[0],
4481 input.shape[0],
4482 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004483 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004484 if error_name == ErrorIf.BatchMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004485 output_dims = [
4486 input.shape[0] + rng.integers(1, 10),
4487 output_dims[0],
4488 output_dims[1],
4489 input.shape[3],
4490 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004491 elif error_name == ErrorIf.ChannelMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004492 output_dims = [
4493 input.shape[0],
4494 output_dims[0],
4495 output_dims[1],
4496 input.shape[3] + rng.integers(1, 10),
4497 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004498 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004499 output_dims = [
4500 input.shape[0],
4501 output_dims[0],
4502 output_dims[1],
4503 input.shape[3],
4504 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07004505
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004506 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004507
4508 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004509 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004510 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004511
4512 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00004513 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004514 if error_name == ErrorIf.ConvOutputShapeMismatch:
4515 choices = [1, 2, 3]
4516 change = rng.choice(choices)
4517 if change in [1, 3]:
4518 output_shape[1] = output_shape[1] + rng.choice(choices)
4519 if change in [2, 3]:
4520 output_shape[2] = output_shape[2] + rng.choice(choices)
4521
Kevin Cheng3a478572021-01-22 17:21:02 -08004522 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004523 out_dtype = DType.INT32
4524 elif ifm.dtype == DType.INT16:
4525 out_dtype = DType.INT48
4526 elif ifm.dtype == DType.FLOAT:
4527 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004528 elif error_name == ErrorIf.WrongInputType:
4529 # Pick some potentially correct output dtype if input type is incorrect
4530 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004531 else:
Les Bell0e027d42021-11-09 14:42:14 +00004532 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4533
4534 if error_name == ErrorIf.WrongOutputType:
4535 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4536 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004537
Kevin Cheng550ccc52021-03-03 11:21:43 -08004538 return ser.addOutput(output_shape, out_dtype)