blob: c9c6d7ed2f8c3e4e2f01fc631dca0860de163356 [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
16from generator.tosa_utils import usableDTypes
Les Bell0e027d42021-11-09 14:42:14 +000017from tosa.DType import DType
18from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010019
20
Eric Kunzee5e26762020-10-13 16:11:07 -070021class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010022 # Maximum rank of tensor supported by test generator.
23 TOSA_TENSOR_MAX_RANK = 6
24
Eric Kunzee5e26762020-10-13 16:11:07 -070025 def __init__(self, args):
26 self.args = args
27 self.basePath = args.output_dir
28 self.random_seed = args.random_seed
29 self.ser = None
30 self.rng = np.random.default_rng(self.random_seed)
31 self.createDynamicOpLists()
32 self.initOpListDefaults()
33 self.quantGen = TosaQuantGen()
34 # Force makeShape to do a specific starting shape
35 self.targetted_shape = None
36
37 def createSerializer(self, opName, testPath):
38 self.testPath = os.path.join(opName, testPath)
39
40 fullPath = os.path.join(self.basePath, self.testPath)
41 os.makedirs(fullPath, exist_ok=True)
42 self.ser = ts.TosaSerializer(fullPath)
43
44 def getSerializer(self):
45 return self.ser
46
47 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080048 with open(
49 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
50 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070051 fd.write(self.ser.serialize())
52
Kevin Cheng550ccc52021-03-03 11:21:43 -080053 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
54 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070055
Matthew Haddon74567092021-07-16 15:38:20 +010056 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000057 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010058 seed = self.random_seed + 1
59 self.rng = np.random.default_rng(seed)
60
Eric Kunzee5e26762020-10-13 16:11:07 -070061 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070062 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070063 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070064 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070065 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070066 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070067 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010068 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
69 elif dtype == DType.UINT8:
70 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070071 elif dtype == DType.INT16:
72 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010073 elif dtype == DType.UINT16:
74 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070075 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080076 return np.int32(
77 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
78 )
Eric Kunzee5e26762020-10-13 16:11:07 -070079 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080080 return np.int64(
81 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
82 )
Eric Kunzee5e26762020-10-13 16:11:07 -070083 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +010084 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070085 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -080086 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070087
Kevin Cheng989cb052021-04-28 16:29:44 -070088 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -070089 placeholders = []
90
Kevin Cheng989cb052021-04-28 16:29:44 -070091 assert len(shape_list) == len(dtype_list)
92
93 for idx, shape in enumerate(shape_list):
94 arr = self.getRandTensor(shape, dtype_list[idx])
95 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -070096
97 return placeholders
98
Kevin Cheng989cb052021-04-28 16:29:44 -070099 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700100 consts = []
101
Kevin Cheng989cb052021-04-28 16:29:44 -0700102 assert len(shape_list) == len(dtype_list)
103
104 for idx, shape in enumerate(shape_list):
105 arr = self.getRandTensor(shape, dtype_list[idx])
106 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700107
108 return consts
109
110 def makeShape(self, rank):
111 if self.targetted_shape:
112 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800113 return np.int32(
114 self.rng.integers(
115 low=self.args.tensor_shape_range[0],
116 high=self.args.tensor_shape_range[1],
117 size=rank,
118 )
119 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700120
121 def setTargetShape(self, shape):
122 self.targetted_shape = shape
123
124 def randInt(self, low=0, high=256):
125 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
126
127 def getRandNumberDType(self, dtype):
128 if dtype == DType.FLOAT:
129 return self.rng.random()
130 elif dtype == DType.BOOL:
131 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700132 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700133 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700134 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100136 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700137 elif dtype == DType.INT16:
138 low, high = (-32768, 32768)
139 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800142 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 # Special size
144 return np.int64(self.rng.integers(low, high, size=1))[0]
145 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800146 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700147
148 return np.int32(self.rng.integers(low, high, size=1))[0]
149
150 def shapeStr(self, shape):
151
152 sStr = []
153 # Convert to strings
154 for i in shape:
155 sStr.append(str(i))
156
Kevin Cheng550ccc52021-03-03 11:21:43 -0800157 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700158
159 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -0700160 if isinstance(t, list):
161 assert len(t) >= 2
162 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 else:
Kevin Cheng989cb052021-04-28 16:29:44 -0700164 if t == DType.BOOL:
165 return "b"
166 elif t == DType.INT4:
167 return "i4"
168 elif t == DType.INT8:
169 return "i8"
170 elif t == DType.UINT8:
171 return "u8"
172 elif t == DType.INT16:
173 return "i16"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100174 elif t == DType.UINT16:
175 return "u16"
Kevin Cheng989cb052021-04-28 16:29:44 -0700176 elif t == DType.INT32:
177 return "i32"
178 elif t == DType.INT48:
179 return "i48"
180 elif t == DType.FLOAT:
181 return "float"
182 else:
183 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184
185 def typeWidth(self, t):
Jeremy Johnson5d1a3472022-03-31 09:50:06 +0100186 """Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -0800187 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -0700188 return 4
189 elif t == DType.INT8:
190 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -0800191 elif t == DType.UINT8:
192 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 elif t == DType.INT16:
194 return 16
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100195 elif t == DType.UINT16:
196 return 16
Eric Kunzee5e26762020-10-13 16:11:07 -0700197 elif t == DType.INT32:
198 return 32
199 elif t == DType.INT48:
200 return 48
Matthew Haddonc2025212021-10-08 21:21:05 +0100201 elif t == DType.FLOAT:
202 return 32
203 elif t == DType.BOOL:
204 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700205 else:
Les Bell729b0352021-11-24 10:28:21 +0000206 raise Exception(f"Unknown dtype, cannot determine width: {t}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
208 # Argument generators
209 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
210 # Where the string descriptor is used to generate the test name and
211 # The build_fcn_arg_list is expanded and passed to the operator test
212 # build function
213
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100214 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
215 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
216
Matthew Haddon848efb42021-09-09 12:30:53 +0100217 # build_placeholder returns an int, ABS/other ops does not
218 if isinstance(op, int):
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100219 self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
220 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000221 elif op["op"] == Op.IDENTITY:
222 self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100223 return result_tens
224
225 # Ensure new output type has correct qinfo
226 if error_name == ErrorIf.WrongOutputType:
227 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
228 qinfo = ts.TosaSerializerQuantInfo()
229 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000230 TosaQuantGen.getQinfo(self, a.dtype),
231 TosaQuantGen.getQinfo(self, result_tens.dtype),
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100232 )
233
234 # Invalidate Input/Output list for error if checks.
235 input_list = [a.name]
236 output_list = [result_tens.name]
237 pCount, cCount = op["operands"]
238 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000239 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
240 self, error_name, input_list, output_list
241 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100242
Les Bell729b0352021-11-24 10:28:21 +0000243 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100244 self.ser,
245 validator_fcns,
246 error_name,
247 op=op,
248 input_dtype=a.dtype,
249 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000250 qinfo=qinfo,
251 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100252 input_list=input_list,
253 output_list=output_list,
254 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000255 ):
256 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100257
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000258 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700259 return result_tens
260
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100261 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000262 result_tens = OutputShaper.binaryBroadcastOp(
263 self.ser, self.rng, a, b, error_name
264 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100265
266 # Invalidate Input/Output list for error if checks.
267 input_list = [a.name, b.name]
268 output_list = [result_tens.name]
269 pCount, cCount = op["operands"]
270 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000271 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
272 self, error_name, input_list, output_list
273 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100274
Les Bell729b0352021-11-24 10:28:21 +0000275 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100276 self.ser,
277 validator_fcns,
278 error_name,
279 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000280 input1=a,
281 input2=b,
282 input_dtype=a.dtype,
283 output_dtype=result_tens.dtype,
284 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100285 input_list=input_list,
286 output_list=output_list,
287 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000288 ):
289 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100290
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000291 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 return result_tens
293
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100294 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700295 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000296 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700297 return result_tens
298
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000299 def build_arithmetic_right_shift(
300 self, op, a, b, round, validator_fcns=None, error_name=None
301 ):
302 result_tens = OutputShaper.binaryBroadcastOp(
303 self.ser, self.rng, a, b, error_name
304 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100305
306 # Invalidate Input/Output list for error if checks.
307 input_list = [a.name, b.name]
308 output_list = [result_tens.name]
309 pCount, cCount = op["operands"]
310 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000311 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
312 self, error_name, input_list, output_list
313 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100314
Les Bell729b0352021-11-24 10:28:21 +0000315 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100316 self.ser,
317 validator_fcns,
318 error_name,
319 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000320 input1=a,
321 input2=b,
322 input_dtype=a.dtype,
323 output_dtype=result_tens.dtype,
324 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100325 input_list=input_list,
326 output_list=output_list,
327 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000328 ):
329 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800330
331 attr = ts.TosaSerializerAttribute()
332 attr.ArithmeticRightShiftAttribute(round)
333
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000334 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800335 return result_tens
336
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100337 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000338 result_tens = OutputShaper.binaryBroadcastOp(
339 self.ser, self.rng, a, b, error_name
340 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700341
342 # Special for multiply:
343 # Force the result to INT32 for INT types
344 if a.dtype != DType.FLOAT:
345 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100346 if error_name == ErrorIf.WrongOutputType:
347 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
348 outputDType = self.rng.choice(all_dtypes)
349 result_tens.setDtype(outputDType)
350
351 # Invalidate Input/Output list for error if checks.
352 input_list = [a.name, b.name]
353 output_list = [result_tens.name]
354 pCount, cCount = op["operands"]
355 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000356 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
357 self, error_name, input_list, output_list
358 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100359
Les Bell729b0352021-11-24 10:28:21 +0000360 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100361 self.ser,
362 validator_fcns,
363 error_name,
364 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000365 input1=a,
366 input2=b,
367 input_dtype=a.dtype,
368 output_dtype=result_tens.dtype,
369 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100370 input_list=input_list,
371 output_list=output_list,
372 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000373 ):
374 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700375
Kevin Chengaee1fac2020-11-11 13:54:06 -0800376 attr = ts.TosaSerializerAttribute()
377 attr.MulAttribute(shift)
378
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000379 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700380 return result_tens
381
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100382 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
383 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700384
Kevin Chengfe392ce2021-10-18 21:51:55 +0000385 attr = ts.TosaSerializerAttribute()
386 attr.TableAttribute(table)
387
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100388 # Invalidate Input/Output list for error if checks.
389 input_list = [a.name]
390 output_list = [result_tens.name]
391 pCount, cCount = op["operands"]
392 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
394 self, error_name, input_list, output_list
395 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100396
Les Bell729b0352021-11-24 10:28:21 +0000397 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100398 self.ser,
399 validator_fcns,
400 error_name,
401 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000402 input_shape=a.shape,
403 input_dtype=a.dtype,
404 output_dtype=result_tens.dtype,
405 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100406 input_list=input_list,
407 output_list=output_list,
408 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000409 ):
410 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100411
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000412 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700413
414 return result_tens
415
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100416 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
417 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
418
419 # Invalidate Input/Output list for error if checks.
420 input_list = [cond.name, a.name, b.name]
421 output_list = [result_tens.name]
422 pCount, cCount = op["operands"]
423 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
425 self, error_name, input_list, output_list
426 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100427
Les Bell729b0352021-11-24 10:28:21 +0000428 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100429 self.ser,
430 validator_fcns,
431 error_name,
432 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000433 input1=cond,
434 input2=a,
435 input3=b,
436 input_shape=a.shape,
437 input_dtype=a.dtype,
438 output_dtype=result_tens.dtype,
439 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100440 input_list=input_list,
441 output_list=output_list,
442 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000443 ):
444 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100445
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000446 self.ser.addOperator(
447 op["op"],
448 input_list,
449 output_list,
450 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700451 return result_tens
452
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100453 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000454 result_tens = OutputShaper.binaryComparisonOp(
455 self.ser, self.rng, a, b, error_name
456 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457
458 # Invalidate Input/Output list for error if checks.
459 input_list = [a.name, b.name]
460 output_list = [result_tens.name]
461 pCount, cCount = op["operands"]
462 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000463 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
464 self, error_name, input_list, output_list
465 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100466
Les Bell729b0352021-11-24 10:28:21 +0000467 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100468 self.ser,
469 validator_fcns,
470 error_name,
471 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000472 input1=a,
473 input2=b,
474 input_shape=a.shape,
475 input_dtype=a.dtype,
476 output_shape=result_tens.shape,
477 output_dtype=result_tens.dtype,
478 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100479 input_list=input_list,
480 output_list=output_list,
481 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000482 ):
483 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100484
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000485 self.ser.addOperator(
486 op["op"],
487 input_list,
488 output_list,
489 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700490 return result_tens
491
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100492 def build_argmax(self, op, a, axis, validator_fcns, error_name):
493 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
494
495 # Invalidate Input/Output list for error if checks.
496 input_list = [a.name]
497 output_list = [result_tens.name]
498 pCount, cCount = op["operands"]
499 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
501 self, error_name, input_list, output_list
502 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100503
Les Bell729b0352021-11-24 10:28:21 +0000504 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100505 self.ser,
506 validator_fcns,
507 error_name,
508 op=op,
509 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000510 input_shape=a.shape,
511 input_dtype=a.dtype,
512 output_shape=result_tens.shape,
513 output_dtype=result_tens.dtype,
514 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100515 input_list=input_list,
516 output_list=output_list,
517 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000518 ):
519 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700520
521 attr = ts.TosaSerializerAttribute()
522 attr.AxisAttribute(axis)
523
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000524 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 return result_tens
526
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 def build_pool2d(
528 self,
529 op,
530 input,
531 stride,
532 pad,
533 kernel,
534 validator_fcns=None,
535 error_name=None,
536 qinfo=None,
537 ):
538 result_tens = OutputShaper.pool2dOp(
539 self.ser, self.rng, input, kernel, stride, pad, error_name
540 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100541
542 # Ensure new output type has correct qinfo
543 if error_name == ErrorIf.WrongInputType:
544 if input.dtype not in [DType.INT8, DType.UINT8]:
545 qinfo = ts.TosaSerializerQuantInfo()
546 qinfo.UnaryQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000547 TosaQuantGen.getQinfo(self, input.dtype),
548 TosaQuantGen.getQinfo(self, result_tens.dtype),
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100549 )
550
551 # Invalidate Input/Output list for error if checks.
552 input_list = [input.name]
553 output_list = [result_tens.name]
554 pCount, cCount = op["operands"]
555 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000556 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
557 self, error_name, input_list, output_list
558 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100559
Les Bell729b0352021-11-24 10:28:21 +0000560 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100561 self.ser,
562 validator_fcns,
563 error_name,
564 op=op,
565 input_shape=input.shape,
566 input_dtype=input.dtype,
567 output_shape=result_tens.shape,
568 output_dtype=result_tens.dtype,
569 kernel=kernel,
570 stride=stride,
571 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000572 qinfo=qinfo,
573 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100574 input_list=input_list,
575 output_list=output_list,
576 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000577 ):
578 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700579
580 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700581 attr.PoolAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -0700582
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000583 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700584 return result_tens
585
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000586 def build_conv2d(
587 self,
588 op,
589 ifm,
590 filter,
591 bias,
592 strides,
593 padding,
594 dilations,
595 validator_fcns=None,
596 error_name=None,
597 qinfo=None,
598 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800599 assert len(padding) == 4
600 result_tens = OutputShaper.conv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000601 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
602 )
603
604 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000605 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
606 DType.INT8,
607 DType.UINT8,
608 ):
Les Bell0e027d42021-11-09 14:42:14 +0000609 qinfo = ts.TosaSerializerQuantInfo()
610 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000611 TosaQuantGen.getQinfo(self, ifm.dtype),
612 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000613 )
614
615 # Invalidate Input/Output list for error_if checks.
616 input_list = [ifm.name, filter.name, bias.name]
617 output_list = [result_tens.name]
618 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000619 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
620 self, error_name, input_list, output_list
621 )
Les Bell0e027d42021-11-09 14:42:14 +0000622
Les Bell729b0352021-11-24 10:28:21 +0000623 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000624 self.ser,
625 validator_fcns,
626 error_name,
627 op=op,
628 input_dtype=ifm.dtype,
629 weight_dtype=filter.dtype,
630 output_dtype=result_tens.dtype,
631 qinfo=qinfo,
632 input_list=input_list,
633 num_operands=num_operands,
634 output_list=output_list,
635 pad=padding,
636 stride=strides,
637 dilation=dilations,
638 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100639 weight_shape=filter.shape,
640 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000641 ):
642 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700643
644 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700645 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700646
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000647 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700648 return result_tens
649
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000650 def build_conv3d(
651 self,
652 op,
653 ifm,
654 filter,
655 bias,
656 strides,
657 padding,
658 dilations,
659 validator_fcns=None,
660 error_name=None,
661 qinfo=None,
662 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700663 assert len(padding) == 6
664 result_tens = OutputShaper.conv3dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000665 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
666 )
667
668 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000669 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
670 DType.INT8,
671 DType.UINT8,
672 ):
Les Bell0e027d42021-11-09 14:42:14 +0000673 qinfo = ts.TosaSerializerQuantInfo()
674 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000675 TosaQuantGen.getQinfo(self, ifm.dtype),
676 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000677 )
678
679 # Invalidate Input/Output list for error_if checks.
680 input_list = [ifm.name, filter.name, bias.name]
681 output_list = [result_tens.name]
682 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000683 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
684 self, error_name, input_list, output_list
685 )
Les Bell0e027d42021-11-09 14:42:14 +0000686
Les Bell729b0352021-11-24 10:28:21 +0000687 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000688 self.ser,
689 validator_fcns,
690 error_name,
691 op=op,
692 input_dtype=ifm.dtype,
693 weight_dtype=filter.dtype,
694 output_dtype=result_tens.dtype,
695 qinfo=qinfo,
696 input_list=input_list,
697 num_operands=num_operands,
698 output_list=output_list,
699 pad=padding,
700 stride=strides,
701 dilation=dilations,
702 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100703 weight_shape=filter.shape,
704 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000705 ):
706 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700707
708 attr = ts.TosaSerializerAttribute()
709 attr.ConvAttribute(padding, strides, dilations)
710
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000711 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Kevin Cheng1533b852021-09-01 12:51:58 -0700712 return result_tens
713
Kevin Cheng550ccc52021-03-03 11:21:43 -0800714 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000715 self,
716 op,
717 ifm,
718 filter,
719 bias,
720 stride,
721 outpad,
722 dilation,
723 output_shape,
724 validator_fcns=None,
725 error_name=None,
726 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800727 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100728 assert len(outpad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000729 result_tens = OutputShaper.transposeConv2DOp(
730 self.ser, self.rng, ifm, output_shape, error_name
731 )
Les Bell0e027d42021-11-09 14:42:14 +0000732
733 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000734 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
735 DType.INT8,
736 DType.UINT8,
737 ):
Les Bell0e027d42021-11-09 14:42:14 +0000738 qinfo = ts.TosaSerializerQuantInfo()
739 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000740 TosaQuantGen.getQinfo(self, ifm.dtype),
741 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000742 )
743
744 # Invalidate Input/Output list for error_if checks.
745 input_list = [ifm.name, filter.name, bias.name]
746 output_list = [result_tens.name]
747 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000748 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
749 self, error_name, input_list, output_list
750 )
Les Bell0e027d42021-11-09 14:42:14 +0000751
Les Bell729b0352021-11-24 10:28:21 +0000752 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000753 self.ser,
754 validator_fcns,
755 error_name,
756 op=op,
757 input_dtype=ifm.dtype,
758 weight_dtype=filter.dtype,
759 output_dtype=result_tens.dtype,
760 qinfo=qinfo,
761 input_list=input_list,
762 num_operands=num_operands,
763 output_list=output_list,
764 pad=outpad,
765 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000766 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100767 weight_shape=filter.shape,
768 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000769 ):
770 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700771
772 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700773 attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -0700774
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000775 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700776 return result_tens
777
Kevin Cheng550ccc52021-03-03 11:21:43 -0800778 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000779 self,
780 op,
781 ifm,
782 filter,
783 bias,
784 strides,
785 padding,
786 dilations,
787 validator_fcns=None,
788 error_name=None,
789 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800790 ):
791 result_tens = OutputShaper.depthwiseConv2dOp(
Les Bell0e027d42021-11-09 14:42:14 +0000792 self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
793 )
794
795 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
797 DType.INT8,
798 DType.UINT8,
799 ):
Les Bell0e027d42021-11-09 14:42:14 +0000800 qinfo = ts.TosaSerializerQuantInfo()
801 qinfo.ConvQuantInfo(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000802 TosaQuantGen.getQinfo(self, ifm.dtype),
803 TosaQuantGen.getQinfo(self, result_tens.dtype),
Les Bell0e027d42021-11-09 14:42:14 +0000804 )
805
806 # Invalidate Input/Output list for error_if checks.
807 input_list = [ifm.name, filter.name, bias.name]
808 output_list = [result_tens.name]
809 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000810 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
811 self, error_name, input_list, output_list
812 )
Les Bell0e027d42021-11-09 14:42:14 +0000813
Les Bell729b0352021-11-24 10:28:21 +0000814 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000815 self.ser,
816 validator_fcns,
817 error_name,
818 op=op,
819 input_dtype=ifm.dtype,
820 weight_dtype=filter.dtype,
821 output_dtype=result_tens.dtype,
822 qinfo=qinfo,
823 input_list=input_list,
824 num_operands=num_operands,
825 output_list=output_list,
826 pad=padding,
827 stride=strides,
828 dilation=dilations,
829 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100830 weight_shape=filter.shape,
831 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000832 ):
833 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700834
835 attr = ts.TosaSerializerAttribute()
Kevin Cheng93a16282021-08-31 16:14:03 -0700836 attr.ConvAttribute(padding, strides, dilations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700837
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000838 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700839 return result_tens
840
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000841 def build_fully_connected(
842 self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
843 ):
844 result_tens = OutputShaper.fullyConnectedOp(
845 self.ser, self.rng, ifm, filter, error_name
846 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100847
848 # Invalidate Input/Output list for error if checks.
849 input_list = [ifm.name, filter.name, bias.name]
850 output_list = [result_tens.name]
851 pCount, cCount = op["operands"]
852 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000853 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
854 self, error_name, input_list, output_list
855 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100856
Les Bell729b0352021-11-24 10:28:21 +0000857 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100858 self.ser,
859 validator_fcns,
860 error_name,
861 op=op,
862 input_shape=ifm.shape,
863 input_dtype=ifm.dtype,
864 weight_dtype=filter.dtype,
865 output_shape=result_tens.shape,
866 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000867 qinfo=qinfo,
868 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100869 input_list=input_list,
870 output_list=output_list,
871 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000872 ):
873 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700874
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000875 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700876 return result_tens
877
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100878 def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
879 result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
880
881 # Invalidate Input/Output list for error if checks.
882 input_list = [a.name, b.name]
883 output_list = [result_tens.name]
884 pCount, cCount = op["operands"]
885 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000886 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
887 self, error_name, input_list, output_list
888 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100889
Les Bell729b0352021-11-24 10:28:21 +0000890 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100891 self.ser,
892 validator_fcns,
893 error_name,
894 op=op,
895 input_shape=a.shape,
896 input_dtype=a.dtype,
897 input2_shape=b.shape,
898 input2_dtype=b.dtype,
899 output_shape=result_tens.shape,
900 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000901 qinfo=qinfo,
902 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100903 input_list=input_list,
904 output_list=output_list,
905 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000906 ):
907 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100908
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000909 self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -0700910 return result_tens
911
Matthew Haddond6ce7252021-09-29 15:35:44 +0100912 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
913 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
914
915 # Invalidate Input/Output list for error if checks.
916 input_list = [a.name]
917 output_list = [result_tens.name]
918 pCount, cCount = op["operands"]
919 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000920 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
921 self, error_name, input_list, output_list
922 )
Matthew Haddond6ce7252021-09-29 15:35:44 +0100923
Les Bell729b0352021-11-24 10:28:21 +0000924 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +0100925 self.ser,
926 validator_fcns,
927 error_name,
928 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000929 axis=axis,
930 input_shape=a.shape,
931 output_shape=result_tens.shape,
932 input_dtype=a.dtype,
933 output_dtype=result_tens.dtype,
934 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +0100935 input_list=input_list,
936 output_list=output_list,
937 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000938 ):
939 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700940
941 attr = ts.TosaSerializerAttribute()
942 attr.AxisAttribute(axis)
943
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000944 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700945 return result_tens
946
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100947 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
948 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700949
Jeremy Johnson18e26662021-07-22 16:15:29 +0100950 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700951
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100952 if error_name == ErrorIf.MaxSmallerMin:
953 # Make sure the numbers are different to invoke this error
954 while v[0] == v[1]:
955 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
956 max_val = min(v)
957 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700958 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100959 max_val = max(v)
960 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -0700961
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100962 # Invalidate Input/Output list for error if checks.
963 input_list = [a.name]
964 output_list = [result_tens.name]
965 pCount, cCount = op["operands"]
966 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
968 self, error_name, input_list, output_list
969 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100970
Les Bell729b0352021-11-24 10:28:21 +0000971 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100972 self.ser,
973 validator_fcns,
974 error_name,
975 op=op,
976 max_val=max_val,
977 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000978 input_shape=a.shape,
979 output_shape=result_tens.shape,
980 input_dtype=a.dtype,
981 output_dtype=result_tens.dtype,
982 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100983 input_list=input_list,
984 output_list=output_list,
985 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000986 ):
987 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100988
989 attr = ts.TosaSerializerAttribute()
990 if a.dtype == DType.FLOAT:
991 attr.ClampAttribute(0, 0, min_val, max_val)
992 else:
993 attr.ClampAttribute(min_val, max_val, 0, 0)
994
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000995 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700996 return result_tens
997
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100998 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
999 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001000 attr = ts.TosaSerializerAttribute()
1001
1002 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1003
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001004 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001005 return result_tens
1006
1007 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001008 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1009 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001010
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001011 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001012 return result_tens
1013
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001014 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1015 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1016
1017 # Invalidate Input/Output list for error if checks.
1018 input_list = [a.name]
1019 output_list = [result_tens.name]
1020 pCount, cCount = op["operands"]
1021 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001022 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1023 self, error_name, input_list, output_list
1024 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001025
Les Bell729b0352021-11-24 10:28:21 +00001026 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001027 self.ser,
1028 validator_fcns,
1029 error_name,
1030 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001031 input_shape=a.shape,
1032 output_shape=result_tens.shape,
1033 input_dtype=a.dtype,
1034 output_dtype=result_tens.dtype,
1035 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001036 input_list=input_list,
1037 output_list=output_list,
1038 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001039 ):
1040 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001041
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001042 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 return result_tens
1044
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001045 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1046 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1047
1048 # Invalidate Input/Output list for error if checks.
1049 input_list = [a.name]
1050 output_list = [result_tens.name]
1051 pCount, cCount = op["operands"]
1052 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001053 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1054 self, error_name, input_list, output_list
1055 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001056
Les Bell729b0352021-11-24 10:28:21 +00001057 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001058 self.ser,
1059 validator_fcns,
1060 error_name,
1061 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001062 input_shape=a.shape,
1063 output_shape=result_tens.shape,
1064 input_dtype=a.dtype,
1065 output_dtype=result_tens.dtype,
1066 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001067 input_list=input_list,
1068 output_list=output_list,
1069 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001070 ):
1071 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001072
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001073 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001074 return result_tens
1075
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001076 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1077 if error_name != ErrorIf.WrongInputType:
1078 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001079
1080 # To store variable length list of input tensors we need to store axis along with it
1081 axis = a[-1]
1082 a = a[:-1]
1083
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001084 result_tens = OutputShaper.concatOp(
1085 self.ser, self.rng, axis, *a, error_name=error_name
1086 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001087
Matthew Haddon818ab902021-07-27 09:12:49 +01001088 input_tensor_names = []
1089 for tensor in a:
1090 input_tensor_names.append(tensor.name)
1091
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001092 # Invalidate Input/Output list for error if checks.
1093 input_list = input_tensor_names
1094 output_list = [result_tens.name]
1095 pCount, cCount = op["operands"]
1096 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001097 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1098 self, error_name, input_list, output_list
1099 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001100
Les Bell729b0352021-11-24 10:28:21 +00001101 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001102 self.ser,
1103 validator_fcns,
1104 error_name,
1105 op=op,
1106 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001107 input_shape=a[0].shape,
1108 output_shape=result_tens.shape,
1109 input_dtype=a[0].dtype,
1110 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001111 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001112 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001113 input_list=input_list,
1114 output_list=output_list,
1115 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001116 ):
1117 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001118
1119 attr = ts.TosaSerializerAttribute()
1120 attr.AxisAttribute(axis)
1121
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001122 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001123 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001124
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001125 def build_pad(
1126 self,
1127 op,
1128 a,
1129 padding,
1130 pad_const_int,
1131 pad_const_float,
1132 validator_fcns=None,
1133 error_name=None,
1134 qinfo=None,
1135 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001136 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001137
Kevin Chengfe392ce2021-10-18 21:51:55 +00001138 attr = ts.TosaSerializerAttribute()
1139 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001140
Matthew Haddone807aae2021-10-11 18:12:58 +01001141 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001142 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001143 output_list = [result_tens.name]
1144 pCount, cCount = op["operands"]
1145 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001146 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1147 self, error_name, input_list, output_list
1148 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001149
Les Bell729b0352021-11-24 10:28:21 +00001150 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001151 self.ser,
1152 validator_fcns,
1153 error_name,
1154 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001155 input_shape=a.shape,
1156 output_shape=result_tens.shape,
1157 input_dtype=a.dtype,
1158 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001159 pad=padding,
1160 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001161 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001162 input_list=input_list,
1163 output_list=output_list,
1164 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001165 ):
1166 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001167
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001168 self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
Matthew Haddone86fd342021-09-07 16:12:21 +01001169 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001170
Matthew Haddone807aae2021-10-11 18:12:58 +01001171 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001172 result_tens = OutputShaper.reshapeOp(
1173 self.ser, self.rng, a, newShape, error_name
1174 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001175
1176 # Invalidate Input/Output list for error if checks.
1177 input_list = [a.name]
1178 output_list = [result_tens.name]
1179 pCount, cCount = op["operands"]
1180 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001181 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1182 self, error_name, input_list, output_list
1183 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001184
Les Bell729b0352021-11-24 10:28:21 +00001185 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001186 self.ser,
1187 validator_fcns,
1188 error_name,
1189 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001190 input_shape=a.shape,
1191 output_shape=result_tens.shape,
1192 input_dtype=a.dtype,
1193 output_dtype=result_tens.dtype,
1194 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001195 input_list=input_list,
1196 output_list=output_list,
1197 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001198 ):
1199 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001200
1201 attr = ts.TosaSerializerAttribute()
1202 attr.ReshapeAttribute(newShape)
1203
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001204 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001205 return result_tens
1206
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001207 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1208 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1209
1210 # Invalidate Input/Output list for error if checks.
1211 input_list = [a.name]
1212 output_list = [result_tens.name]
1213 pCount, cCount = op["operands"]
1214 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001215 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1216 self, error_name, input_list, output_list
1217 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001218
Les Bell729b0352021-11-24 10:28:21 +00001219 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001220 self.ser,
1221 validator_fcns,
1222 error_name,
1223 op=op,
1224 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001225 input_shape=a.shape,
1226 output_shape=result_tens.shape,
1227 input_dtype=a.dtype,
1228 output_dtype=result_tens.dtype,
1229 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001230 input_list=input_list,
1231 output_list=output_list,
1232 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001233 ):
1234 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001235
1236 attr = ts.TosaSerializerAttribute()
1237 attr.AxisAttribute(axis)
1238
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001239 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001240 return result_tens
1241
Matthew Haddone807aae2021-10-11 18:12:58 +01001242 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1243 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001244
Kevin Chengfe392ce2021-10-18 21:51:55 +00001245 attr = ts.TosaSerializerAttribute()
1246 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001247
Matthew Haddone807aae2021-10-11 18:12:58 +01001248 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001249 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001250 output_list = [result_tens.name]
1251 pCount, cCount = op["operands"]
1252 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001253 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1254 self, error_name, input_list, output_list
1255 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001256
Les Bell729b0352021-11-24 10:28:21 +00001257 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001258 self.ser,
1259 validator_fcns,
1260 error_name,
1261 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001262 input_shape=a.shape,
1263 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001264 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001265 input_dtype=a.dtype,
1266 output_dtype=result_tens.dtype,
1267 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001268 input_list=input_list,
1269 output_list=output_list,
1270 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001271 ):
1272 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001273
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001274 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001275 return result_tens
1276
Matthew Haddone807aae2021-10-11 18:12:58 +01001277 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001278 result_tens = OutputShaper.sliceOp(
1279 self.ser, self.rng, a, start, size, error_name
1280 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001281
1282 # Invalidate Input/Output list for error if checks.
1283 input_list = [a.name]
1284 output_list = [result_tens.name]
1285 pCount, cCount = op["operands"]
1286 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001287 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1288 self, error_name, input_list, output_list
1289 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001290
Les Bell729b0352021-11-24 10:28:21 +00001291 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001292 self.ser,
1293 validator_fcns,
1294 error_name,
1295 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001296 input_shape=a.shape,
1297 output_shape=result_tens.shape,
1298 input_dtype=a.dtype,
1299 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001300 start=start,
1301 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001302 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001303 input_list=input_list,
1304 output_list=output_list,
1305 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001306 ):
1307 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001308
1309 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001310 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001311
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001313 return result_tens
1314
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001315 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1316 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1317
1318 # Invalidate Input/Output list for error if checks.
1319 input_list = [a.name]
1320 output_list = [result_tens.name]
1321 pCount, cCount = op["operands"]
1322 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001323 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1324 self, error_name, input_list, output_list
1325 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001326
Les Bell729b0352021-11-24 10:28:21 +00001327 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001328 self.ser,
1329 validator_fcns,
1330 error_name,
1331 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001332 input_shape=a.shape,
1333 output_shape=result_tens.shape,
1334 input_dtype=a.dtype,
1335 output_dtype=result_tens.dtype,
1336 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001337 input_list=input_list,
1338 output_list=output_list,
1339 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001340 ):
1341 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001342
1343 attr = ts.TosaSerializerAttribute()
1344 attr.TileAttribute(multiples)
1345
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001346 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001347 return result_tens
1348
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001349 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001350
1351 # Create a new indicies tensor
1352 # here with data that doesn't exceed the dimensions of the values tensor
1353
Kevin Cheng550ccc52021-03-03 11:21:43 -08001354 K = values.shape[1] # K
1355 W = self.randInt(
1356 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1357 ) # W
1358 indicies_arr = np.int32(
1359 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1360 ) # (N, W)
1361 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001362
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001363 result_tens = OutputShaper.gatherOp(
1364 self.ser, self.rng, values, indicies, error_name
1365 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001366
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001367 # Invalidate Input/Output list for error if checks.
1368 input_list = [values.name, indicies.name]
1369 output_list = [result_tens.name]
1370 pCount, cCount = op["operands"]
1371 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1373 self, error_name, input_list, output_list
1374 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001375
Les Bell729b0352021-11-24 10:28:21 +00001376 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001377 self.ser,
1378 validator_fcns,
1379 error_name,
1380 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001381 input_shape=values.shape,
1382 output_shape=result_tens.shape,
1383 input_dtype=values.dtype,
1384 output_dtype=result_tens.dtype,
1385 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001386 input_list=input_list,
1387 output_list=output_list,
1388 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001389 ):
1390 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001391
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001392 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001393
1394 return result_tens
1395
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001396 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001397
1398 # Create a new indicies tensor
1399 # here with data that doesn't exceed the dimensions of the values_in tensor
1400
Kevin Cheng550ccc52021-03-03 11:21:43 -08001401 K = values_in.shape[1] # K
1402 W = input.shape[1] # W
1403 indicies_arr = np.int32(
1404 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1405 ) # (N, W)
1406 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001407
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 result_tens = OutputShaper.scatterOp(
1409 self.ser, self.rng, values_in, indicies, input, error_name
1410 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001411
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001412 # Invalidate Input/Output list for error if checks.
1413 input_list = [values_in.name, indicies.name, input.name]
1414 output_list = [result_tens.name]
1415 pCount, cCount = op["operands"]
1416 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001417 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1418 self, error_name, input_list, output_list
1419 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001420
Les Bell729b0352021-11-24 10:28:21 +00001421 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422 self.ser,
1423 validator_fcns,
1424 error_name,
1425 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001426 input_shape=values_in.shape,
1427 output_shape=result_tens.shape,
1428 input_dtype=values_in.dtype,
1429 output_dtype=result_tens.dtype,
1430 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001431 input_list=input_list,
1432 output_list=output_list,
1433 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001434 ):
1435 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001436
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001437 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001438
Kevin Cheng77d0f762020-11-24 10:26:32 -08001439 return result_tens
1440
Kevin Cheng550ccc52021-03-03 11:21:43 -08001441 def build_resize(
1442 self,
1443 op,
1444 input,
1445 mode,
1446 stride,
1447 offset,
1448 shift,
1449 stride_fp,
1450 offset_fp,
1451 output_dims,
1452 input_dtype,
1453 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001454 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001455 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001456 ):
1457 result_tens = OutputShaper.resizeOp(
1458 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001459 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001460 input,
1461 mode,
1462 stride,
1463 offset,
1464 shift,
1465 stride_fp,
1466 offset_fp,
1467 output_dims,
1468 input_dtype,
1469 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001470 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001471 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001472
Matthew Haddon848efb42021-09-09 12:30:53 +01001473 # Invalidate Input/Output list for error if checks.
1474 input_list = [input.name]
1475 output_list = [result_tens.name]
1476 pCount, cCount = op["operands"]
1477 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001478 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1479 self, error_name, input_list, output_list
1480 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001481
Les Bell729b0352021-11-24 10:28:21 +00001482 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001483 self.ser,
1484 validator_fcns,
1485 error_name,
1486 op=op,
1487 mode=mode,
1488 shift=shift,
1489 input_dtype=input_dtype,
1490 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001491 input_shape=input.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001492 output_shape=output_dims,
1493 offset=offset,
1494 offset_fp=offset_fp,
1495 stride=stride,
1496 stride_fp=stride_fp,
1497 input_list=input_list,
1498 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001499 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001500 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001501 ):
1502 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001503
Eric Kunzee5e26762020-10-13 16:11:07 -07001504 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001505
Kevin Cheng550ccc52021-03-03 11:21:43 -08001506 attr.ResizeAttribute(
1507 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1508 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001509
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001510 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001511 return result_tens
1512
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001513 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1514 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1515 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001516 self.ser.addOperator(
1517 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1518 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001519 return result_tens
1520
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001521 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001522 self.ser.addOutputTensor(val)
1523 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001524
1525 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001526 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001527 result_tens = OutputShaper.typeConversionOp(
1528 self.ser, self.rng, val, out_dtype, error_name
1529 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001530
1531 # Invalidate Input/Output list for error if checks.
1532 input_list = [val.name]
1533 output_list = [result_tens.name]
1534 pCount, cCount = op["operands"]
1535 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001536 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1537 self, error_name, input_list, output_list
1538 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001539
Les Bell729b0352021-11-24 10:28:21 +00001540 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001541 self.ser,
1542 validator_fcns,
1543 error_name,
1544 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001545 input_shape=val.shape,
1546 output_shape=result_tens.shape,
1547 input_dtype=val.dtype,
1548 output_dtype=result_tens.dtype,
1549 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001550 input_list=input_list,
1551 output_list=output_list,
1552 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001553 ):
1554 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001555
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001556 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001557 return result_tens
1558
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001559 def build_rescale(
1560 self,
1561 op,
1562 val,
1563 out_dtype,
1564 scale32,
1565 double_round,
1566 per_channel,
1567 validator_fcns,
1568 error_name,
1569 ):
1570 result_tens = OutputShaper.typeConversionOp(
1571 self.ser, self.rng, val, out_dtype, error_name
1572 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001573
1574 if per_channel:
1575 nc = val.shape[-1]
1576 else:
1577 nc = 1
1578
1579 in_type_width = self.typeWidth(val.dtype)
1580 out_type_width = self.typeWidth(out_dtype)
1581
Kevin Cheng3a478572021-01-22 17:21:02 -08001582 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001583 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001584 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001585 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001586 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001587 in_type_width += 1
1588 elif error_name in [
1589 ErrorIf.InputZeroPointNotZero,
1590 ErrorIf.U16InputZeroPointNotValid,
1591 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001592 input_zp = self.randInt(-128, 128)
1593 if input_zp == 0:
1594 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001595 in_type_width += 1
1596 elif val.dtype == DType.UINT16:
1597 # Must come after ErrorIf.U16InputZeroPointNotValid check
1598 input_zp = self.rng.choice([0, 32768])
1599 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001600 else:
1601 input_zp = 0
1602
Kevin Cheng3a478572021-01-22 17:21:02 -08001603 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001604 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001605 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001606 elif out_dtype == DType.UINT8:
1607 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001608 out_type_width += 1
1609 elif error_name in [
1610 ErrorIf.OutputZeroPointNotZero,
1611 ErrorIf.U16OutputZeroPointNotValid,
1612 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001613 output_zp = self.randInt(-128, 128)
1614 if output_zp == 0:
1615 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001616 out_type_width += 1
1617 elif out_dtype == DType.UINT16:
1618 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1619 output_zp = self.rng.choice([0, 32768])
1620 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001621 else:
1622 output_zp = 0
1623
1624 # Calculate scale based on:
1625 # scale = a *(2^output_width)/(2^input_width))
1626
1627 a = np.float32(self.rng.random(size=[nc]))
1628 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1629
1630 if scale32:
1631 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001632 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001633 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1634 else:
1635 # Cap the scaling at 2^15 - 1 for scale16
1636 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1637
Kevin Cheng550ccc52021-03-03 11:21:43 -08001638 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001639
1640 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1641 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001642 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1643 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001644
1645 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001646 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1647 scale_arr[i], scale32
1648 )
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001649 min_shift_value_arr[i] = -1 << (shift_arr[i] - 2)
1650 max_shift_value_arr[i] = (1 << (shift_arr[i] - 2)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001651
Kevin Cheng550ccc52021-03-03 11:21:43 -08001652 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001653 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001654 # Make sure random values are within apply_scale_32 specification
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001655 # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2))
1656 assert val.placeholderFilename
1657 values = np.load(
1658 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1659 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001660 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1661 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1662 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1663 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001664 if not np.all(np.array_equal(values, val_adj)):
1665 # Values changed so overwrite file with new values
1666 np.save(
1667 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1668 val_adj,
1669 False,
1670 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001671
Matthew Haddonc2025212021-10-08 21:21:05 +01001672 # Invalidate Input/Output list for error if checks.
1673 input_list = [val.name]
1674 output_list = [result_tens.name]
1675 pCount, cCount = op["operands"]
1676 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001677 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1678 self, error_name, input_list, output_list
1679 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001680
1681 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001682 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001683 self.ser,
1684 validator_fcns,
1685 error_name,
1686 op=op,
1687 input_dtype=val.dtype,
1688 output_dtype=out_dtype,
1689 input_shape=val.shape,
1690 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001691 scale32=scale32,
1692 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001693 input_list=input_list,
1694 output_list=output_list,
1695 result_tensor=result_tens,
1696 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001697 ):
1698 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001699
Eric Kunzee5e26762020-10-13 16:11:07 -07001700 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001701 attr.RescaleAttribute(
1702 input_zp,
1703 output_zp,
1704 multiplier_arr,
1705 shift_arr,
1706 scale32,
1707 double_round,
1708 per_channel,
1709 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001710
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001711 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001712 return result_tens
1713
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001714 def build_cond_if_const(
1715 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1716 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001717 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1718 # (except for the generated shap) and the condition. Build Then/Else blocks
1719 # and fill them with const nodes for the body.
1720
1721 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001722 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001723
1724 # Make then/else tensors
1725 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001726
1727 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001728 if error_name in [
1729 ErrorIf.CondIfOutputListThenGraphMismatch,
1730 ErrorIf.CondIfOutputListElseGraphMismatch,
1731 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001732 incorrect_shape = deepcopy(then_tens.shape)
1733 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001734 incorrect_shape[i] += (
1735 self.rng.choice([-3, -2, 2, 3])
1736 if incorrect_shape[i] > 3
1737 else self.rng.choice([1, 2, 4])
1738 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001739 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1740
Jeremy Johnson18e26662021-07-22 16:15:29 +01001741 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1742 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001743
1744 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001745 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001746
1747 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001748 then_block = "THEN_BLOCK"
1749 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001750 attr = ts.TosaSerializerAttribute()
1751 attr.CondIfAttribute(then_block, else_block)
1752
1753 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001754 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001755
1756 self.ser.startBasicBlock(then_block)
1757 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001758 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1759 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1760 else:
1761 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001762 self.ser.addOutputTensor(then_tens)
1763
1764 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001765 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1766 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1767 else:
1768 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001769 self.ser.addOutputTensor(else_tens)
1770
Les Bell729b0352021-11-24 10:28:21 +00001771 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001772 self.ser,
1773 validator_fcns,
1774 error_name,
1775 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001776 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001777 ):
1778 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001779
Eric Kunzee5e26762020-10-13 16:11:07 -07001780 return result_tens
1781
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001782 def build_cond_if_binary(
1783 self, op, a, b, cond, validator_fcns=None, error_name=None
1784 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001785 # For cond_if with a binary op in the then/else blocks, take a and b and
1786 # alternately add or subtract them based on the condition
1787
1788 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001789 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001790
Kevin Cheng550ccc52021-03-03 11:21:43 -08001791 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001792
1793 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001794 then_block = "THEN_BLOCK"
1795 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001796 attr = ts.TosaSerializerAttribute()
1797 attr.CondIfAttribute(then_block, else_block)
1798
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001799 if error_name in [
1800 ErrorIf.CondIfInputListThenGraphMismatch,
1801 ErrorIf.CondIfInputListElseGraphMismatch,
1802 ErrorIf.CondIfOutputListElseGraphMismatch,
1803 ErrorIf.CondIfOutputListThenGraphMismatch,
1804 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001805 incorrect_shape = a.shape.copy()
1806 for i in range(len(incorrect_shape)):
1807 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1808 incorrect_block_input = deepcopy(a)
1809 incorrect_block_input.shape = incorrect_shape
1810
Eric Kunzee5e26762020-10-13 16:11:07 -07001811 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001812 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001814 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001815
Les Bell6040b4d2021-10-11 12:50:31 +01001816 if a.dtype in (DType.FLOAT, DType.INT32):
1817 then_op, else_op = Op.ADD, Op.SUB
1818 elif a.dtype in (DType.INT8, DType.INT16):
1819 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1820 else:
1821 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001822
Les Bell6040b4d2021-10-11 12:50:31 +01001823 for block, op in ((then_block, then_op), (else_block, else_op)):
1824 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001825 if (
1826 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1827 and block == then_block
1828 ) or (
1829 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1830 and block == else_block
1831 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001832 self.ser.addInputTensor(incorrect_block_input)
1833 self.ser.addInputTensor(b)
1834 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001835 elif (
1836 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1837 and block == then_block
1838 ) or (
1839 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1840 and block == else_block
1841 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001842 self.ser.addInputTensor(a)
1843 self.ser.addInputTensor(b)
1844 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1845 else:
1846 self.ser.addInputTensor(a)
1847 self.ser.addInputTensor(b)
1848 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001849 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001850
Les Bell729b0352021-11-24 10:28:21 +00001851 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001852 self.ser,
1853 validator_fcns,
1854 error_name,
1855 op=op,
1856 a=a,
1857 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001858 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001859 ):
1860 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001861
Eric Kunzee5e26762020-10-13 16:11:07 -07001862 return result_tens
1863
Matthew Haddon630c17c2021-10-14 15:05:41 +01001864 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001865 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001866
Kevin Cheng550ccc52021-03-03 11:21:43 -08001867 cond_block = "COND_BLOCK"
1868 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001869
1870 attr = ts.TosaSerializerAttribute()
1871 attr.WhileLoopAttribute(cond_block, body_block)
1872
1873 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001874 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001875 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001876 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001877
1878 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001879 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1880 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001881 if error_name == ErrorIf.InputListOutputListMismatch:
1882 incorrect_acc = deepcopy(acc)
1883 for i in range(len(incorrect_acc.shape)):
1884 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1885 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1886 else:
1887 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001888
1889 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001890 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001891 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001892 [iter.name, a.name, acc.name],
1893 [iter_out.name, a_out.name, acc_out.name],
1894 attr,
1895 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001896 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001897
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001898 if error_name in [
1899 ErrorIf.InputListCondGraphMismatch,
1900 ErrorIf.InputListBodyGraphInputMismatch,
1901 ErrorIf.InputListBodyGraphOutputMismatch,
1902 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001903 incorrect_iter = deepcopy(iter)
1904 for i in range(len(incorrect_iter.shape)):
1905 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
1906 if len(incorrect_iter.shape) == 0:
1907 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
1908
1909 incorrect_acc = deepcopy(acc)
1910 for i in range(len(incorrect_acc.shape)):
1911 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1912
Eric Kunzee5e26762020-10-13 16:11:07 -07001913 # COND block (input: iter, output: cond_tens )
1914 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001915 if error_name == ErrorIf.InputListCondGraphMismatch:
1916 self.ser.addInputTensor(incorrect_iter)
1917 self.ser.addInputTensor(a)
1918 self.ser.addInputTensor(incorrect_acc)
1919 else:
1920 self.ser.addInputTensor(iter)
1921 self.ser.addInputTensor(a)
1922 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001923 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001924
1925 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001926 cond_tens = self.ser.addOutput(
1927 [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
1928 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001929 else:
1930 cond_tens = self.ser.addOutput([], DType.BOOL)
1931
Kevin Cheng550ccc52021-03-03 11:21:43 -08001932 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001933
1934 # BODY block (input: a, acc, iter, output: a, acc, iter)
1935 # Note that local intermediate tensors need to be declared here for the outputs
1936 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001937 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
1938 self.ser.addInputTensor(incorrect_iter)
1939 self.ser.addInputTensor(a)
1940 self.ser.addInputTensor(incorrect_acc)
1941 else:
1942 self.ser.addInputTensor(iter)
1943 self.ser.addInputTensor(a)
1944 self.ser.addInputTensor(acc)
1945
Kevin Cheng550ccc52021-03-03 11:21:43 -08001946 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01001947
1948 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001949 iter_body_out = self.ser.addIntermediate(
1950 incorrect_iter.shape, incorrect_iter.dtype
1951 )
1952 acc_body_out = self.ser.addIntermediate(
1953 incorrect_acc.shape, incorrect_acc.dtype
1954 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001955 else:
1956 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1957 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
1958
Eric Kunzee5e26762020-10-13 16:11:07 -07001959 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1960 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1961 self.ser.addOutputTensor(iter_body_out)
1962 self.ser.addOutputTensor(a)
1963 self.ser.addOutputTensor(acc_body_out)
1964
Les Bell729b0352021-11-24 10:28:21 +00001965 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001966 self.ser,
1967 validator_fcns,
1968 error_name,
1969 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001970 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001971 ):
1972 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001973
Eric Kunzee5e26762020-10-13 16:11:07 -07001974 return acc_out
1975
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001976 def create_filter_lists(
1977 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
1978 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01001979 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1980 default_test_rank_range = range(1, 5)
1981 if not shapeFilter:
1982 shapeFilter = [None]
1983
1984 # Calculate the filters based on what is requested and what the operator allows
1985 rmin, rmax = op["rank"]
1986 if rankFilter is not None:
1987 cleanRankFilter = []
1988 # Ensure rankFilter values are allowed by operator
1989 for rank in rankFilter:
1990 if rank >= rmin and rank <= rmax:
1991 cleanRankFilter.append(rank)
1992 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01001993 # Ensure default behaviour is bounded by default range or by operator,
1994 # whichever is the smaller range of ranks.
1995 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001996 cleanRankFilter = (
1997 opRankRange
1998 if len(opRankRange) <= len(default_test_rank_range)
1999 else default_test_rank_range
2000 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002001 else:
2002 cleanRankFilter = range(rmin, rmax + 1)
2003
2004 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002005
Matthew Haddon1c00b712021-10-01 15:51:03 +01002006 if dtypeFilter is not None:
2007 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002008 # Create list of operator dtypes filtered by requested dtypes
2009 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002010 if dtype in dtypeFilter or (
2011 isinstance(dtype, list) and dtype[0] in dtypeFilter
2012 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002013 cleanDtypeFilter.append(dtype)
2014 else:
2015 cleanDtypeFilter = dtypes
2016
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002017 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002018 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002019 "shapeFilter": shapeFilter,
2020 "rankFilter": cleanRankFilter,
2021 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002022 }
2023 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002024 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002025 if validator is not None:
2026 validator_info = validator(check=False, op=op)
2027 else:
2028 return None
2029
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002030 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002031
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002032 # Set parameters as required
2033 if error_arguments["rank"] is not None:
2034 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002035 else:
2036 rankFilter = cleanRankFilter
2037
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002038 if error_arguments["dtype"] is not None:
2039 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002040 else:
2041 dtypeFilter = cleanDtypeFilter
2042
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002043 if error_arguments["shape"] is not None:
2044 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002045 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002046 shapeFilter = shapeFilter[
2047 :2
2048 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002049
2050 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002051 "shapeFilter": shapeFilter,
2052 "rankFilter": rankFilter,
2053 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002054 }
2055 return filterDict
2056
Kevin Cheng550ccc52021-03-03 11:21:43 -08002057 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002058 self,
2059 opName,
2060 shapeFilter=[None],
2061 rankFilter=None,
2062 dtypeFilter=None,
2063 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002064 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002065
2066 try:
2067 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002068 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002069 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002070
2071 # Initialize a new random number generator
2072 self.rng = np.random.default_rng(self.random_seed)
2073
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002074 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002075
Eric Kunzee5e26762020-10-13 16:11:07 -07002076 # Test list consists of a tuple of:
2077 # (opName, testNameStr, dtype, shapeList, argumentsList)
2078 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002079 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002080 error_if_validators = op["error_if_validators"]
2081 else:
2082 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002083
Matthew Haddon1c00b712021-10-01 15:51:03 +01002084 for validator in error_if_validators:
2085 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002086 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002087 else:
2088 error_name = None
2089
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002090 filterDict = self.create_filter_lists(
2091 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2092 )
2093 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002094 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002095 cleanRankFilter = filterDict["rankFilter"]
2096 cleanDtypeFilter = filterDict["dtypeFilter"]
2097 cleanShapeFilter = filterDict["shapeFilter"]
2098 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002099
2100 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002101 for t in cleanDtypeFilter:
2102 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002103 # Filter out by rank
2104 if shape is not None and len(shape) != r:
2105 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002106 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002107 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002108
Matthew Haddon74567092021-07-16 15:38:20 +01002109 shapeStr = self.shapeStr(shapeList[0])
2110 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002111
Matthew Haddon74567092021-07-16 15:38:20 +01002112 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2113 argList = []
2114 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002115 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002116 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002117 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002118
Matthew Haddon74567092021-07-16 15:38:20 +01002119 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002120 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002121 if argStr:
2122 testStr = "{}_{}_{}_{}".format(
2123 opName, shapeStr, typeStr, argStr
2124 )
2125 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002126 testStr = "{}_{}_{}".format(
2127 opName, shapeStr, typeStr
2128 )
2129 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002130 if argStr:
2131 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2132 opName, error_name, shapeStr, typeStr, argStr
2133 )
2134 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002135 testStr = "{}_ERRORIF_{}_{}_{}".format(
2136 opName, error_name, shapeStr, typeStr
2137 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002138
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002139 testList.append(
2140 (opName, testStr, t, error_name, shapeList, args)
2141 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002142
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002143 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002144 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2145 if "invalid_test_validators" in op:
2146 invalid_test_validators = op["invalid_test_validators"]
2147 clean_testList = []
2148 for test in testList:
2149 for validator_fcn in invalid_test_validators:
2150 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002151 if validator_fcn(
2152 opName=test[0],
2153 input_dtype=test[2],
2154 shapeList=test[4],
2155 args=test[5],
2156 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002157 remove_test = True
2158 if not remove_test:
2159 clean_testList.append(test)
2160 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002161
2162 return testList
2163
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002164 def serializeTest(
2165 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2166 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002167 try:
2168 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002169 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002170 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002171
2172 # Create a serializer
2173 self.createSerializer(opName, testStr)
2174
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002175 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002176 if "error_if_validators" in op:
2177 error_if_validators = op["error_if_validators"]
2178 else:
2179 error_if_validators = None
2180
Kevin Cheng550ccc52021-03-03 11:21:43 -08002181 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002182 num_operands = pCount + cCount
2183
2184 if isinstance(dtype_or_dtypeList, list):
2185 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002186 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002187 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002188 else:
2189 dtypeList = [dtype_or_dtypeList] * (num_operands)
2190
Kevin Cheng93a16282021-08-31 16:14:03 -07002191 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002192 assert (
2193 len(shapeList) == num_operands
2194 ), "shapeList length {} must match number of operands {}".format(
2195 len(shapeList), num_operands
2196 )
2197 assert (
2198 len(dtypeList) == num_operands
2199 ), "dtypeList length {} must match number of operands {}".format(
2200 len(dtypeList), num_operands
2201 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002202
2203 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002204 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002205 except KeyError:
2206 qgen = None
2207
2208 # Build the random tensor operands and the test
2209 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002210
Matthew Haddon1c00b712021-10-01 15:51:03 +01002211 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002212 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002213 else:
2214 qinfo = None
2215
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002216 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002217
Matthew Haddon1c00b712021-10-01 15:51:03 +01002218 try:
2219 if error_if_validators is None:
2220 if qinfo is not None:
2221 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2222 else:
2223 resultName = build_fcn(self, op, *tens, *testArgs)
2224 else:
2225 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002226 resultName = build_fcn(
2227 self,
2228 op,
2229 *tens,
2230 *testArgs,
2231 validator_fcns=error_if_validators,
2232 error_name=error_name,
2233 qinfo=qinfo,
2234 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002235 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002236 resultName = build_fcn(
2237 self,
2238 op,
2239 *tens,
2240 *testArgs,
2241 validator_fcns=error_if_validators,
2242 error_name=error_name,
2243 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002244 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002245 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002246 raise e
2247
Les Bell729b0352021-11-24 10:28:21 +00002248 if resultName:
2249 # The test is valid, serialize it
2250 self.serialize("test")
2251 else:
2252 # The test is not valid
2253 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002254
Eric Kunzee5e26762020-10-13 16:11:07 -07002255 def createDynamicOpLists(self):
2256
2257 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002258 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002259
Kevin Cheng1533b852021-09-01 12:51:58 -07002260 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002261 testName = "conv2d_{}x{}".format(k[0], k[1])
2262 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2263 self.TOSA_OP_LIST[testName]["filter"] = k
2264 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002265
Kevin Cheng550ccc52021-03-03 11:21:43 -08002266 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2267 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2268 "depthwise_conv2d_TEMPLATE"
2269 ].copy()
2270 self.TOSA_OP_LIST[testName]["filter"] = k
2271 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002272
Kevin Cheng550ccc52021-03-03 11:21:43 -08002273 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2274 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2275 "transpose_conv2d_TEMPLATE"
2276 ].copy()
2277 self.TOSA_OP_LIST[testName]["filter"] = k
2278 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002279
Kevin Cheng1533b852021-09-01 12:51:58 -07002280 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2281 for k in KERNELS_3D:
2282 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2283 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2284 self.TOSA_OP_LIST[testName]["filter"] = k
2285 self.TOSA_OP_LIST[testName]["template"] = False
2286
Eric Kunzee5e26762020-10-13 16:11:07 -07002287 # Delete any templates after having created any dynamic ops
2288 # This is a two-pass operation because it's bad practice to delete
2289 # keys from dictionaries while iterating
2290 keyList = []
2291 for k in self.TOSA_OP_LIST:
2292 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002293 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002294 keyList.append(k)
2295 continue
2296 except KeyError:
2297 pass
2298
2299 for k in keyList:
2300 del self.TOSA_OP_LIST[k]
2301
2302 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002303 """Fill in default fields for ops if they aren't already specified.
2304 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002305 for op in self.TOSA_OP_LIST:
2306
2307 # Required fields
2308 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002309 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002310 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002311 raise Exception(
2312 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2313 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002314
2315 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002316 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002317 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002318 raise Exception(
2319 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2320 op
2321 )
2322 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002323
2324 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002325 _ = self.TOSA_OP_LIST[op]["types"]
2326 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002327 raise Exception(
2328 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2329 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002330
2331 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002332 _ = self.TOSA_OP_LIST[op]["op"]
2333 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002334 raise Exception(
2335 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2336 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002337
2338 # Put in default rank range, if missing
2339 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002340 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002341 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002342 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002343
2344 # Tensor operator list
2345 # 'op': op name
2346 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002347 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2348 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002349 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2350 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08002351 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002352
Kevin Cheng550ccc52021-03-03 11:21:43 -08002353 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
2354 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002355
Kevin Cheng550ccc52021-03-03 11:21:43 -08002356 TYPE_BOOL = [DType.BOOL]
2357 TYPE_FI32 = [DType.FLOAT, DType.INT32]
2358 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
2359 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002360
Kevin Cheng550ccc52021-03-03 11:21:43 -08002361 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002362
Kevin Cheng1533b852021-09-01 12:51:58 -07002363 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002364 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002365 [DType.INT8, DType.INT8, DType.INT32],
2366 [DType.INT16, DType.INT8, DType.INT48],
2367 DType.FLOAT,
2368 ]
2369
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002370 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002371
2372 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002373 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002374 "argmax": {
2375 "op": Op.ARGMAX,
2376 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002377 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002378 "build_fcn": (
2379 build_argmax,
2380 TosaTensorGen.tgBasic,
2381 TosaTensorValuesGen.tvgDefault,
2382 TosaArgGen.agAxis,
2383 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002384 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002385 "error_if_validators": (
2386 TosaErrorValidator.evAxisSmallerZero,
2387 TosaErrorValidator.evAxisLargerRank,
2388 TosaErrorValidator.evArgmaxOutputRankMismatch,
2389 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2390 TosaErrorValidator.evWrongRank,
2391 TosaErrorValidator.evWrongInputType,
2392 TosaErrorValidator.evWrongOutputType,
2393 TosaErrorValidator.evWrongInputList,
2394 TosaErrorValidator.evWrongOutputList,
2395 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002396 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002397 "avg_pool2d": {
2398 "op": Op.AVG_POOL2D,
2399 "operands": (1, 0),
2400 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002401 "build_fcn": (
2402 build_pool2d,
2403 TosaTensorGen.tgNHWC,
2404 TosaTensorValuesGen.tvgDefault,
2405 TosaArgGen.agPooling,
2406 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002407 "qgen": TosaQuantGen.qgUnary,
2408 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002409 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002410 "error_if_validators": (
2411 TosaErrorValidator.evKernelSmallerOne,
2412 TosaErrorValidator.evStrideSmallerOne,
2413 TosaErrorValidator.evPadSmallerZero,
2414 TosaErrorValidator.evWrongRank,
2415 TosaErrorValidator.evWrongInputType,
2416 TosaErrorValidator.evWrongOutputType,
2417 TosaErrorValidator.evWrongInputList,
2418 TosaErrorValidator.evWrongOutputList,
2419 TosaErrorValidator.evInputZeroPointNotZero,
2420 TosaErrorValidator.evOutputZeroPointNotZero,
2421 TosaErrorValidator.evPadLargerEqualKernel,
2422 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002423 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002424 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002425 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002426 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002427 "conv2d_TEMPLATE": {
2428 "op": Op.CONV2D,
2429 "operands": (1, 2),
2430 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002431 "build_fcn": (
2432 build_conv2d,
2433 TosaTensorGen.tgConv2D,
2434 TosaTensorValuesGen.tvgDefault,
2435 TosaArgGen.agConv,
2436 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002437 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002438 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002439 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2440 "error_if_validators": (
2441 TosaErrorValidator.evWrongInputType,
2442 TosaErrorValidator.evWrongOutputType,
2443 TosaErrorValidator.evWrongInputList,
2444 TosaErrorValidator.evWrongOutputList,
2445 TosaErrorValidator.evInputZeroPointNotZero,
2446 TosaErrorValidator.evWeightZeroPointNotZero,
2447 TosaErrorValidator.evPadSmallerZero,
2448 TosaErrorValidator.evStrideSmallerOne,
2449 TosaErrorValidator.evDilationSmallerOne,
2450 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002451 TosaErrorValidator.evConvOutputShapeMismatch,
2452 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002453 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002454 "template": True,
2455 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002456 # Templated operator. Filled in by createDynamicOpLists
2457 "conv3d_TEMPLATE": {
2458 "op": Op.CONV3D,
2459 "operands": (1, 2),
2460 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002461 "build_fcn": (
2462 build_conv3d,
2463 TosaTensorGen.tgConv3D,
2464 TosaTensorValuesGen.tvgDefault,
2465 TosaArgGen.agConv,
2466 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002467 "qgen": TosaQuantGen.qgConv,
2468 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002469 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2470 "error_if_validators": (
2471 TosaErrorValidator.evWrongInputType,
2472 TosaErrorValidator.evWrongOutputType,
2473 TosaErrorValidator.evWrongInputList,
2474 TosaErrorValidator.evWrongOutputList,
2475 TosaErrorValidator.evInputZeroPointNotZero,
2476 TosaErrorValidator.evWeightZeroPointNotZero,
2477 TosaErrorValidator.evPadSmallerZero,
2478 TosaErrorValidator.evStrideSmallerOne,
2479 TosaErrorValidator.evDilationSmallerOne,
2480 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002481 TosaErrorValidator.evConvOutputShapeMismatch,
2482 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002483 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002484 "template": True,
2485 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002486 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002487 "depthwise_conv2d_TEMPLATE": {
2488 "op": Op.DEPTHWISE_CONV2D,
2489 "operands": (1, 2),
2490 "filter": [1, 1],
2491 "rank": (4, 4),
2492 "build_fcn": (
2493 build_depthwise_conv2d,
2494 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002495 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002496 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002497 ),
2498 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002499 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002500 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2501 "error_if_validators": (
2502 TosaErrorValidator.evWrongInputType,
2503 TosaErrorValidator.evWrongOutputType,
2504 TosaErrorValidator.evWrongInputList,
2505 TosaErrorValidator.evWrongOutputList,
2506 TosaErrorValidator.evInputZeroPointNotZero,
2507 TosaErrorValidator.evWeightZeroPointNotZero,
2508 TosaErrorValidator.evPadSmallerZero,
2509 TosaErrorValidator.evStrideSmallerOne,
2510 TosaErrorValidator.evDilationSmallerOne,
2511 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002512 TosaErrorValidator.evConvOutputShapeMismatch,
2513 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002514 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002515 "template": True,
2516 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002517 "fully_connected": {
2518 "op": Op.FULLY_CONNECTED,
2519 "operands": (1, 2),
2520 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002521 "build_fcn": (
2522 build_fully_connected,
2523 TosaTensorGen.tgFullyConnected,
2524 TosaTensorValuesGen.tvgDefault,
2525 None,
2526 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002527 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002528 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002529 "error_if_validators": (
2530 TosaErrorValidator.evInputZeroPointNotZero,
2531 TosaErrorValidator.evWeightZeroPointNotZero,
2532 TosaErrorValidator.evWrongRank,
2533 TosaErrorValidator.evWrongInputType,
2534 TosaErrorValidator.evWrongOutputType,
2535 TosaErrorValidator.evWrongInputList,
2536 TosaErrorValidator.evWrongOutputList,
2537 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002538 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002539 "matmul": {
2540 "op": Op.MATMUL,
2541 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002542 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002543 "build_fcn": (
2544 build_matmul,
2545 TosaTensorGen.tgMatmul,
2546 TosaTensorValuesGen.tvgDefault,
2547 None,
2548 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002549 "qgen": TosaQuantGen.qgMatmul,
2550 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002551 "error_if_validators": (
2552 TosaErrorValidator.evInputZeroPointNotZero,
2553 TosaErrorValidator.evWrongRank,
2554 TosaErrorValidator.evWrongInputType,
2555 TosaErrorValidator.evWrongOutputType,
2556 TosaErrorValidator.evWrongInputList,
2557 TosaErrorValidator.evWrongOutputList,
2558 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002559 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002560 "max_pool2d": {
2561 "op": Op.MAX_POOL2D,
2562 "operands": (1, 0),
2563 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002564 "build_fcn": (
2565 build_pool2d,
2566 TosaTensorGen.tgNHWC,
2567 TosaTensorValuesGen.tvgDefault,
2568 TosaArgGen.agPooling,
2569 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002570 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002571 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002572 "error_if_validators": (
2573 TosaErrorValidator.evKernelSmallerOne,
2574 TosaErrorValidator.evStrideSmallerOne,
2575 TosaErrorValidator.evPadSmallerZero,
2576 TosaErrorValidator.evWrongRank,
2577 TosaErrorValidator.evWrongInputType,
2578 TosaErrorValidator.evWrongOutputType,
2579 TosaErrorValidator.evWrongInputList,
2580 TosaErrorValidator.evWrongOutputList,
2581 TosaErrorValidator.evPadLargerEqualKernel,
2582 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002583 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002584 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002585 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002586 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002587 "transpose_conv2d_TEMPLATE": {
2588 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002589 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002590 "rank": (4, 4),
2591 "build_fcn": (
2592 build_transpose_conv2d,
2593 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002594 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002595 TosaArgGen.agTransposeConv2D,
2596 ),
2597 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002598 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002599 "invalid_test_validators": (
2600 TosaInvalidValidator.ivHeightWidthInvalid,
2601 TosaInvalidValidator.ivNonPositiveOutputShape,
2602 ),
2603 "error_if_validators": (
2604 TosaErrorValidator.evWrongInputType,
2605 TosaErrorValidator.evWrongOutputType,
2606 TosaErrorValidator.evWrongInputList,
2607 TosaErrorValidator.evWrongOutputList,
2608 TosaErrorValidator.evInputZeroPointNotZero,
2609 TosaErrorValidator.evWeightZeroPointNotZero,
2610 TosaErrorValidator.evPadSmallerZero,
2611 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002612 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002613 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002614 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002615 "template": True,
2616 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002617 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002618 "clamp": {
2619 "op": Op.CLAMP,
2620 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002621 "build_fcn": (
2622 build_clamp,
2623 TosaTensorGen.tgBasic,
2624 TosaTensorValuesGen.tvgDefault,
2625 None,
2626 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002627 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002628 "error_if_validators": (
2629 TosaErrorValidator.evMaxSmallerMin,
2630 TosaErrorValidator.evWrongInputType,
2631 TosaErrorValidator.evWrongOutputType,
2632 TosaErrorValidator.evWrongInputList,
2633 TosaErrorValidator.evWrongOutputList,
2634 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002635 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002636 "sigmoid": {
2637 "op": Op.SIGMOID,
2638 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002639 "build_fcn": (
2640 build_sigmoid,
2641 TosaTensorGen.tgBasic,
2642 TosaTensorValuesGen.tvgDefault,
2643 None,
2644 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002645 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002646 "error_if_validators": (
2647 TosaErrorValidator.evWrongInputType,
2648 TosaErrorValidator.evWrongOutputType,
2649 TosaErrorValidator.evWrongInputList,
2650 TosaErrorValidator.evWrongOutputList,
2651 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002652 },
2653 "tanh": {
2654 "op": Op.TANH,
2655 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002656 "build_fcn": (
2657 build_tanh,
2658 TosaTensorGen.tgBasic,
2659 TosaTensorValuesGen.tvgDefault,
2660 None,
2661 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002662 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002663 "error_if_validators": (
2664 TosaErrorValidator.evWrongInputType,
2665 TosaErrorValidator.evWrongOutputType,
2666 TosaErrorValidator.evWrongInputList,
2667 TosaErrorValidator.evWrongOutputList,
2668 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002669 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002670 # Elementwise Binary Operators
2671 "add": {
2672 "op": Op.ADD,
2673 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002674 "build_fcn": (
2675 build_binary_broadcast,
2676 TosaTensorGen.tgBroadcastFuzz,
2677 TosaTensorValuesGen.tvgAddSub,
2678 None,
2679 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002680 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002681 "error_if_validators": (
2682 TosaErrorValidator.evRankMismatch,
2683 TosaErrorValidator.evWrongInputType,
2684 TosaErrorValidator.evWrongOutputType,
2685 TosaErrorValidator.evWrongInputList,
2686 TosaErrorValidator.evWrongOutputList,
2687 TosaErrorValidator.evDimensionMismatch,
2688 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002689 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002690 "arithmetic_right_shift": {
2691 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2692 "operands": (2, 0),
2693 "build_fcn": (
2694 build_arithmetic_right_shift,
2695 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002696 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002697 TosaArgGen.agArithmeticRightShift,
2698 ),
2699 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002700 "error_if_validators": (
2701 TosaErrorValidator.evRankMismatch,
2702 TosaErrorValidator.evWrongInputType,
2703 TosaErrorValidator.evWrongOutputType,
2704 TosaErrorValidator.evWrongInputList,
2705 TosaErrorValidator.evWrongOutputList,
2706 TosaErrorValidator.evDimensionMismatch,
2707 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002708 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002709 "bitwise_and": {
2710 "op": Op.BITWISE_AND,
2711 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002712 "build_fcn": (
2713 build_binary_broadcast,
2714 TosaTensorGen.tgBroadcastFuzz,
2715 TosaTensorValuesGen.tvgDefault,
2716 None,
2717 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002718 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002719 "error_if_validators": (
2720 TosaErrorValidator.evRankMismatch,
2721 TosaErrorValidator.evWrongInputType,
2722 TosaErrorValidator.evWrongOutputType,
2723 TosaErrorValidator.evWrongInputList,
2724 TosaErrorValidator.evWrongOutputList,
2725 TosaErrorValidator.evDimensionMismatch,
2726 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002727 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002728 "bitwise_or": {
2729 "op": Op.BITWISE_OR,
2730 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002731 "build_fcn": (
2732 build_binary_broadcast,
2733 TosaTensorGen.tgBroadcastFuzz,
2734 TosaTensorValuesGen.tvgDefault,
2735 None,
2736 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002737 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002738 "error_if_validators": (
2739 TosaErrorValidator.evRankMismatch,
2740 TosaErrorValidator.evWrongInputType,
2741 TosaErrorValidator.evWrongOutputType,
2742 TosaErrorValidator.evWrongInputList,
2743 TosaErrorValidator.evWrongOutputList,
2744 TosaErrorValidator.evDimensionMismatch,
2745 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002746 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002747 "bitwise_xor": {
2748 "op": Op.BITWISE_XOR,
2749 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002750 "build_fcn": (
2751 build_binary_broadcast,
2752 TosaTensorGen.tgBroadcastFuzz,
2753 TosaTensorValuesGen.tvgDefault,
2754 None,
2755 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002756 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002757 "error_if_validators": (
2758 TosaErrorValidator.evRankMismatch,
2759 TosaErrorValidator.evWrongInputType,
2760 TosaErrorValidator.evWrongOutputType,
2761 TosaErrorValidator.evWrongInputList,
2762 TosaErrorValidator.evWrongOutputList,
2763 TosaErrorValidator.evDimensionMismatch,
2764 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002765 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002766 "intdiv": {
2767 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002768 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002769 "build_fcn": (
2770 build_binary_broadcast,
2771 TosaTensorGen.tgBroadcastFuzz,
2772 TosaTensorValuesGen.tvgIntDiv,
2773 None,
2774 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002775 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002776 "error_if_validators": (
2777 TosaErrorValidator.evRankMismatch,
2778 TosaErrorValidator.evWrongInputType,
2779 TosaErrorValidator.evWrongOutputType,
2780 TosaErrorValidator.evWrongInputList,
2781 TosaErrorValidator.evWrongOutputList,
2782 TosaErrorValidator.evDimensionMismatch,
2783 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002784 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002785 "logical_and": {
2786 "op": Op.LOGICAL_AND,
2787 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002788 "build_fcn": (
2789 build_binary_broadcast,
2790 TosaTensorGen.tgBroadcastFuzz,
2791 TosaTensorValuesGen.tvgDefault,
2792 None,
2793 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002794 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002795 "error_if_validators": (
2796 TosaErrorValidator.evRankMismatch,
2797 TosaErrorValidator.evWrongInputType,
2798 TosaErrorValidator.evWrongOutputType,
2799 TosaErrorValidator.evWrongInputList,
2800 TosaErrorValidator.evWrongOutputList,
2801 TosaErrorValidator.evDimensionMismatch,
2802 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002803 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002804 "logical_left_shift": {
2805 "op": Op.LOGICAL_LEFT_SHIFT,
2806 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002807 "build_fcn": (
2808 build_binary_broadcast,
2809 TosaTensorGen.tgBroadcastFuzz,
2810 TosaTensorValuesGen.tvgLogicalShift,
2811 None,
2812 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002813 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002814 "error_if_validators": (
2815 TosaErrorValidator.evRankMismatch,
2816 TosaErrorValidator.evWrongInputType,
2817 TosaErrorValidator.evWrongOutputType,
2818 TosaErrorValidator.evWrongInputList,
2819 TosaErrorValidator.evWrongOutputList,
2820 TosaErrorValidator.evDimensionMismatch,
2821 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002822 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002823 "logical_right_shift": {
2824 "op": Op.LOGICAL_RIGHT_SHIFT,
2825 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002826 "build_fcn": (
2827 build_binary_broadcast,
2828 TosaTensorGen.tgBroadcastFuzz,
2829 TosaTensorValuesGen.tvgLogicalShift,
2830 None,
2831 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002832 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002833 "error_if_validators": (
2834 TosaErrorValidator.evRankMismatch,
2835 TosaErrorValidator.evWrongInputType,
2836 TosaErrorValidator.evWrongOutputType,
2837 TosaErrorValidator.evWrongInputList,
2838 TosaErrorValidator.evWrongOutputList,
2839 TosaErrorValidator.evDimensionMismatch,
2840 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002841 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002842 "logical_or": {
2843 "op": Op.LOGICAL_OR,
2844 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002845 "build_fcn": (
2846 build_binary_broadcast,
2847 TosaTensorGen.tgBroadcastFuzz,
2848 TosaTensorValuesGen.tvgDefault,
2849 None,
2850 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002851 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002852 "error_if_validators": (
2853 TosaErrorValidator.evRankMismatch,
2854 TosaErrorValidator.evWrongInputType,
2855 TosaErrorValidator.evWrongOutputType,
2856 TosaErrorValidator.evWrongInputList,
2857 TosaErrorValidator.evWrongOutputList,
2858 TosaErrorValidator.evDimensionMismatch,
2859 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002860 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002861 "logical_xor": {
2862 "op": Op.LOGICAL_XOR,
2863 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002864 "build_fcn": (
2865 build_binary_broadcast,
2866 TosaTensorGen.tgBroadcastFuzz,
2867 TosaTensorValuesGen.tvgDefault,
2868 None,
2869 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002870 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002871 "error_if_validators": (
2872 TosaErrorValidator.evRankMismatch,
2873 TosaErrorValidator.evWrongInputType,
2874 TosaErrorValidator.evWrongOutputType,
2875 TosaErrorValidator.evWrongInputList,
2876 TosaErrorValidator.evWrongOutputList,
2877 TosaErrorValidator.evDimensionMismatch,
2878 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002879 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002880 "maximum": {
2881 "op": Op.MAXIMUM,
2882 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002883 "build_fcn": (
2884 build_binary_broadcast,
2885 TosaTensorGen.tgBroadcastFuzz,
2886 TosaTensorValuesGen.tvgDefault,
2887 None,
2888 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002889 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002890 "error_if_validators": (
2891 TosaErrorValidator.evRankMismatch,
2892 TosaErrorValidator.evWrongInputType,
2893 TosaErrorValidator.evWrongOutputType,
2894 TosaErrorValidator.evWrongInputList,
2895 TosaErrorValidator.evWrongOutputList,
2896 TosaErrorValidator.evDimensionMismatch,
2897 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002898 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002899 "minimum": {
2900 "op": Op.MINIMUM,
2901 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002902 "build_fcn": (
2903 build_binary_broadcast,
2904 TosaTensorGen.tgBroadcastFuzz,
2905 TosaTensorValuesGen.tvgDefault,
2906 None,
2907 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002908 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002909 "error_if_validators": (
2910 TosaErrorValidator.evRankMismatch,
2911 TosaErrorValidator.evWrongInputType,
2912 TosaErrorValidator.evWrongOutputType,
2913 TosaErrorValidator.evWrongInputList,
2914 TosaErrorValidator.evWrongOutputList,
2915 TosaErrorValidator.evDimensionMismatch,
2916 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002917 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002918 "mul": {
2919 "op": Op.MUL,
2920 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002921 "build_fcn": (
2922 build_mul,
2923 TosaTensorGen.tgBroadcastFuzz,
2924 TosaTensorValuesGen.tvgMul,
2925 TosaArgGen.agMul,
2926 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002927 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002928 "error_if_validators": (
2929 TosaErrorValidator.evWrongInputType,
2930 TosaErrorValidator.evWrongOutputType,
2931 TosaErrorValidator.evWrongInputList,
2932 TosaErrorValidator.evWrongOutputList,
2933 TosaErrorValidator.evRankMismatch,
2934 TosaErrorValidator.evDimensionMismatch,
2935 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002936 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002937 "pow": {
2938 "op": Op.POW,
2939 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002940 "build_fcn": (
2941 build_binary_broadcast,
2942 TosaTensorGen.tgBroadcastFuzz,
2943 TosaTensorValuesGen.tvgDefault,
2944 None,
2945 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002946 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002947 "error_if_validators": (
2948 TosaErrorValidator.evRankMismatch,
2949 TosaErrorValidator.evWrongInputType,
2950 TosaErrorValidator.evWrongOutputType,
2951 TosaErrorValidator.evWrongInputList,
2952 TosaErrorValidator.evWrongOutputList,
2953 TosaErrorValidator.evDimensionMismatch,
2954 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002955 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002956 "sub": {
2957 "op": Op.SUB,
2958 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002959 "build_fcn": (
2960 build_binary_broadcast,
2961 TosaTensorGen.tgBroadcastFuzz,
2962 TosaTensorValuesGen.tvgAddSub,
2963 None,
2964 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002965 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002966 "error_if_validators": (
2967 TosaErrorValidator.evRankMismatch,
2968 TosaErrorValidator.evWrongInputType,
2969 TosaErrorValidator.evWrongOutputType,
2970 TosaErrorValidator.evWrongInputList,
2971 TosaErrorValidator.evWrongOutputList,
2972 TosaErrorValidator.evDimensionMismatch,
2973 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002974 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002975 "table": {
2976 "op": Op.TABLE,
2977 # Use the automatic generation functions to create the input array
2978 # but create the table tensor in the build function, as it may be
2979 # a different type from the input
2980 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002981 "build_fcn": (
2982 build_table,
2983 TosaTensorGen.tgBasic,
2984 TosaTensorValuesGen.tvgDefault,
2985 TosaArgGen.agTable,
2986 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002987 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002988 "error_if_validators": (
2989 TosaErrorValidator.evWrongInputType,
2990 TosaErrorValidator.evWrongOutputType,
2991 TosaErrorValidator.evWrongInputList,
2992 TosaErrorValidator.evWrongOutputList,
2993 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002994 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002995 # Elementwise Unary operators
2996 "abs": {
2997 "op": Op.ABS,
2998 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002999 "build_fcn": (
3000 build_unary,
3001 TosaTensorGen.tgBasic,
3002 TosaTensorValuesGen.tvgDefault,
3003 None,
3004 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003005 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003006 "error_if_validators": (
3007 TosaErrorValidator.evWrongInputType,
3008 TosaErrorValidator.evWrongOutputType,
3009 TosaErrorValidator.evWrongInputList,
3010 TosaErrorValidator.evWrongOutputList,
3011 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003012 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003013 "bitwise_not": {
3014 "op": Op.BITWISE_NOT,
3015 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003016 "build_fcn": (
3017 build_unary,
3018 TosaTensorGen.tgBasic,
3019 TosaTensorValuesGen.tvgDefault,
3020 None,
3021 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003022 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003023 "error_if_validators": (
3024 TosaErrorValidator.evWrongInputType,
3025 TosaErrorValidator.evWrongOutputType,
3026 TosaErrorValidator.evWrongInputList,
3027 TosaErrorValidator.evWrongOutputList,
3028 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003029 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003030 "ceil": {
3031 "op": Op.CEIL,
3032 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003033 "build_fcn": (
3034 build_unary,
3035 TosaTensorGen.tgBasic,
3036 TosaTensorValuesGen.tvgDefault,
3037 None,
3038 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003039 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003040 "error_if_validators": (
3041 TosaErrorValidator.evWrongInputType,
3042 TosaErrorValidator.evWrongOutputType,
3043 TosaErrorValidator.evWrongInputList,
3044 TosaErrorValidator.evWrongOutputList,
3045 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003046 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003047 "clz": {
3048 "op": Op.CLZ,
3049 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003050 "build_fcn": (
3051 build_unary,
3052 TosaTensorGen.tgBasic,
3053 TosaTensorValuesGen.tvgDefault,
3054 None,
3055 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003056 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003057 "error_if_validators": (
3058 TosaErrorValidator.evWrongInputType,
3059 TosaErrorValidator.evWrongOutputType,
3060 TosaErrorValidator.evWrongInputList,
3061 TosaErrorValidator.evWrongOutputList,
3062 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003063 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003064 "exp": {
3065 "op": Op.EXP,
3066 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003067 "build_fcn": (
3068 build_unary,
3069 TosaTensorGen.tgBasic,
3070 TosaTensorValuesGen.tvgDefault,
3071 None,
3072 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003073 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003074 "error_if_validators": (
3075 TosaErrorValidator.evWrongInputType,
3076 TosaErrorValidator.evWrongOutputType,
3077 TosaErrorValidator.evWrongInputList,
3078 TosaErrorValidator.evWrongOutputList,
3079 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003080 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003081 "floor": {
3082 "op": Op.FLOOR,
3083 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003084 "build_fcn": (
3085 build_unary,
3086 TosaTensorGen.tgBasic,
3087 TosaTensorValuesGen.tvgDefault,
3088 None,
3089 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003090 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003091 "error_if_validators": (
3092 TosaErrorValidator.evWrongInputType,
3093 TosaErrorValidator.evWrongOutputType,
3094 TosaErrorValidator.evWrongInputList,
3095 TosaErrorValidator.evWrongOutputList,
3096 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003097 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003098 "log": {
3099 "op": Op.LOG,
3100 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003101 "build_fcn": (
3102 build_unary,
3103 TosaTensorGen.tgBasic,
3104 TosaTensorValuesGen.tvgDefault,
3105 None,
3106 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003107 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003108 "error_if_validators": (
3109 TosaErrorValidator.evWrongInputType,
3110 TosaErrorValidator.evWrongOutputType,
3111 TosaErrorValidator.evWrongInputList,
3112 TosaErrorValidator.evWrongOutputList,
3113 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003114 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003115 "logical_not": {
3116 "op": Op.LOGICAL_NOT,
3117 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003118 "build_fcn": (
3119 build_unary,
3120 TosaTensorGen.tgBasic,
3121 TosaTensorValuesGen.tvgDefault,
3122 None,
3123 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003124 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003125 "error_if_validators": (
3126 TosaErrorValidator.evWrongInputType,
3127 TosaErrorValidator.evWrongOutputType,
3128 TosaErrorValidator.evWrongInputList,
3129 TosaErrorValidator.evWrongOutputList,
3130 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003131 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003132 "negate": {
3133 "op": Op.NEGATE,
3134 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003135 "build_fcn": (
3136 build_unary,
3137 TosaTensorGen.tgBasic,
3138 TosaTensorValuesGen.tvgNegate,
3139 None,
3140 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003141 "qgen": TosaQuantGen.qgUnary,
3142 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003143 "error_if_validators": (
3144 TosaErrorValidator.evInputZeroPointNotZero,
3145 TosaErrorValidator.evOutputZeroPointNotZero,
3146 TosaErrorValidator.evWrongInputType,
3147 TosaErrorValidator.evWrongOutputType,
3148 TosaErrorValidator.evWrongInputList,
3149 TosaErrorValidator.evWrongOutputList,
3150 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003151 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003152 "reciprocal": {
3153 "op": Op.RECIPROCAL,
3154 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003155 "build_fcn": (
3156 build_unary,
3157 TosaTensorGen.tgBasic,
3158 TosaTensorValuesGen.tvgDefault,
3159 None,
3160 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003161 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003162 "error_if_validators": (
3163 TosaErrorValidator.evWrongInputType,
3164 TosaErrorValidator.evWrongOutputType,
3165 TosaErrorValidator.evWrongInputList,
3166 TosaErrorValidator.evWrongOutputList,
3167 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003168 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003169 "rsqrt": {
3170 "op": Op.RSQRT,
3171 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003172 "build_fcn": (
3173 build_unary,
3174 TosaTensorGen.tgBasic,
3175 TosaTensorValuesGen.tvgDefault,
3176 None,
3177 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003178 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003179 "error_if_validators": (
3180 TosaErrorValidator.evWrongInputType,
3181 TosaErrorValidator.evWrongOutputType,
3182 TosaErrorValidator.evWrongInputList,
3183 TosaErrorValidator.evWrongOutputList,
3184 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003185 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003186 # Elementwise Ternary operators
3187 "select": {
3188 "op": Op.SELECT,
3189 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003190 "build_fcn": (
3191 build_select,
3192 TosaTensorGen.tgBroadcastFuzz,
3193 TosaTensorValuesGen.tvgSelect,
3194 None,
3195 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003196 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003197 "error_if_validators": (
3198 TosaErrorValidator.evRankMismatch,
3199 TosaErrorValidator.evWrongInputType,
3200 TosaErrorValidator.evWrongOutputType,
3201 TosaErrorValidator.evWrongInputList,
3202 TosaErrorValidator.evWrongOutputList,
3203 TosaErrorValidator.evDimensionMismatch,
3204 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003205 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003206 # Comparison operators
3207 "equal": {
3208 "op": Op.EQUAL,
3209 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003210 "build_fcn": (
3211 build_comparison,
3212 TosaTensorGen.tgBroadcastFuzz,
3213 TosaTensorValuesGen.tvgEqual,
3214 None,
3215 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003216 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003217 "error_if_validators": (
3218 TosaErrorValidator.evRankMismatch,
3219 TosaErrorValidator.evWrongInputType,
3220 TosaErrorValidator.evWrongOutputType,
3221 TosaErrorValidator.evWrongInputList,
3222 TosaErrorValidator.evWrongOutputList,
3223 TosaErrorValidator.evDimensionMismatch,
3224 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003225 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003226 "greater_equal": {
3227 "op": Op.GREATER_EQUAL,
3228 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003229 "build_fcn": (
3230 build_comparison,
3231 TosaTensorGen.tgBroadcastFuzz,
3232 TosaTensorValuesGen.tvgDefault,
3233 None,
3234 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003235 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003236 "error_if_validators": (
3237 TosaErrorValidator.evRankMismatch,
3238 TosaErrorValidator.evWrongInputType,
3239 TosaErrorValidator.evWrongOutputType,
3240 TosaErrorValidator.evWrongInputList,
3241 TosaErrorValidator.evWrongOutputList,
3242 TosaErrorValidator.evDimensionMismatch,
3243 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003244 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003245 "greater": {
3246 "op": Op.GREATER,
3247 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003248 "build_fcn": (
3249 build_comparison,
3250 TosaTensorGen.tgBroadcastFuzz,
3251 TosaTensorValuesGen.tvgDefault,
3252 None,
3253 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003254 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003255 "error_if_validators": (
3256 TosaErrorValidator.evRankMismatch,
3257 TosaErrorValidator.evWrongInputType,
3258 TosaErrorValidator.evWrongOutputType,
3259 TosaErrorValidator.evWrongInputList,
3260 TosaErrorValidator.evWrongOutputList,
3261 TosaErrorValidator.evDimensionMismatch,
3262 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003264 # Reduction operators
3265 "reduce_all": {
3266 "op": Op.REDUCE_ALL,
3267 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003268 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003269 "build_fcn": (
3270 build_reduce,
3271 TosaTensorGen.tgBasic,
3272 TosaTensorValuesGen.tvgDefault,
3273 TosaArgGen.agAxis,
3274 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003275 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003276 "error_if_validators": (
3277 TosaErrorValidator.evAxisLargerRank,
3278 TosaErrorValidator.evAxisSmallerZero,
3279 TosaErrorValidator.evShapeOfAxisNotOne,
3280 TosaErrorValidator.evWrongInputType,
3281 TosaErrorValidator.evWrongOutputType,
3282 TosaErrorValidator.evWrongRank,
3283 TosaErrorValidator.evWrongInputList,
3284 TosaErrorValidator.evWrongOutputList,
3285 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003286 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003287 "reduce_any": {
3288 "op": Op.REDUCE_ANY,
3289 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003290 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003291 "build_fcn": (
3292 build_reduce,
3293 TosaTensorGen.tgBasic,
3294 TosaTensorValuesGen.tvgDefault,
3295 TosaArgGen.agAxis,
3296 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003297 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003298 "error_if_validators": (
3299 TosaErrorValidator.evAxisLargerRank,
3300 TosaErrorValidator.evAxisSmallerZero,
3301 TosaErrorValidator.evShapeOfAxisNotOne,
3302 TosaErrorValidator.evWrongInputType,
3303 TosaErrorValidator.evWrongOutputType,
3304 TosaErrorValidator.evWrongRank,
3305 TosaErrorValidator.evWrongInputList,
3306 TosaErrorValidator.evWrongOutputList,
3307 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003308 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 "reduce_max": {
3310 "op": Op.REDUCE_MAX,
3311 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003312 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003313 "build_fcn": (
3314 build_reduce,
3315 TosaTensorGen.tgBasic,
3316 TosaTensorValuesGen.tvgDefault,
3317 TosaArgGen.agAxis,
3318 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003319 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003320 "error_if_validators": (
3321 TosaErrorValidator.evAxisLargerRank,
3322 TosaErrorValidator.evAxisSmallerZero,
3323 TosaErrorValidator.evShapeOfAxisNotOne,
3324 TosaErrorValidator.evWrongInputType,
3325 TosaErrorValidator.evWrongOutputType,
3326 TosaErrorValidator.evWrongRank,
3327 TosaErrorValidator.evWrongInputList,
3328 TosaErrorValidator.evWrongOutputList,
3329 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003330 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003331 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003332 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003333 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003334 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003335 "build_fcn": (
3336 build_reduce,
3337 TosaTensorGen.tgBasic,
3338 TosaTensorValuesGen.tvgDefault,
3339 TosaArgGen.agAxis,
3340 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003341 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003342 "error_if_validators": (
3343 TosaErrorValidator.evAxisLargerRank,
3344 TosaErrorValidator.evAxisSmallerZero,
3345 TosaErrorValidator.evShapeOfAxisNotOne,
3346 TosaErrorValidator.evWrongInputType,
3347 TosaErrorValidator.evWrongOutputType,
3348 TosaErrorValidator.evWrongRank,
3349 TosaErrorValidator.evWrongInputList,
3350 TosaErrorValidator.evWrongOutputList,
3351 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003352 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 "reduce_product": {
3354 "op": Op.REDUCE_PRODUCT,
3355 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003356 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003357 "build_fcn": (
3358 build_reduce,
3359 TosaTensorGen.tgBasic,
3360 TosaTensorValuesGen.tvgDefault,
3361 TosaArgGen.agAxis,
3362 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003363 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003364 "error_if_validators": (
3365 TosaErrorValidator.evAxisLargerRank,
3366 TosaErrorValidator.evAxisSmallerZero,
3367 TosaErrorValidator.evShapeOfAxisNotOne,
3368 TosaErrorValidator.evWrongInputType,
3369 TosaErrorValidator.evWrongOutputType,
3370 TosaErrorValidator.evWrongRank,
3371 TosaErrorValidator.evWrongInputList,
3372 TosaErrorValidator.evWrongOutputList,
3373 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003374 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003375 "reduce_sum": {
3376 "op": Op.REDUCE_SUM,
3377 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003378 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003379 "build_fcn": (
3380 build_reduce,
3381 TosaTensorGen.tgBasic,
3382 TosaTensorValuesGen.tvgReduceSum,
3383 TosaArgGen.agAxis,
3384 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003385 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003386 "error_if_validators": (
3387 TosaErrorValidator.evAxisLargerRank,
3388 TosaErrorValidator.evAxisSmallerZero,
3389 TosaErrorValidator.evShapeOfAxisNotOne,
3390 TosaErrorValidator.evWrongInputType,
3391 TosaErrorValidator.evWrongOutputType,
3392 TosaErrorValidator.evWrongRank,
3393 TosaErrorValidator.evWrongInputList,
3394 TosaErrorValidator.evWrongOutputList,
3395 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003396 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003397 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003398 "concat": {
3399 "op": Op.CONCAT,
3400 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003401 "build_fcn": (
3402 build_concat,
3403 TosaTensorGen.tgConcat,
3404 TosaTensorValuesGen.tvgConcat,
3405 TosaArgGen.agAxis,
3406 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003407 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003408 "error_if_validators": (
3409 TosaErrorValidator.evAxisLargerRank,
3410 TosaErrorValidator.evAxisSmallerZero,
3411 TosaErrorValidator.evConcatInputRankMismatch,
3412 TosaErrorValidator.evConcatShapeSumMismatch,
3413 TosaErrorValidator.evConcatInputDimMismatch,
3414 TosaErrorValidator.evWrongInputType,
3415 TosaErrorValidator.evWrongOutputType,
3416 TosaErrorValidator.evWrongOutputList,
3417 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003418 },
3419 "pad": {
3420 "op": Op.PAD,
3421 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003422 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003423 "build_fcn": (
3424 build_pad,
3425 TosaTensorGen.tgBasic,
3426 TosaTensorValuesGen.tvgDefault,
3427 TosaArgGen.agPad,
3428 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003429 "qgen": TosaQuantGen.qgPad,
3430 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003431 "error_if_validators": (
3432 TosaErrorValidator.evWrongInputType,
3433 TosaErrorValidator.evPadSmallerZero,
3434 TosaErrorValidator.evWrongOutputType,
3435 TosaErrorValidator.evWrongInputList,
3436 TosaErrorValidator.evWrongOutputList,
3437 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003438 },
3439 "reshape": {
3440 "op": Op.RESHAPE,
3441 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003442 "build_fcn": (
3443 build_reshape,
3444 TosaTensorGen.tgBasic,
3445 TosaTensorValuesGen.tvgDefault,
3446 TosaArgGen.agReshape,
3447 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003448 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003449 "error_if_validators": (
3450 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3451 TosaErrorValidator.evWrongInputType,
3452 TosaErrorValidator.evWrongOutputType,
3453 TosaErrorValidator.evWrongInputList,
3454 TosaErrorValidator.evWrongOutputList,
3455 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003456 },
3457 "reverse": {
3458 "op": Op.REVERSE,
3459 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003460 "build_fcn": (
3461 build_reverse,
3462 TosaTensorGen.tgBasic,
3463 TosaTensorValuesGen.tvgDefault,
3464 TosaArgGen.agAxis,
3465 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003466 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003467 "error_if_validators": (
3468 TosaErrorValidator.evAxisSmallerZero,
3469 TosaErrorValidator.evAxisLargerRank,
3470 TosaErrorValidator.evWrongInputType,
3471 TosaErrorValidator.evWrongOutputType,
3472 TosaErrorValidator.evWrongInputList,
3473 TosaErrorValidator.evWrongOutputList,
3474 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003475 },
3476 "slice": {
3477 "op": Op.SLICE,
3478 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003479 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003480 "build_fcn": (
3481 build_slice,
3482 TosaTensorGen.tgBasic,
3483 TosaTensorValuesGen.tvgDefault,
3484 TosaArgGen.agSlice,
3485 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003486 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003487 "error_if_validators": (
3488 TosaErrorValidator.evStartSmallerZero,
3489 TosaErrorValidator.evSizeSmallerEqualZero,
3490 TosaErrorValidator.evStartSizeOutsideBounds,
3491 TosaErrorValidator.evSizeOutputShapeMismatch,
3492 TosaErrorValidator.evInputSizeStartLengthMismatch,
3493 TosaErrorValidator.evWrongRank,
3494 TosaErrorValidator.evWrongInputType,
3495 TosaErrorValidator.evWrongOutputType,
3496 TosaErrorValidator.evWrongInputList,
3497 TosaErrorValidator.evWrongOutputList,
3498 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003499 },
3500 "tile": {
3501 "op": Op.TILE,
3502 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003503 "build_fcn": (
3504 build_tile,
3505 TosaTensorGen.tgBasic,
3506 TosaTensorValuesGen.tvgDefault,
3507 TosaArgGen.agTile,
3508 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003509 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003510 "error_if_validators": (
3511 TosaErrorValidator.evWrongInputType,
3512 TosaErrorValidator.evWrongOutputType,
3513 TosaErrorValidator.evWrongInputList,
3514 TosaErrorValidator.evWrongOutputList,
3515 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003516 },
3517 "transpose": {
3518 "op": Op.TRANSPOSE,
3519 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003520 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003521 "build_fcn": (
3522 build_transpose,
3523 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003524 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003525 TosaArgGen.agTranspose,
3526 ),
3527 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003528 "error_if_validators": (
3529 TosaErrorValidator.evIndexOutsideBounds,
3530 TosaErrorValidator.evIndexUsedTwice,
3531 TosaErrorValidator.evWrongInputType,
3532 TosaErrorValidator.evWrongOutputType,
3533 TosaErrorValidator.evWrongInputList,
3534 TosaErrorValidator.evWrongOutputList,
3535 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003536 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003537 # Data nodes
3538 "const": {
3539 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003540 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003541 "build_fcn": (
3542 build_const,
3543 TosaTensorGen.tgBasic,
3544 TosaTensorValuesGen.tvgDefault,
3545 None,
3546 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003547 "types": TYPE_FIB,
3548 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 "identity": {
3550 "op": Op.IDENTITY,
3551 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003552 "build_fcn": (
3553 build_unary,
3554 TosaTensorGen.tgBasic,
3555 TosaTensorValuesGen.tvgDefault,
3556 None,
3557 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003558 "types": TYPE_FIB,
3559 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003560 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003561 "gather": {
3562 "op": Op.GATHER,
3563 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3564 "operands": (1, 0),
3565 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003566 "build_fcn": (
3567 build_gather,
3568 TosaTensorGen.tgBasic,
3569 TosaTensorValuesGen.tvgDefault,
3570 None,
3571 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003572 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003573 "error_if_validators": (
3574 TosaErrorValidator.evWrongInputType,
3575 TosaErrorValidator.evWrongOutputType,
3576 TosaErrorValidator.evWrongInputList,
3577 TosaErrorValidator.evWrongOutputList,
3578 TosaErrorValidator.evWrongRank,
3579 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003580 },
3581 "scatter": {
3582 "op": Op.SCATTER,
3583 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003584 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003585 "operands": (2, 0),
3586 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003587 "build_fcn": (
3588 build_scatter,
3589 TosaTensorGen.tgScatter,
3590 TosaTensorValuesGen.tvgDefault,
3591 None,
3592 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003593 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003594 "error_if_validators": (
3595 TosaErrorValidator.evWrongInputType,
3596 TosaErrorValidator.evWrongOutputType,
3597 TosaErrorValidator.evWrongInputList,
3598 TosaErrorValidator.evWrongOutputList,
3599 TosaErrorValidator.evWrongRank,
3600 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003601 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003602 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003603 "resize": {
3604 "op": Op.RESIZE,
3605 "operands": (1, 0),
3606 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003607 "build_fcn": (
3608 build_resize,
3609 TosaTensorGen.tgNHWC,
3610 TosaTensorValuesGen.tvgDefault,
3611 TosaArgGen.agResize,
3612 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003613 "types": [DType.INT8, DType.INT16, DType.FLOAT],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003614 "invalid_test_validators": (
3615 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
3616 TosaInvalidValidator.ivBadStride,
3617 ),
3618 "error_if_validators": (
3619 TosaErrorValidator.evMaxDimExceeded,
3620 TosaErrorValidator.evStrideSmallerEqualZero,
3621 TosaErrorValidator.evStrideLargerDimension,
3622 TosaErrorValidator.evStrideLargerEqualMax,
3623 TosaErrorValidator.evOffsetSmallerEqualMin,
3624 TosaErrorValidator.evOffsetLargerEqualMax,
3625 TosaErrorValidator.evShiftNotZero,
3626 TosaErrorValidator.evShiftSmallerOne,
3627 TosaErrorValidator.evShiftLargerEleven,
3628 TosaErrorValidator.evWrongInputType,
3629 TosaErrorValidator.evWrongOutputType,
3630 TosaErrorValidator.evWrongRank,
3631 TosaErrorValidator.evWrongInputList,
3632 TosaErrorValidator.evWrongOutputList,
3633 TosaErrorValidator.evBatchMismatch,
3634 TosaErrorValidator.evChannelMismatch,
3635 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003636 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003637 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003638 "cast": {
3639 "op": Op.CAST,
3640 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003641 "build_fcn": (
3642 build_cast,
3643 TosaTensorGen.tgBasic,
3644 TosaTensorValuesGen.tvgDefault,
3645 TosaArgGen.agCast,
3646 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003647 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003648 "error_if_validators": (
3649 TosaErrorValidator.evWrongInputType,
3650 TosaErrorValidator.evWrongOutputType,
3651 TosaErrorValidator.evWrongInputList,
3652 TosaErrorValidator.evWrongOutputList,
3653 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003654 },
3655 "rescale": {
3656 "op": Op.RESCALE,
3657 "operands": (1, 0),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003658 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003659 "build_fcn": (
3660 build_rescale,
3661 TosaTensorGen.tgBasic,
3662 TosaTensorValuesGen.tvgDefault,
3663 TosaArgGen.agRescale,
3664 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003665 "types": [
3666 DType.UINT8,
3667 DType.INT8,
3668 DType.INT16,
3669 DType.INT32,
3670 DType.INT48,
3671 DType.UINT16,
3672 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003673 "error_if_validators": (
3674 TosaErrorValidator.evInputZeroPointNotZero,
3675 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003676 TosaErrorValidator.evU16InputZeroPointNotValid,
3677 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003678 TosaErrorValidator.evScaleTrue,
3679 TosaErrorValidator.evScaleNotTrue,
3680 TosaErrorValidator.evWrongInputType,
3681 TosaErrorValidator.evWrongOutputType,
3682 TosaErrorValidator.evWrongRank,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003686 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003687 # Custom
3688 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003690 # Two varients of cond_if, one that generates one of two constant tensors (no
3691 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3692 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003693 "cond_if_const": {
3694 "op": Op.COND_IF,
3695 "operands": (0, 2),
3696 "build_fcn": (
3697 build_cond_if_const,
3698 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003699 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003700 TosaArgGen.agCondIf,
3701 ),
3702 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003703 "error_if_validators": (
3704 TosaErrorValidator.evOutputListThenGraphMismatch,
3705 TosaErrorValidator.evOutputListElseGraphMismatch,
3706 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003707 },
3708 "cond_if_binary": {
3709 "op": Op.COND_IF,
3710 "operands": (2, 0),
3711 "build_fcn": (
3712 build_cond_if_binary,
3713 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003714 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003715 TosaArgGen.agCondIf,
3716 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003717 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003718 "error_if_validators": (
3719 TosaErrorValidator.evInputListThenGraphMismatch,
3720 TosaErrorValidator.evInputListElseGraphMismatch,
3721 TosaErrorValidator.evOutputListThenGraphMismatch,
3722 TosaErrorValidator.evOutputListElseGraphMismatch,
3723 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003724 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003725 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003726 "while_loop": {
3727 "op": Op.WHILE_LOOP,
3728 "operands": (0, 1),
3729 "build_fcn": (
3730 build_while_loop,
3731 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003733 TosaArgGen.agWhileLoop,
3734 ),
3735 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003736 "error_if_validators": (
3737 TosaErrorValidator.evInputListOutputListMismatch,
3738 TosaErrorValidator.evInputListCondGraphMismatch,
3739 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3740 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3741 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
3742 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003743 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003744 }
3745
Kevin Cheng550ccc52021-03-03 11:21:43 -08003746
Eric Kunzee5e26762020-10-13 16:11:07 -07003747class OutputShaper:
3748 # Methods in this class compute the expected output shape and datatype
3749 # for common classes of operations
3750 def __init__(self):
3751 pass
3752
3753 # These methods return arguments that can be used for
3754 # creating a new output tensor
3755 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003756 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3757 if error_name != ErrorIf.RankMismatch:
3758 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003759 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003760
3761 shape = []
3762 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003763 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003764 shape.append(b.shape[i])
3765 else:
3766 shape.append(a.shape[i])
3767
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003768 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003769 all_dtypes = [
3770 DType.INT8,
3771 DType.INT16,
3772 DType.INT32,
3773 DType.INT48,
3774 DType.FLOAT,
3775 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003776 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3777 outputDType = rng.choice(wrong_dtypes)
3778 else:
3779 outputDType = a.dtype
3780
3781 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003782
3783 @staticmethod
3784 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003785 assert len(a.shape) == len(b.shape)
3786 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003787
3788 shape = []
3789 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003790 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003791 shape.append(a.shape[i])
3792
Kevin Cheng550ccc52021-03-03 11:21:43 -08003793 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003794
3795 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003796 def unaryOp(ser, rng, a, error_name=None):
3797 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003798 all_dtypes = [
3799 DType.INT8,
3800 DType.INT16,
3801 DType.INT32,
3802 DType.INT48,
3803 DType.FLOAT,
3804 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003805 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3806 outputDType = rng.choice(wrong_dtypes)
3807 else:
3808 outputDType = a.dtype
3809
3810 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003811
3812 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003813 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003814 if error_name != ErrorIf.RankMismatch:
3815 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003816 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003817
3818 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003819 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003820 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003821 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3822 else:
3823 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003824
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003825 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003826 all_dtypes = [
3827 DType.INT8,
3828 DType.INT16,
3829 DType.INT32,
3830 DType.INT48,
3831 DType.FLOAT,
3832 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003833 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3834 outputDType = rng.choice(wrong_dtypes)
3835 else:
3836 outputDType = a.dtype
3837
3838 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003839
3840 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003841 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003842 if error_name != ErrorIf.RankMismatch:
3843 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003844 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003845
3846 # Do broadcast
3847 shape = []
3848 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08003849 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07003850 shape.append(b.shape[i])
3851 else:
3852 shape.append(a.shape[i])
3853
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003854 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003855 wrong_dtypes = [
3856 DType.INT8,
3857 DType.INT16,
3858 DType.INT32,
3859 DType.INT48,
3860 DType.FLOAT,
3861 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003862 outputDType = rng.choice(wrong_dtypes)
3863 else:
3864 outputDType = DType.BOOL
3865
3866 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003867
3868 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01003869 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003870 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 if error_name not in [
3872 ErrorIf.AxisSmallerZero,
3873 ErrorIf.AxisLargerRank,
3874 ErrorIf.ShapeOfAxisNotOne,
3875 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01003876 shape[axis] = 1
3877 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
3878 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07003879
Matthew Haddond6ce7252021-09-29 15:35:44 +01003880 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003881 all_dtypes = [
3882 DType.INT8,
3883 DType.INT16,
3884 DType.INT32,
3885 DType.INT48,
3886 DType.FLOAT,
3887 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01003888 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3889 outputDType = rng.choice(wrong_dtypes)
3890 else:
3891 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003892
Matthew Haddond6ce7252021-09-29 15:35:44 +01003893 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003894
3895 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003896 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003897 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003898
3899 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
3900 del shape[axis]
3901
3902 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
3903 remove = rng.choice([True, False])
3904 if remove and len(shape) > 1:
3905 del shape[0]
3906 else:
3907 shape.append(1)
3908 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
3909 for i in range(len(shape)):
3910 shape[i] = shape[i] + rng.integers(1, 10)
3911
3912 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003913 all_dtypes = [
3914 DType.INT8,
3915 DType.INT16,
3916 DType.INT32,
3917 DType.INT48,
3918 DType.FLOAT,
3919 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003920 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
3921 outputDType = rng.choice(wrong_dtypes)
3922 else:
3923 outputDType = DType.INT32
3924
3925 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003926
3927 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003928 def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003929
3930 # IFM: NHWC
3931 # Filter: OHWI
3932 # OFM: NHWC
3933
Kevin Cheng550ccc52021-03-03 11:21:43 -08003934 h = (
3935 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003936 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003937 + padding[0]
3938 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003939 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003940 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003941
Kevin Cheng550ccc52021-03-03 11:21:43 -08003942 w = (
3943 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003944 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08003945 + padding[2]
3946 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003947 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08003948 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07003949
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003950 if error_name == ErrorIf.ConvOutputShapeMismatch:
3951 choices = [1, 2, 3]
3952 change = rng.choice(choices)
3953 # increment in multiples of stride to not hit non-integer error case
3954 if change in [1, 3]:
3955 h = h + (rng.choice(choices) * strides[0])
3956 if change in [2, 3]:
3957 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00003958
Eric Kunzee5e26762020-10-13 16:11:07 -07003959 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
3960
Kevin Cheng3a478572021-01-22 17:21:02 -08003961 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07003962 out_dtype = DType.INT32
3963 elif ifm.dtype == DType.INT16:
3964 out_dtype = DType.INT48
3965 elif ifm.dtype == DType.FLOAT:
3966 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00003967 elif error_name == ErrorIf.WrongInputType:
3968 # Pick some potentially correct output dtype if input type is incorrect
3969 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07003970 else:
Les Bell0e027d42021-11-09 14:42:14 +00003971 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
3972
3973 if error_name == ErrorIf.WrongOutputType:
3974 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
3975 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07003976
Kevin Cheng550ccc52021-03-03 11:21:43 -08003977 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003978
3979 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00003980 def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
Kevin Cheng1533b852021-09-01 12:51:58 -07003981
3982 # IFM: NDHWC
3983 # Filter: ODHWI
3984 # OFM: NDHWC
3985
3986 d = (
3987 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003988 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003989 + padding[0]
3990 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003991 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07003992 ) // strides[0] + 1
3993
3994 h = (
3995 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003996 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07003997 + padding[2]
3998 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003999 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004000 ) // strides[1] + 1
4001
4002 w = (
4003 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004004 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004005 + padding[4]
4006 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004007 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004008 ) // strides[2] + 1
4009
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004010 if error_name == ErrorIf.ConvOutputShapeMismatch:
4011 choices = [1, 2, 3, 4]
4012 change = rng.choice(choices)
4013 # increment in multiples of stride to not hit non-integer error case
4014 if change in [1, 4]:
4015 d = d + (rng.choice(choices) * strides[0])
4016 if change in [2, 4]:
4017 h = h + (rng.choice(choices) * strides[1])
4018 if change in [3, 4]:
4019 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004020
Kevin Cheng1533b852021-09-01 12:51:58 -07004021 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4022
4023 if ifm.dtype == DType.INT8:
4024 out_dtype = DType.INT32
4025 elif ifm.dtype == DType.INT16:
4026 out_dtype = DType.INT48
4027 elif ifm.dtype == DType.FLOAT:
4028 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004029 elif error_name == ErrorIf.WrongInputType:
4030 # Pick some potentially correct output dtype if input type is incorrect
4031 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004032 else:
Les Bell0e027d42021-11-09 14:42:14 +00004033 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4034
4035 if error_name == ErrorIf.WrongOutputType:
4036 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4037 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004038
4039 return ser.addOutput(ofm_shape, out_dtype)
4040
4041 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004042 def depthwiseConv2dOp(
4043 ser, rng, ifm, filter, strides, padding, dilations, error_name=None
4044 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004045 # IFM: NHWC
4046 # Filter: HWCM
4047 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004048
Kevin Cheng550ccc52021-03-03 11:21:43 -08004049 h = (
4050 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004051 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004052 + padding[0]
4053 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004054 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004055 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004056
Kevin Cheng550ccc52021-03-03 11:21:43 -08004057 w = (
4058 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004059 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004060 + padding[2]
4061 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004062 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004063 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004064
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004065 if error_name == ErrorIf.ConvOutputShapeMismatch:
4066 choices = [1, 2, 3]
4067 change = rng.choice(choices)
4068 # increment in multiples of stride to not hit non-integer error case
4069 if change in [1, 3]:
4070 h = h + (rng.choice(choices) * strides[0])
4071 if change in [2, 3]:
4072 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004073
Eric Kunzee5e26762020-10-13 16:11:07 -07004074 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4075
Kevin Cheng3a478572021-01-22 17:21:02 -08004076 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004077 out_dtype = DType.INT32
4078 elif ifm.dtype == DType.INT16:
4079 out_dtype = DType.INT48
4080 elif ifm.dtype == DType.FLOAT:
4081 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004082 elif error_name == ErrorIf.WrongInputType:
4083 # Pick some potentially correct output dtype if input type is incorrect
4084 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004085 else:
Les Bell0e027d42021-11-09 14:42:14 +00004086 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4087
4088 if error_name == ErrorIf.WrongOutputType:
4089 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4090 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004091
Kevin Cheng550ccc52021-03-03 11:21:43 -08004092 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004093
4094 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004095 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004096 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004097 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004098 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004099 h = 1
4100 w = 1
4101 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004102 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4103 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004104
4105 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004106 choices = [1, 2, 3]
4107 change = rng.choice(choices)
4108 # increment in multiples of stride to not hit non-integer error case
4109 if change in [1, 3]:
4110 h = h + (rng.choice(choices) * stride[0])
4111 if change in [2, 3]:
4112 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004113 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004114
4115 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004116 all_dtypes = [
4117 DType.INT8,
4118 DType.INT16,
4119 DType.INT32,
4120 DType.INT48,
4121 DType.FLOAT,
4122 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004123 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4124 outputDType = rng.choice(wrong_dtypes)
4125 else:
4126 outputDType = ifm.dtype
4127
4128 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004129
4130 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004131 def fullyConnectedOp(ser, rng, input, filter, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004132 # input: N, IC
4133 # filter: OC, IC
4134 # output: N, OC
4135
4136 output_shape = [input.shape[0], filter.shape[0]]
4137
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004138 if error_name == ErrorIf.WrongOutputType:
4139 if input.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004140 incorrect_types = (
4141 DType.INT4,
4142 DType.INT8,
4143 DType.INT16,
4144 DType.INT48,
4145 DType.FLOAT,
4146 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004147 elif input.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004148 incorrect_types = (
4149 DType.INT4,
4150 DType.INT8,
4151 DType.INT16,
4152 DType.INT32,
4153 DType.FLOAT,
4154 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004155 elif input.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004156 incorrect_types = (
4157 DType.INT4,
4158 DType.INT8,
4159 DType.INT16,
4160 DType.INT32,
4161 DType.INT48,
4162 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004163 out_dtype = rng.choice(a=incorrect_types)
4164 elif input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004165 out_dtype = DType.INT32
4166 elif input.dtype == DType.INT16:
4167 out_dtype = DType.INT48
4168 elif input.dtype == DType.FLOAT:
4169 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004170 elif error_name == ErrorIf.WrongInputType:
4171 # Pick some potentially correct output dtype if input type is incorrect
4172 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004173 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08004174 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004175
Kevin Cheng550ccc52021-03-03 11:21:43 -08004176 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004177
4178 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004179 def matmulOp(ser, rng, a, b, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004180 # a: N, H, C
4181 # b: N, C, W
4182 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004183
Kevin Cheng2d60f002021-06-09 14:18:32 -07004184 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004185
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004186 if error_name == ErrorIf.WrongOutputType:
4187 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004188 incorrect_types = (
4189 DType.INT4,
4190 DType.INT8,
4191 DType.INT16,
4192 DType.INT48,
4193 DType.FLOAT,
4194 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004195 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004196 incorrect_types = (
4197 DType.INT4,
4198 DType.INT8,
4199 DType.INT16,
4200 DType.INT32,
4201 DType.FLOAT,
4202 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004203 elif a.dtype == DType.FLOAT:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004204 incorrect_types = (
4205 DType.INT4,
4206 DType.INT8,
4207 DType.INT16,
4208 DType.INT32,
4209 DType.INT48,
4210 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004211 out_dtype = rng.choice(a=incorrect_types)
4212 elif a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004213 out_dtype = DType.INT32
4214 elif a.dtype == DType.INT16:
4215 out_dtype = DType.INT48
4216 elif a.dtype == DType.FLOAT:
4217 out_dtype = DType.FLOAT
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004218 elif error_name == ErrorIf.WrongInputType:
4219 # Pick some potentially correct output dtype if input type is incorrect
4220 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004221 else:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004222 raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07004223
Kevin Cheng550ccc52021-03-03 11:21:43 -08004224 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004225
4226 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004227 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004228 input1 = a[0]
4229 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004230
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004231 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004232 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004233 if not (
4234 # unable to concat tensors of different ranks
4235 error_name == ErrorIf.ConcatInputRankMismatch
4236 # unable to concat tensors along an invalid axis
4237 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004238 ):
4239 for tensor in remaining_inputs:
4240 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004241
Matthew Haddon01c359d2021-10-15 16:30:48 +01004242 if error_name == ErrorIf.ConcatShapeSumMismatch:
4243 output_shape[axis] += rng.integers(5, 10)
4244
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004245 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004246 all_dtypes = {
4247 DType.INT8,
4248 DType.INT16,
4249 DType.INT32,
4250 DType.INT48,
4251 DType.FLOAT,
4252 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004253 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4254 outputDType = rng.choice(wrong_dtypes)
4255 else:
4256 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004257
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004258 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004259
4260 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004261 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004262
4263 output_shape = a.shape.copy()
4264
4265 for i in range(len(output_shape)):
4266 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4267
Matthew Haddone807aae2021-10-11 18:12:58 +01004268 # Fix negative output shape if error_if test causes it
4269 if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
4270 output_shape = [i if i >= 1 else 1 for i in output_shape]
4271
4272 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004273 all_dtypes = [
4274 DType.INT8,
4275 DType.INT16,
4276 DType.INT32,
4277 DType.INT48,
4278 DType.FLOAT,
4279 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004280 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4281 outputDType = rng.choice(wrong_dtypes)
4282 else:
4283 outputDType = a.dtype
4284
4285 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004286
4287 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004288 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004289 output_shape = shape.copy()
4290
4291 totalElements = 1
4292 for i in a.shape:
4293 totalElements *= i
4294
4295 # If there are any -1 elements, figure out what that dimension must be
4296 totalOutputElements = 1
4297 for i in output_shape:
4298 if i != -1:
4299 totalOutputElements *= i
4300
4301 # And fill it in
4302 for i in range(len(output_shape)):
4303 if output_shape[i] == -1:
4304 output_shape[i] = totalElements // totalOutputElements
4305
Matthew Haddone807aae2021-10-11 18:12:58 +01004306 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4307 for i in range(len(output_shape)):
4308 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4309
4310 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004311 all_dtypes = [
4312 DType.INT8,
4313 DType.INT16,
4314 DType.INT32,
4315 DType.INT48,
4316 DType.FLOAT,
4317 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004318 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4319 outputDType = rng.choice(wrong_dtypes)
4320 else:
4321 outputDType = a.dtype
4322
4323 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004324
4325 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004326 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004327
Matthew Haddone807aae2021-10-11 18:12:58 +01004328 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004329 all_dtypes = [
4330 DType.INT8,
4331 DType.INT16,
4332 DType.INT32,
4333 DType.INT48,
4334 DType.FLOAT,
4335 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004336 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4337 outputDType = rng.choice(wrong_dtypes)
4338 else:
4339 outputDType = a.dtype
4340
4341 if error_name == ErrorIf.SizeOutputShapeMismatch:
4342 output_shape = size.copy()
4343 for index in range(len(output_shape)):
4344 if output_shape[index] <= 2:
4345 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4346 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004347 output_shape[index] = output_shape[index] + rng.choice(
4348 [-2, -1, 1, 2]
4349 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004350 else:
4351 output_shape = size.copy()
4352
4353 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004354
4355 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004356 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004357
4358 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004359 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004360
4361 for i in range(len(output_shape)):
4362 output_shape[i] = a.shape[i] * multiples[i]
4363
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004364 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004365 all_dtypes = [
4366 DType.INT8,
4367 DType.INT16,
4368 DType.INT32,
4369 DType.INT48,
4370 DType.FLOAT,
4371 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004372 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4373 outputDType = rng.choice(wrong_dtypes)
4374 else:
4375 outputDType = a.dtype
4376
4377 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004378
4379 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004380 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004381 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004382
Kevin Cheng550ccc52021-03-03 11:21:43 -08004383 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004384
Matthew Haddone807aae2021-10-11 18:12:58 +01004385 if error_name == ErrorIf.IndexOutsideBounds:
4386 for i in range(len(output_shape)):
4387 output_shape[i] = a.shape[0]
4388 else:
4389 for i in range(len(output_shape)):
4390 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004391
Matthew Haddone807aae2021-10-11 18:12:58 +01004392 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004393 all_dtypes = [
4394 DType.INT8,
4395 DType.INT16,
4396 DType.INT32,
4397 DType.INT48,
4398 DType.FLOAT,
4399 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004400 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4401 outputDType = rng.choice(wrong_dtypes)
4402 else:
4403 outputDType = a.dtype
4404
4405 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004406
4407 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004408 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004409 if error_name != ErrorIf.WrongRank:
4410 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004411 assert len(indices.shape) == 2
4412 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004413
Kevin Cheng77d0f762020-11-24 10:26:32 -08004414 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4415
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004416 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004417 all_dtypes = [
4418 DType.INT8,
4419 DType.INT16,
4420 DType.INT32,
4421 DType.INT48,
4422 DType.FLOAT,
4423 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004424 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4425 outputDType = rng.choice(wrong_dtypes)
4426 else:
4427 outputDType = values.dtype
4428
4429 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004430
4431 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004432 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004433 if error_name != ErrorIf.WrongRank:
4434 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004435 assert len(indices.shape) == 2
4436 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004437 assert values_in.shape[0] == indices.shape[0] # N
4438 assert input.shape[1] == indices.shape[1] # W
4439 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004440
4441 output_shape = values_in.shape
4442
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004443 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004444 all_dtypes = [
4445 DType.INT8,
4446 DType.INT16,
4447 DType.INT32,
4448 DType.INT48,
4449 DType.FLOAT,
4450 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004451 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4452 outputDType = rng.choice(wrong_dtypes)
4453 else:
4454 outputDType = values_in.dtype
4455
4456 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004457
4458 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004459 def tableOp(ser, rng, input, error_name=None):
4460 # Same shape as the input, dtype dependent on input dtype
4461 if error_name != ErrorIf.WrongInputType:
4462 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004463 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004464 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004465 wrong_dtypes = [
4466 DType.INT8,
4467 DType.INT16,
4468 DType.INT32,
4469 DType.INT48,
4470 DType.FLOAT,
4471 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004472 wrong_dtypes.remove(output_dtype)
4473 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004474 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004475
4476 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004477 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004478 serializer,
4479 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004480 input,
4481 mode,
4482 stride,
4483 offset,
4484 shift,
4485 stride_fp,
4486 offset_fp,
4487 output_dims,
4488 input_dtype,
4489 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004490 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004491 ):
Matthew Haddon848efb42021-09-09 12:30:53 +01004492 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004493 output_dims = [
4494 input.shape[0],
4495 output_dims[0],
4496 output_dims[0],
4497 input.shape[0],
4498 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004499 else:
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004500 if error_name == ErrorIf.BatchMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004501 output_dims = [
4502 input.shape[0] + rng.integers(1, 10),
4503 output_dims[0],
4504 output_dims[1],
4505 input.shape[3],
4506 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004507 elif error_name == ErrorIf.ChannelMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004508 output_dims = [
4509 input.shape[0],
4510 output_dims[0],
4511 output_dims[1],
4512 input.shape[3] + rng.integers(1, 10),
4513 ]
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004514 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004515 output_dims = [
4516 input.shape[0],
4517 output_dims[0],
4518 output_dims[1],
4519 input.shape[3],
4520 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07004521
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004522 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004523
4524 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004525 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004526 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004527
4528 @staticmethod
Les Bell0e027d42021-11-09 14:42:14 +00004529 def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004530 if error_name == ErrorIf.ConvOutputShapeMismatch:
4531 choices = [1, 2, 3]
4532 change = rng.choice(choices)
4533 if change in [1, 3]:
4534 output_shape[1] = output_shape[1] + rng.choice(choices)
4535 if change in [2, 3]:
4536 output_shape[2] = output_shape[2] + rng.choice(choices)
4537
Kevin Cheng3a478572021-01-22 17:21:02 -08004538 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07004539 out_dtype = DType.INT32
4540 elif ifm.dtype == DType.INT16:
4541 out_dtype = DType.INT48
4542 elif ifm.dtype == DType.FLOAT:
4543 out_dtype = DType.FLOAT
Les Bell0e027d42021-11-09 14:42:14 +00004544 elif error_name == ErrorIf.WrongInputType:
4545 # Pick some potentially correct output dtype if input type is incorrect
4546 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004547 else:
Les Bell0e027d42021-11-09 14:42:14 +00004548 raise Exception(f"Unsupported input dtype: {ifm.dtype}")
4549
4550 if error_name == ErrorIf.WrongOutputType:
4551 wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
4552 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004553
Kevin Cheng550ccc52021-03-03 11:21:43 -08004554 return ser.addOutput(output_shape, out_dtype)