blob: 9ff6ec5c2976e31e271d3497a525346af39a21ae [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010016from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010017from generator.tosa_utils import usableDTypes
Les Bell0e027d42021-11-09 14:42:14 +000018from tosa.DType import DType
19from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010020
21
Eric Kunzee5e26762020-10-13 16:11:07 -070022class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010023 # Maximum rank of tensor supported by test generator.
24 TOSA_TENSOR_MAX_RANK = 6
25
Eric Kunzee5e26762020-10-13 16:11:07 -070026 def __init__(self, args):
27 self.args = args
28 self.basePath = args.output_dir
29 self.random_seed = args.random_seed
30 self.ser = None
31 self.rng = np.random.default_rng(self.random_seed)
32 self.createDynamicOpLists()
33 self.initOpListDefaults()
34 self.quantGen = TosaQuantGen()
35 # Force makeShape to do a specific starting shape
36 self.targetted_shape = None
37
38 def createSerializer(self, opName, testPath):
39 self.testPath = os.path.join(opName, testPath)
40
41 fullPath = os.path.join(self.basePath, self.testPath)
42 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010043 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070044
45 def getSerializer(self):
46 return self.ser
47
48 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080049 with open(
50 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
51 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070052 fd.write(self.ser.serialize())
53
Kevin Cheng550ccc52021-03-03 11:21:43 -080054 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
55 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070056
Matthew Haddon74567092021-07-16 15:38:20 +010057 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000058 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010059 seed = self.random_seed + 1
60 self.rng = np.random.default_rng(seed)
61
Eric Kunzee5e26762020-10-13 16:11:07 -070062 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070063 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070064 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070065 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070066 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070067 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070068 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010069 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
70 elif dtype == DType.UINT8:
71 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070072 elif dtype == DType.INT16:
73 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010074 elif dtype == DType.UINT16:
75 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070076 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080077 return np.int32(
78 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
79 )
Eric Kunzee5e26762020-10-13 16:11:07 -070080 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080081 return np.int64(
82 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
83 )
James Ward8b390432022-08-12 20:48:56 +010084 elif dtype == DType.FP16:
85 return np.float16(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070086 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +010087 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070088 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -080089 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070090
Kevin Cheng989cb052021-04-28 16:29:44 -070091 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -070092 placeholders = []
93
Kevin Cheng989cb052021-04-28 16:29:44 -070094 assert len(shape_list) == len(dtype_list)
95
96 for idx, shape in enumerate(shape_list):
97 arr = self.getRandTensor(shape, dtype_list[idx])
98 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -070099
100 return placeholders
101
Kevin Cheng989cb052021-04-28 16:29:44 -0700102 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700103 consts = []
104
Kevin Cheng989cb052021-04-28 16:29:44 -0700105 assert len(shape_list) == len(dtype_list)
106
107 for idx, shape in enumerate(shape_list):
108 arr = self.getRandTensor(shape, dtype_list[idx])
109 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700110
111 return consts
112
113 def makeShape(self, rank):
114 if self.targetted_shape:
115 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 return np.int32(
117 self.rng.integers(
118 low=self.args.tensor_shape_range[0],
119 high=self.args.tensor_shape_range[1],
120 size=rank,
121 )
122 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700123
124 def setTargetShape(self, shape):
125 self.targetted_shape = shape
126
127 def randInt(self, low=0, high=256):
128 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
129
130 def getRandNumberDType(self, dtype):
131 if dtype == DType.FLOAT:
132 return self.rng.random()
James Ward8b390432022-08-12 20:48:56 +0100133 elif dtype == DType.FP16:
134 rand_f32 = self.rng.random()
135 return np.float16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700136 elif dtype == DType.BOOL:
137 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700138 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700139 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700140 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100142 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 elif dtype == DType.INT16:
144 low, high = (-32768, 32768)
145 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800146 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700147 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800148 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700149 # Special size
150 return np.int64(self.rng.integers(low, high, size=1))[0]
151 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800152 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700153
154 return np.int32(self.rng.integers(low, high, size=1))[0]
155
156 def shapeStr(self, shape):
157
158 sStr = []
159 # Convert to strings
160 for i in shape:
161 sStr.append(str(i))
162
Kevin Cheng550ccc52021-03-03 11:21:43 -0800163 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700164
165 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -0700166 if isinstance(t, list):
167 assert len(t) >= 2
168 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700169 else:
Kevin Cheng989cb052021-04-28 16:29:44 -0700170 if t == DType.BOOL:
171 return "b"
172 elif t == DType.INT4:
173 return "i4"
174 elif t == DType.INT8:
175 return "i8"
176 elif t == DType.UINT8:
177 return "u8"
178 elif t == DType.INT16:
179 return "i16"
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100180 elif t == DType.UINT16:
181 return "u16"
Kevin Cheng989cb052021-04-28 16:29:44 -0700182 elif t == DType.INT32:
183 return "i32"
184 elif t == DType.INT48:
185 return "i48"
James Ward8b390432022-08-12 20:48:56 +0100186 elif t == DType.FP16:
187 return "f16"
Kevin Cheng989cb052021-04-28 16:29:44 -0700188 elif t == DType.FLOAT:
189 return "float"
190 else:
191 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -0700192
193 def typeWidth(self, t):
James Ward8b390432022-08-12 20:48:56 +0100194 """Get the datatype width for data types"""
Kevin Cheng3a478572021-01-22 17:21:02 -0800195 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -0700196 return 4
197 elif t == DType.INT8:
198 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -0800199 elif t == DType.UINT8:
200 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 elif t == DType.INT16:
202 return 16
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +0100203 elif t == DType.UINT16:
204 return 16
Eric Kunzee5e26762020-10-13 16:11:07 -0700205 elif t == DType.INT32:
206 return 32
207 elif t == DType.INT48:
208 return 48
James Ward8b390432022-08-12 20:48:56 +0100209 elif t == DType.FP16:
210 return 16
Matthew Haddonc2025212021-10-08 21:21:05 +0100211 elif t == DType.FLOAT:
212 return 32
213 elif t == DType.BOOL:
214 return 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700215 else:
Les Bell729b0352021-11-24 10:28:21 +0000216 raise Exception(f"Unknown dtype, cannot determine width: {t}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
218 # Argument generators
219 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
220 # Where the string descriptor is used to generate the test name and
221 # The build_fcn_arg_list is expanded and passed to the operator test
222 # build function
223
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100224 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
225 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
226
Matthew Haddon848efb42021-09-09 12:30:53 +0100227 # build_placeholder returns an int, ABS/other ops does not
228 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000229 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100230 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000231 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000232 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100233 return result_tens
234
235 # Ensure new output type has correct qinfo
236 if error_name == ErrorIf.WrongOutputType:
237 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000238 qinfo = [
239 TosaQuantGen.getZeroPoint(self, a.dtype),
240 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
241 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100242
243 # Invalidate Input/Output list for error if checks.
244 input_list = [a.name]
245 output_list = [result_tens.name]
246 pCount, cCount = op["operands"]
247 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000248 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
249 self, error_name, input_list, output_list
250 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100251
Les Bell729b0352021-11-24 10:28:21 +0000252 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100253 self.ser,
254 validator_fcns,
255 error_name,
256 op=op,
257 input_dtype=a.dtype,
258 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000259 qinfo=qinfo,
260 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100261 input_list=input_list,
262 output_list=output_list,
263 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000264 ):
265 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100266
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000267 attr = None
268 if op["op"] == Op.NEGATE:
269 attr = ts.TosaSerializerAttribute()
270 attr.NegateAttribute(qinfo[0], qinfo[1])
271
272 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700273 return result_tens
274
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100275 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000276 result_tens = OutputShaper.binaryBroadcastOp(
277 self.ser, self.rng, a, b, error_name
278 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100279
280 # Invalidate Input/Output list for error if checks.
281 input_list = [a.name, b.name]
282 output_list = [result_tens.name]
283 pCount, cCount = op["operands"]
284 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000285 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
286 self, error_name, input_list, output_list
287 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100288
Les Bell729b0352021-11-24 10:28:21 +0000289 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100290 self.ser,
291 validator_fcns,
292 error_name,
293 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000294 input1=a,
295 input2=b,
296 input_dtype=a.dtype,
297 output_dtype=result_tens.dtype,
298 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100299 input_list=input_list,
300 output_list=output_list,
301 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000302 ):
303 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100304
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000305 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700306 return result_tens
307
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100308 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700309 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000310 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700311 return result_tens
312
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000313 def build_arithmetic_right_shift(
314 self, op, a, b, round, validator_fcns=None, error_name=None
315 ):
316 result_tens = OutputShaper.binaryBroadcastOp(
317 self.ser, self.rng, a, b, error_name
318 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100319
320 # Invalidate Input/Output list for error if checks.
321 input_list = [a.name, b.name]
322 output_list = [result_tens.name]
323 pCount, cCount = op["operands"]
324 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000325 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
326 self, error_name, input_list, output_list
327 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100328
Les Bell729b0352021-11-24 10:28:21 +0000329 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100330 self.ser,
331 validator_fcns,
332 error_name,
333 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000334 input1=a,
335 input2=b,
336 input_dtype=a.dtype,
337 output_dtype=result_tens.dtype,
338 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100339 input_list=input_list,
340 output_list=output_list,
341 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000342 ):
343 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800344
345 attr = ts.TosaSerializerAttribute()
346 attr.ArithmeticRightShiftAttribute(round)
347
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000348 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800349 return result_tens
350
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100351 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000352 result_tens = OutputShaper.binaryBroadcastOp(
353 self.ser, self.rng, a, b, error_name
354 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700355
356 # Special for multiply:
357 # Force the result to INT32 for INT types
James Ward8b390432022-08-12 20:48:56 +0100358 if a.dtype not in (DType.FP16, DType.FLOAT):
Eric Kunzee5e26762020-10-13 16:11:07 -0700359 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100360 if error_name == ErrorIf.WrongOutputType:
361 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
362 outputDType = self.rng.choice(all_dtypes)
363 result_tens.setDtype(outputDType)
364
365 # Invalidate Input/Output list for error if checks.
366 input_list = [a.name, b.name]
367 output_list = [result_tens.name]
368 pCount, cCount = op["operands"]
369 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000370 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
371 self, error_name, input_list, output_list
372 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100373
Les Bell729b0352021-11-24 10:28:21 +0000374 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100375 self.ser,
376 validator_fcns,
377 error_name,
378 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000379 input1=a,
380 input2=b,
381 input_dtype=a.dtype,
382 output_dtype=result_tens.dtype,
383 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100384 input_list=input_list,
385 output_list=output_list,
386 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000387 ):
388 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700389
Kevin Chengaee1fac2020-11-11 13:54:06 -0800390 attr = ts.TosaSerializerAttribute()
391 attr.MulAttribute(shift)
392
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700394 return result_tens
395
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100396 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
397 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700398
Kevin Chengfe392ce2021-10-18 21:51:55 +0000399 attr = ts.TosaSerializerAttribute()
400 attr.TableAttribute(table)
401
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100402 # Invalidate Input/Output list for error if checks.
403 input_list = [a.name]
404 output_list = [result_tens.name]
405 pCount, cCount = op["operands"]
406 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000407 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
408 self, error_name, input_list, output_list
409 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410
Les Bell729b0352021-11-24 10:28:21 +0000411 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100412 self.ser,
413 validator_fcns,
414 error_name,
415 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000416 input_shape=a.shape,
417 input_dtype=a.dtype,
418 output_dtype=result_tens.dtype,
419 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100420 input_list=input_list,
421 output_list=output_list,
422 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000423 ):
424 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100425
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000426 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700427
428 return result_tens
429
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100430 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
431 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
432
433 # Invalidate Input/Output list for error if checks.
434 input_list = [cond.name, a.name, b.name]
435 output_list = [result_tens.name]
436 pCount, cCount = op["operands"]
437 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000438 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
439 self, error_name, input_list, output_list
440 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100441
Les Bell729b0352021-11-24 10:28:21 +0000442 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100443 self.ser,
444 validator_fcns,
445 error_name,
446 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000447 input1=cond,
448 input2=a,
449 input3=b,
450 input_shape=a.shape,
451 input_dtype=a.dtype,
452 output_dtype=result_tens.dtype,
453 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100454 input_list=input_list,
455 output_list=output_list,
456 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000457 ):
458 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100459
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000460 self.ser.addOperator(
461 op["op"],
462 input_list,
463 output_list,
464 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700465 return result_tens
466
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100467 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000468 result_tens = OutputShaper.binaryComparisonOp(
469 self.ser, self.rng, a, b, error_name
470 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471
472 # Invalidate Input/Output list for error if checks.
473 input_list = [a.name, b.name]
474 output_list = [result_tens.name]
475 pCount, cCount = op["operands"]
476 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000477 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
478 self, error_name, input_list, output_list
479 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480
Les Bell729b0352021-11-24 10:28:21 +0000481 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100482 self.ser,
483 validator_fcns,
484 error_name,
485 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000486 input1=a,
487 input2=b,
488 input_shape=a.shape,
489 input_dtype=a.dtype,
490 output_shape=result_tens.shape,
491 output_dtype=result_tens.dtype,
492 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100493 input_list=input_list,
494 output_list=output_list,
495 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000496 ):
497 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100498
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000499 self.ser.addOperator(
500 op["op"],
501 input_list,
502 output_list,
503 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700504 return result_tens
505
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100506 def build_argmax(self, op, a, axis, validator_fcns, error_name):
507 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
508
509 # Invalidate Input/Output list for error if checks.
510 input_list = [a.name]
511 output_list = [result_tens.name]
512 pCount, cCount = op["operands"]
513 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000514 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
515 self, error_name, input_list, output_list
516 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100517
Les Bell729b0352021-11-24 10:28:21 +0000518 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100519 self.ser,
520 validator_fcns,
521 error_name,
522 op=op,
523 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000524 input_shape=a.shape,
525 input_dtype=a.dtype,
526 output_shape=result_tens.shape,
527 output_dtype=result_tens.dtype,
528 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100529 input_list=input_list,
530 output_list=output_list,
531 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000532 ):
533 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700534
535 attr = ts.TosaSerializerAttribute()
536 attr.AxisAttribute(axis)
537
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000538 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700539 return result_tens
540
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000541 def build_pool2d(
542 self,
543 op,
544 input,
James Ward8b390432022-08-12 20:48:56 +0100545 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000546 stride,
547 pad,
548 kernel,
549 validator_fcns=None,
550 error_name=None,
551 qinfo=None,
552 ):
553 result_tens = OutputShaper.pool2dOp(
554 self.ser, self.rng, input, kernel, stride, pad, error_name
555 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100556
557 # Ensure new output type has correct qinfo
558 if error_name == ErrorIf.WrongInputType:
559 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000560 qinfo = [
561 TosaQuantGen.getZeroPoint(self, input.dtype),
562 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
563 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100564
565 # Invalidate Input/Output list for error if checks.
566 input_list = [input.name]
567 output_list = [result_tens.name]
568 pCount, cCount = op["operands"]
569 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000570 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
571 self, error_name, input_list, output_list
572 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100573
Les Bell729b0352021-11-24 10:28:21 +0000574 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100575 self.ser,
576 validator_fcns,
577 error_name,
578 op=op,
579 input_shape=input.shape,
580 input_dtype=input.dtype,
581 output_shape=result_tens.shape,
582 output_dtype=result_tens.dtype,
583 kernel=kernel,
584 stride=stride,
585 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000586 qinfo=qinfo,
587 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100588 input_list=input_list,
589 output_list=output_list,
590 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000591 ):
592 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700593
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000594 if qinfo is None:
595 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700596
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000597 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100598 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000599
600 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700601 return result_tens
602
James Ward8b390432022-08-12 20:48:56 +0100603 def build_maxpool2d(
604 self,
605 op,
606 input,
607 stride,
608 pad,
609 kernel,
610 validator_fcns=None,
611 error_name=None,
612 qinfo=None,
613 ):
614 # Same as build_pool2d but manually sets accum_dtype value
615 # (maxpool has no accum_dtype)
616 return self.build_pool2d(
617 op,
618 input,
619 DType.UNKNOWN,
620 stride,
621 pad,
622 kernel,
623 validator_fcns,
624 error_name,
625 qinfo,
626 )
627
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000628 def build_conv2d(
629 self,
630 op,
631 ifm,
632 filter,
633 bias,
James Ward8b390432022-08-12 20:48:56 +0100634 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000635 strides,
636 padding,
637 dilations,
638 validator_fcns=None,
639 error_name=None,
640 qinfo=None,
641 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800642 assert len(padding) == 4
643 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100644 self.ser,
645 self.rng,
646 ifm,
647 filter,
648 accum_dtype,
649 strides,
650 padding,
651 dilations,
652 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000653 )
654
655 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000656 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
657 DType.INT8,
658 DType.UINT8,
659 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000660 qinfo = [
661 TosaQuantGen.getZeroPoint(self, ifm.dtype),
662 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
663 ]
Les Bell0e027d42021-11-09 14:42:14 +0000664
665 # Invalidate Input/Output list for error_if checks.
666 input_list = [ifm.name, filter.name, bias.name]
667 output_list = [result_tens.name]
668 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000669 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
670 self, error_name, input_list, output_list
671 )
Les Bell0e027d42021-11-09 14:42:14 +0000672
Les Bell729b0352021-11-24 10:28:21 +0000673 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000674 self.ser,
675 validator_fcns,
676 error_name,
677 op=op,
678 input_dtype=ifm.dtype,
679 weight_dtype=filter.dtype,
680 output_dtype=result_tens.dtype,
681 qinfo=qinfo,
682 input_list=input_list,
683 num_operands=num_operands,
684 output_list=output_list,
685 pad=padding,
686 stride=strides,
687 dilation=dilations,
688 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100689 weight_shape=filter.shape,
690 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000691 ):
692 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700693
694 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100695 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700696
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000697 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700698 return result_tens
699
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000700 def build_conv3d(
701 self,
702 op,
703 ifm,
704 filter,
705 bias,
James Ward8b390432022-08-12 20:48:56 +0100706 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000707 strides,
708 padding,
709 dilations,
710 validator_fcns=None,
711 error_name=None,
712 qinfo=None,
713 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700714 assert len(padding) == 6
715 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100716 self.ser,
717 self.rng,
718 ifm,
719 filter,
720 accum_dtype,
721 strides,
722 padding,
723 dilations,
724 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000725 )
726
727 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000728 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
729 DType.INT8,
730 DType.UINT8,
731 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000732 qinfo = [
733 TosaQuantGen.getZeroPoint(self, ifm.dtype),
734 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
735 ]
Les Bell0e027d42021-11-09 14:42:14 +0000736
737 # Invalidate Input/Output list for error_if checks.
738 input_list = [ifm.name, filter.name, bias.name]
739 output_list = [result_tens.name]
740 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000741 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
742 self, error_name, input_list, output_list
743 )
Les Bell0e027d42021-11-09 14:42:14 +0000744
Les Bell729b0352021-11-24 10:28:21 +0000745 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000746 self.ser,
747 validator_fcns,
748 error_name,
749 op=op,
750 input_dtype=ifm.dtype,
751 weight_dtype=filter.dtype,
752 output_dtype=result_tens.dtype,
753 qinfo=qinfo,
754 input_list=input_list,
755 num_operands=num_operands,
756 output_list=output_list,
757 pad=padding,
758 stride=strides,
759 dilation=dilations,
760 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100761 weight_shape=filter.shape,
762 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000763 ):
764 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700765
766 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100767 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Kevin Cheng1533b852021-09-01 12:51:58 -0700768
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000769 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700770 return result_tens
771
Kevin Cheng550ccc52021-03-03 11:21:43 -0800772 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000773 self,
774 op,
775 ifm,
776 filter,
777 bias,
James Ward8b390432022-08-12 20:48:56 +0100778 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000779 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700780 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000781 output_shape,
782 validator_fcns=None,
783 error_name=None,
784 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800785 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700786 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100788 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000789 )
Les Bell0e027d42021-11-09 14:42:14 +0000790
791 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000792 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
793 DType.INT8,
794 DType.UINT8,
795 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000796 qinfo = [
797 TosaQuantGen.getZeroPoint(self, ifm.dtype),
798 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
799 ]
Les Bell0e027d42021-11-09 14:42:14 +0000800
801 # Invalidate Input/Output list for error_if checks.
802 input_list = [ifm.name, filter.name, bias.name]
803 output_list = [result_tens.name]
804 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000805 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
806 self, error_name, input_list, output_list
807 )
Les Bell0e027d42021-11-09 14:42:14 +0000808
Les Bell729b0352021-11-24 10:28:21 +0000809 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000810 self.ser,
811 validator_fcns,
812 error_name,
813 op=op,
814 input_dtype=ifm.dtype,
815 weight_dtype=filter.dtype,
816 output_dtype=result_tens.dtype,
817 qinfo=qinfo,
818 input_list=input_list,
819 num_operands=num_operands,
820 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700821 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000822 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000823 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100824 weight_shape=filter.shape,
825 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000826 ):
827 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700828
829 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100830 attr.TransposeConvAttribute(
831 out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype
832 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700833
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000834 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700835 return result_tens
836
Kevin Cheng550ccc52021-03-03 11:21:43 -0800837 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000838 self,
839 op,
840 ifm,
841 filter,
842 bias,
James Ward8b390432022-08-12 20:48:56 +0100843 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000844 strides,
845 padding,
846 dilations,
847 validator_fcns=None,
848 error_name=None,
849 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800850 ):
851 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100852 self.ser,
853 self.rng,
854 ifm,
855 filter,
856 accum_dtype,
857 strides,
858 padding,
859 dilations,
860 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000861 )
862
863 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000864 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
865 DType.INT8,
866 DType.UINT8,
867 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000868 qinfo = [
869 TosaQuantGen.getZeroPoint(self, ifm.dtype),
870 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
871 ]
Les Bell0e027d42021-11-09 14:42:14 +0000872
873 # Invalidate Input/Output list for error_if checks.
874 input_list = [ifm.name, filter.name, bias.name]
875 output_list = [result_tens.name]
876 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000877 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
878 self, error_name, input_list, output_list
879 )
Les Bell0e027d42021-11-09 14:42:14 +0000880
Les Bell729b0352021-11-24 10:28:21 +0000881 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000882 self.ser,
883 validator_fcns,
884 error_name,
885 op=op,
886 input_dtype=ifm.dtype,
887 weight_dtype=filter.dtype,
888 output_dtype=result_tens.dtype,
889 qinfo=qinfo,
890 input_list=input_list,
891 num_operands=num_operands,
892 output_list=output_list,
893 pad=padding,
894 stride=strides,
895 dilation=dilations,
896 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100897 weight_shape=filter.shape,
898 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000899 ):
900 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700901
902 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100903 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700904
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000905 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700906 return result_tens
907
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000908 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100909 self,
910 op,
911 ifm,
912 filter,
913 bias,
914 accum_dtype,
915 validator_fcns=None,
916 error_name=None,
917 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000918 ):
919 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100920 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000921 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100922
923 # Invalidate Input/Output list for error if checks.
924 input_list = [ifm.name, filter.name, bias.name]
925 output_list = [result_tens.name]
926 pCount, cCount = op["operands"]
927 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000928 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
929 self, error_name, input_list, output_list
930 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100931
Les Bell729b0352021-11-24 10:28:21 +0000932 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100933 self.ser,
934 validator_fcns,
935 error_name,
936 op=op,
937 input_shape=ifm.shape,
938 input_dtype=ifm.dtype,
939 weight_dtype=filter.dtype,
940 output_shape=result_tens.shape,
941 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000942 qinfo=qinfo,
943 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100944 input_list=input_list,
945 output_list=output_list,
946 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100947 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000948 ):
949 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700950
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000951 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100952 attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000953
954 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700955 return result_tens
956
James Ward8b390432022-08-12 20:48:56 +0100957 def build_matmul(
958 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
959 ):
960 result_tens = OutputShaper.matmulOp(
961 self.ser, self.rng, a, b, accum_dtype, error_name
962 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100963
964 # Invalidate Input/Output list for error if checks.
965 input_list = [a.name, b.name]
966 output_list = [result_tens.name]
967 pCount, cCount = op["operands"]
968 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000969 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
970 self, error_name, input_list, output_list
971 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100972
Les Bell729b0352021-11-24 10:28:21 +0000973 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100974 self.ser,
975 validator_fcns,
976 error_name,
977 op=op,
978 input_shape=a.shape,
979 input_dtype=a.dtype,
980 input2_shape=b.shape,
981 input2_dtype=b.dtype,
982 output_shape=result_tens.shape,
983 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000984 qinfo=qinfo,
985 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100986 input_list=input_list,
987 output_list=output_list,
988 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100989 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000990 ):
991 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100992
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000993 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100994 attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000995
996 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700997 return result_tens
998
Matthew Haddond6ce7252021-09-29 15:35:44 +0100999 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1000 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1001
1002 # Invalidate Input/Output list for error if checks.
1003 input_list = [a.name]
1004 output_list = [result_tens.name]
1005 pCount, cCount = op["operands"]
1006 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001007 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1008 self, error_name, input_list, output_list
1009 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001010
Les Bell729b0352021-11-24 10:28:21 +00001011 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001012 self.ser,
1013 validator_fcns,
1014 error_name,
1015 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001016 axis=axis,
1017 input_shape=a.shape,
1018 output_shape=result_tens.shape,
1019 input_dtype=a.dtype,
1020 output_dtype=result_tens.dtype,
1021 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +01001022 input_list=input_list,
1023 output_list=output_list,
1024 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001025 ):
1026 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001027
1028 attr = ts.TosaSerializerAttribute()
1029 attr.AxisAttribute(axis)
1030
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001031 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001032 return result_tens
1033
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001034 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1035 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001036
Jeremy Johnson18e26662021-07-22 16:15:29 +01001037 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001038
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001039 if error_name == ErrorIf.MaxSmallerMin:
1040 # Make sure the numbers are different to invoke this error
1041 while v[0] == v[1]:
1042 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1043 max_val = min(v)
1044 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001045 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001046 max_val = max(v)
1047 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001048
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001049 # Invalidate Input/Output list for error if checks.
1050 input_list = [a.name]
1051 output_list = [result_tens.name]
1052 pCount, cCount = op["operands"]
1053 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001054 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1055 self, error_name, input_list, output_list
1056 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001057
Les Bell729b0352021-11-24 10:28:21 +00001058 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001059 self.ser,
1060 validator_fcns,
1061 error_name,
1062 op=op,
1063 max_val=max_val,
1064 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001065 input_shape=a.shape,
1066 output_shape=result_tens.shape,
1067 input_dtype=a.dtype,
1068 output_dtype=result_tens.dtype,
1069 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001070 input_list=input_list,
1071 output_list=output_list,
1072 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001073 ):
1074 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001075
1076 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +01001077 if a.dtype in (DType.FP16, DType.FLOAT):
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001078 attr.ClampAttribute(0, 0, min_val, max_val)
1079 else:
1080 attr.ClampAttribute(min_val, max_val, 0, 0)
1081
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001082 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001083 return result_tens
1084
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001085 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1086 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001087 attr = ts.TosaSerializerAttribute()
1088
1089 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1090
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001091 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001092 return result_tens
1093
1094 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001095 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1096 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001097
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001098 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001099 return result_tens
1100
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001101 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1102 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1103
1104 # Invalidate Input/Output list for error if checks.
1105 input_list = [a.name]
1106 output_list = [result_tens.name]
1107 pCount, cCount = op["operands"]
1108 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001109 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1110 self, error_name, input_list, output_list
1111 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001112
Les Bell729b0352021-11-24 10:28:21 +00001113 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001114 self.ser,
1115 validator_fcns,
1116 error_name,
1117 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001118 input_shape=a.shape,
1119 output_shape=result_tens.shape,
1120 input_dtype=a.dtype,
1121 output_dtype=result_tens.dtype,
1122 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001123 input_list=input_list,
1124 output_list=output_list,
1125 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001126 ):
1127 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001128
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001129 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001130 return result_tens
1131
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001132 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1133 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1134
1135 # Invalidate Input/Output list for error if checks.
1136 input_list = [a.name]
1137 output_list = [result_tens.name]
1138 pCount, cCount = op["operands"]
1139 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001140 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1141 self, error_name, input_list, output_list
1142 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001143
Les Bell729b0352021-11-24 10:28:21 +00001144 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001145 self.ser,
1146 validator_fcns,
1147 error_name,
1148 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001149 input_shape=a.shape,
1150 output_shape=result_tens.shape,
1151 input_dtype=a.dtype,
1152 output_dtype=result_tens.dtype,
1153 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001154 input_list=input_list,
1155 output_list=output_list,
1156 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001157 ):
1158 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001159
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001160 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001161 return result_tens
1162
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001163 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1164 if error_name != ErrorIf.WrongInputType:
1165 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001166
1167 # To store variable length list of input tensors we need to store axis along with it
1168 axis = a[-1]
1169 a = a[:-1]
1170
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001171 result_tens = OutputShaper.concatOp(
1172 self.ser, self.rng, axis, *a, error_name=error_name
1173 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001174
Matthew Haddon818ab902021-07-27 09:12:49 +01001175 input_tensor_names = []
1176 for tensor in a:
1177 input_tensor_names.append(tensor.name)
1178
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001179 # Invalidate Input/Output list for error if checks.
1180 input_list = input_tensor_names
1181 output_list = [result_tens.name]
1182 pCount, cCount = op["operands"]
1183 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001184 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1185 self, error_name, input_list, output_list
1186 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001187
Les Bell729b0352021-11-24 10:28:21 +00001188 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001189 self.ser,
1190 validator_fcns,
1191 error_name,
1192 op=op,
1193 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001194 input_shape=a[0].shape,
1195 output_shape=result_tens.shape,
1196 input_dtype=a[0].dtype,
1197 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001198 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001199 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001200 input_list=input_list,
1201 output_list=output_list,
1202 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001203 ):
1204 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001205
1206 attr = ts.TosaSerializerAttribute()
1207 attr.AxisAttribute(axis)
1208
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001209 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001210 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001211
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001212 def build_pad(
1213 self,
1214 op,
1215 a,
1216 padding,
1217 pad_const_int,
1218 pad_const_float,
1219 validator_fcns=None,
1220 error_name=None,
1221 qinfo=None,
1222 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001223 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001224
Kevin Chengfe392ce2021-10-18 21:51:55 +00001225 attr = ts.TosaSerializerAttribute()
1226 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001227
Matthew Haddone807aae2021-10-11 18:12:58 +01001228 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001229 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001230 output_list = [result_tens.name]
1231 pCount, cCount = op["operands"]
1232 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001233 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1234 self, error_name, input_list, output_list
1235 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001236
Les Bell729b0352021-11-24 10:28:21 +00001237 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001238 self.ser,
1239 validator_fcns,
1240 error_name,
1241 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001242 input_shape=a.shape,
1243 output_shape=result_tens.shape,
1244 input_dtype=a.dtype,
1245 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001246 pad=padding,
1247 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001248 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001249 input_list=input_list,
1250 output_list=output_list,
1251 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001252 ):
1253 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001254
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001255 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001256 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001257
Matthew Haddone807aae2021-10-11 18:12:58 +01001258 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001259 result_tens = OutputShaper.reshapeOp(
1260 self.ser, self.rng, a, newShape, error_name
1261 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001262
1263 # Invalidate Input/Output list for error if checks.
1264 input_list = [a.name]
1265 output_list = [result_tens.name]
1266 pCount, cCount = op["operands"]
1267 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001268 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1269 self, error_name, input_list, output_list
1270 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001271
Les Bell729b0352021-11-24 10:28:21 +00001272 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001273 self.ser,
1274 validator_fcns,
1275 error_name,
1276 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001277 input_shape=a.shape,
1278 output_shape=result_tens.shape,
1279 input_dtype=a.dtype,
1280 output_dtype=result_tens.dtype,
1281 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001282 input_list=input_list,
1283 output_list=output_list,
1284 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001285 ):
1286 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001287
1288 attr = ts.TosaSerializerAttribute()
1289 attr.ReshapeAttribute(newShape)
1290
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001291 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001292 return result_tens
1293
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001294 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1295 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1296
1297 # Invalidate Input/Output list for error if checks.
1298 input_list = [a.name]
1299 output_list = [result_tens.name]
1300 pCount, cCount = op["operands"]
1301 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001302 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1303 self, error_name, input_list, output_list
1304 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001305
Les Bell729b0352021-11-24 10:28:21 +00001306 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001307 self.ser,
1308 validator_fcns,
1309 error_name,
1310 op=op,
1311 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 input_shape=a.shape,
1313 output_shape=result_tens.shape,
1314 input_dtype=a.dtype,
1315 output_dtype=result_tens.dtype,
1316 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001317 input_list=input_list,
1318 output_list=output_list,
1319 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001320 ):
1321 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001322
1323 attr = ts.TosaSerializerAttribute()
1324 attr.AxisAttribute(axis)
1325
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001326 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001327 return result_tens
1328
Matthew Haddone807aae2021-10-11 18:12:58 +01001329 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1330 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001331
Kevin Chengfe392ce2021-10-18 21:51:55 +00001332 attr = ts.TosaSerializerAttribute()
1333 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001334
Matthew Haddone807aae2021-10-11 18:12:58 +01001335 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001336 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001337 output_list = [result_tens.name]
1338 pCount, cCount = op["operands"]
1339 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001340 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1341 self, error_name, input_list, output_list
1342 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001343
Les Bell729b0352021-11-24 10:28:21 +00001344 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001345 self.ser,
1346 validator_fcns,
1347 error_name,
1348 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 input_shape=a.shape,
1350 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001351 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001352 input_dtype=a.dtype,
1353 output_dtype=result_tens.dtype,
1354 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001355 input_list=input_list,
1356 output_list=output_list,
1357 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001358 ):
1359 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001360
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001361 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001362 return result_tens
1363
Matthew Haddone807aae2021-10-11 18:12:58 +01001364 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001365 result_tens = OutputShaper.sliceOp(
1366 self.ser, self.rng, a, start, size, error_name
1367 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001368
1369 # Invalidate Input/Output list for error if checks.
1370 input_list = [a.name]
1371 output_list = [result_tens.name]
1372 pCount, cCount = op["operands"]
1373 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001374 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1375 self, error_name, input_list, output_list
1376 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001377
Les Bell729b0352021-11-24 10:28:21 +00001378 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001379 self.ser,
1380 validator_fcns,
1381 error_name,
1382 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001383 input_shape=a.shape,
1384 output_shape=result_tens.shape,
1385 input_dtype=a.dtype,
1386 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001387 start=start,
1388 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001389 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001390 input_list=input_list,
1391 output_list=output_list,
1392 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001393 ):
1394 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001395
1396 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001397 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001398
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001399 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001400 return result_tens
1401
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1403 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1404
1405 # Invalidate Input/Output list for error if checks.
1406 input_list = [a.name]
1407 output_list = [result_tens.name]
1408 pCount, cCount = op["operands"]
1409 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001410 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1411 self, error_name, input_list, output_list
1412 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413
Les Bell729b0352021-11-24 10:28:21 +00001414 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001415 self.ser,
1416 validator_fcns,
1417 error_name,
1418 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001419 input_shape=a.shape,
1420 output_shape=result_tens.shape,
1421 input_dtype=a.dtype,
1422 output_dtype=result_tens.dtype,
1423 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 input_list=input_list,
1425 output_list=output_list,
1426 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001427 ):
1428 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001429
1430 attr = ts.TosaSerializerAttribute()
1431 attr.TileAttribute(multiples)
1432
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001433 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001434 return result_tens
1435
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001436 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001437
1438 # Create a new indicies tensor
1439 # here with data that doesn't exceed the dimensions of the values tensor
1440
Kevin Cheng550ccc52021-03-03 11:21:43 -08001441 K = values.shape[1] # K
1442 W = self.randInt(
1443 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1444 ) # W
1445 indicies_arr = np.int32(
1446 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1447 ) # (N, W)
1448 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001449
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001450 result_tens = OutputShaper.gatherOp(
1451 self.ser, self.rng, values, indicies, error_name
1452 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001453
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001454 # Invalidate Input/Output list for error if checks.
1455 input_list = [values.name, indicies.name]
1456 output_list = [result_tens.name]
1457 pCount, cCount = op["operands"]
1458 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001459 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1460 self, error_name, input_list, output_list
1461 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001462
Les Bell729b0352021-11-24 10:28:21 +00001463 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001464 self.ser,
1465 validator_fcns,
1466 error_name,
1467 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001468 input_shape=values.shape,
1469 output_shape=result_tens.shape,
1470 input_dtype=values.dtype,
1471 output_dtype=result_tens.dtype,
1472 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001473 input_list=input_list,
1474 output_list=output_list,
1475 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001476 ):
1477 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001478
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001479 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001480
1481 return result_tens
1482
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001483 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001484
1485 # Create a new indicies tensor
1486 # here with data that doesn't exceed the dimensions of the values_in tensor
1487
Kevin Cheng550ccc52021-03-03 11:21:43 -08001488 K = values_in.shape[1] # K
1489 W = input.shape[1] # W
1490 indicies_arr = np.int32(
1491 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1492 ) # (N, W)
1493 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001494
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001495 result_tens = OutputShaper.scatterOp(
1496 self.ser, self.rng, values_in, indicies, input, error_name
1497 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001498
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001499 # Invalidate Input/Output list for error if checks.
1500 input_list = [values_in.name, indicies.name, input.name]
1501 output_list = [result_tens.name]
1502 pCount, cCount = op["operands"]
1503 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001504 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1505 self, error_name, input_list, output_list
1506 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001507
Les Bell729b0352021-11-24 10:28:21 +00001508 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001509 self.ser,
1510 validator_fcns,
1511 error_name,
1512 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001513 input_shape=values_in.shape,
1514 output_shape=result_tens.shape,
1515 input_dtype=values_in.dtype,
1516 output_dtype=result_tens.dtype,
1517 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001518 input_list=input_list,
1519 output_list=output_list,
1520 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001521 ):
1522 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001523
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001524 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001525
Kevin Cheng77d0f762020-11-24 10:26:32 -08001526 return result_tens
1527
Kevin Cheng550ccc52021-03-03 11:21:43 -08001528 def build_resize(
1529 self,
1530 op,
1531 input,
1532 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001533 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001534 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001535 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001536 input_dtype,
1537 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001538 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001539 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001540 ):
1541 result_tens = OutputShaper.resizeOp(
1542 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001543 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001544 input,
1545 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001546 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001547 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001548 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001549 input_dtype,
1550 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001551 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001552 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001553
Matthew Haddon848efb42021-09-09 12:30:53 +01001554 # Invalidate Input/Output list for error if checks.
1555 input_list = [input.name]
1556 output_list = [result_tens.name]
1557 pCount, cCount = op["operands"]
1558 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001559 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1560 self, error_name, input_list, output_list
1561 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001562
Les Bell729b0352021-11-24 10:28:21 +00001563 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001564 self.ser,
1565 validator_fcns,
1566 error_name,
1567 op=op,
1568 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001569 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001570 input_dtype=input_dtype,
1571 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001572 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001573 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001574 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001575 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001576 input_list=input_list,
1577 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001578 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001579 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001580 ):
1581 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001582
Eric Kunzee5e26762020-10-13 16:11:07 -07001583 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001584
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001585 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001586
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001587 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001588 return result_tens
1589
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001590 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1591 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1592 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001593 self.ser.addOperator(
1594 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1595 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001596 return result_tens
1597
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001598 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001599 self.ser.addOutputTensor(val)
1600 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001601
1602 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001603 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001604 result_tens = OutputShaper.typeConversionOp(
1605 self.ser, self.rng, val, out_dtype, error_name
1606 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001607
1608 # Invalidate Input/Output list for error if checks.
1609 input_list = [val.name]
1610 output_list = [result_tens.name]
1611 pCount, cCount = op["operands"]
1612 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001613 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1614 self, error_name, input_list, output_list
1615 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001616
Les Bell729b0352021-11-24 10:28:21 +00001617 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001618 self.ser,
1619 validator_fcns,
1620 error_name,
1621 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001622 input_shape=val.shape,
1623 output_shape=result_tens.shape,
1624 input_dtype=val.dtype,
1625 output_dtype=result_tens.dtype,
1626 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001627 input_list=input_list,
1628 output_list=output_list,
1629 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001630 ):
1631 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001632
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001633 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001634 return result_tens
1635
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001636 def build_rescale(
1637 self,
1638 op,
1639 val,
1640 out_dtype,
1641 scale32,
1642 double_round,
1643 per_channel,
1644 validator_fcns,
1645 error_name,
1646 ):
1647 result_tens = OutputShaper.typeConversionOp(
1648 self.ser, self.rng, val, out_dtype, error_name
1649 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001650
1651 if per_channel:
1652 nc = val.shape[-1]
1653 else:
1654 nc = 1
1655
1656 in_type_width = self.typeWidth(val.dtype)
1657 out_type_width = self.typeWidth(out_dtype)
1658
Kevin Cheng3a478572021-01-22 17:21:02 -08001659 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001660 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001661 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001662 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001663 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001664 in_type_width += 1
1665 elif error_name in [
1666 ErrorIf.InputZeroPointNotZero,
1667 ErrorIf.U16InputZeroPointNotValid,
1668 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001669 input_zp = self.randInt(-128, 128)
1670 if input_zp == 0:
1671 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001672 in_type_width += 1
1673 elif val.dtype == DType.UINT16:
1674 # Must come after ErrorIf.U16InputZeroPointNotValid check
1675 input_zp = self.rng.choice([0, 32768])
1676 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001677 else:
1678 input_zp = 0
1679
Kevin Cheng3a478572021-01-22 17:21:02 -08001680 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001681 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001682 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001683 elif out_dtype == DType.UINT8:
1684 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001685 out_type_width += 1
1686 elif error_name in [
1687 ErrorIf.OutputZeroPointNotZero,
1688 ErrorIf.U16OutputZeroPointNotValid,
1689 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001690 output_zp = self.randInt(-128, 128)
1691 if output_zp == 0:
1692 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001693 out_type_width += 1
1694 elif out_dtype == DType.UINT16:
1695 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1696 output_zp = self.rng.choice([0, 32768])
1697 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001698 else:
1699 output_zp = 0
1700
1701 # Calculate scale based on:
1702 # scale = a *(2^output_width)/(2^input_width))
1703
1704 a = np.float32(self.rng.random(size=[nc]))
1705 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1706
1707 if scale32:
1708 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001709 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001710 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1711 else:
1712 # Cap the scaling at 2^15 - 1 for scale16
1713 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1714
Kevin Cheng550ccc52021-03-03 11:21:43 -08001715 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001716
1717 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1718 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001719 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1720 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001721
1722 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001723 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1724 scale_arr[i], scale32
1725 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001726 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1727 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001728
Kevin Cheng550ccc52021-03-03 11:21:43 -08001729 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001730 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001731 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001732 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001733 assert val.placeholderFilename
1734 values = np.load(
1735 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1736 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001737 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1738 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1739 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1740 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001741 if not np.all(np.array_equal(values, val_adj)):
1742 # Values changed so overwrite file with new values
1743 np.save(
1744 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1745 val_adj,
1746 False,
1747 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001748
Matthew Haddonc2025212021-10-08 21:21:05 +01001749 # Invalidate Input/Output list for error if checks.
1750 input_list = [val.name]
1751 output_list = [result_tens.name]
1752 pCount, cCount = op["operands"]
1753 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001754 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1755 self, error_name, input_list, output_list
1756 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001757
1758 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001759 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001760 self.ser,
1761 validator_fcns,
1762 error_name,
1763 op=op,
1764 input_dtype=val.dtype,
1765 output_dtype=out_dtype,
1766 input_shape=val.shape,
1767 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 scale32=scale32,
1769 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001770 input_list=input_list,
1771 output_list=output_list,
1772 result_tensor=result_tens,
1773 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001774 ):
1775 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001776
Eric Kunzee5e26762020-10-13 16:11:07 -07001777 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001778 attr.RescaleAttribute(
1779 input_zp,
1780 output_zp,
1781 multiplier_arr,
1782 shift_arr,
1783 scale32,
1784 double_round,
1785 per_channel,
1786 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001787
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001788 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001789 return result_tens
1790
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001791 def build_cond_if_const(
1792 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1793 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001794 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1795 # (except for the generated shap) and the condition. Build Then/Else blocks
1796 # and fill them with const nodes for the body.
1797
1798 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001799 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001800
1801 # Make then/else tensors
1802 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001803
1804 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001805 if error_name in [
1806 ErrorIf.CondIfOutputListThenGraphMismatch,
1807 ErrorIf.CondIfOutputListElseGraphMismatch,
1808 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001809 incorrect_shape = deepcopy(then_tens.shape)
1810 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001811 incorrect_shape[i] += (
1812 self.rng.choice([-3, -2, 2, 3])
1813 if incorrect_shape[i] > 3
1814 else self.rng.choice([1, 2, 4])
1815 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001816 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1817
Jeremy Johnson18e26662021-07-22 16:15:29 +01001818 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1819 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001820
1821 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001822 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001823
1824 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001825 then_block = "THEN_BLOCK"
1826 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001827 attr = ts.TosaSerializerAttribute()
1828 attr.CondIfAttribute(then_block, else_block)
1829
1830 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001831 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001832
1833 self.ser.startBasicBlock(then_block)
1834 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001835 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1836 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1837 else:
1838 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001839 self.ser.addOutputTensor(then_tens)
1840
1841 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001842 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1843 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1844 else:
1845 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001846 self.ser.addOutputTensor(else_tens)
1847
Les Bell729b0352021-11-24 10:28:21 +00001848 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001849 self.ser,
1850 validator_fcns,
1851 error_name,
1852 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001853 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001854 ):
1855 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001856
Eric Kunzee5e26762020-10-13 16:11:07 -07001857 return result_tens
1858
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001859 def build_cond_if_binary(
1860 self, op, a, b, cond, validator_fcns=None, error_name=None
1861 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001862 # For cond_if with a binary op in the then/else blocks, take a and b and
1863 # alternately add or subtract them based on the condition
1864
1865 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001866 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001867
Kevin Cheng550ccc52021-03-03 11:21:43 -08001868 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001869
1870 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001871 then_block = "THEN_BLOCK"
1872 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001873 attr = ts.TosaSerializerAttribute()
1874 attr.CondIfAttribute(then_block, else_block)
1875
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001876 if error_name in [
1877 ErrorIf.CondIfInputListThenGraphMismatch,
1878 ErrorIf.CondIfInputListElseGraphMismatch,
1879 ErrorIf.CondIfOutputListElseGraphMismatch,
1880 ErrorIf.CondIfOutputListThenGraphMismatch,
1881 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001882 incorrect_shape = a.shape.copy()
1883 for i in range(len(incorrect_shape)):
1884 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1885 incorrect_block_input = deepcopy(a)
1886 incorrect_block_input.shape = incorrect_shape
1887
Eric Kunzee5e26762020-10-13 16:11:07 -07001888 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001889 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001890 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001891 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001892
James Ward8b390432022-08-12 20:48:56 +01001893 if a.dtype in (DType.FLOAT, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001894 then_op, else_op = Op.ADD, Op.SUB
1895 elif a.dtype in (DType.INT8, DType.INT16):
1896 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1897 else:
1898 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001899
Les Bell6040b4d2021-10-11 12:50:31 +01001900 for block, op in ((then_block, then_op), (else_block, else_op)):
1901 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001902 if (
1903 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1904 and block == then_block
1905 ) or (
1906 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1907 and block == else_block
1908 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001909 self.ser.addInputTensor(incorrect_block_input)
1910 self.ser.addInputTensor(b)
1911 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001912 elif (
1913 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1914 and block == then_block
1915 ) or (
1916 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1917 and block == else_block
1918 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001919 self.ser.addInputTensor(a)
1920 self.ser.addInputTensor(b)
1921 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1922 else:
1923 self.ser.addInputTensor(a)
1924 self.ser.addInputTensor(b)
1925 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001926 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001927
Les Bell729b0352021-11-24 10:28:21 +00001928 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001929 self.ser,
1930 validator_fcns,
1931 error_name,
1932 op=op,
1933 a=a,
1934 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001935 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00001936 ):
1937 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001938
Eric Kunzee5e26762020-10-13 16:11:07 -07001939 return result_tens
1940
Matthew Haddon630c17c2021-10-14 15:05:41 +01001941 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001942 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
Kevin Cheng550ccc52021-03-03 11:21:43 -08001944 cond_block = "COND_BLOCK"
1945 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001946
1947 attr = ts.TosaSerializerAttribute()
1948 attr.WhileLoopAttribute(cond_block, body_block)
1949
1950 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001951 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001952 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001953 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001954
1955 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001956 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1957 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001958 if error_name == ErrorIf.InputListOutputListMismatch:
1959 incorrect_acc = deepcopy(acc)
1960 for i in range(len(incorrect_acc.shape)):
1961 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1962 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1963 else:
1964 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001965
1966 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001967 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001968 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001969 [iter.name, a.name, acc.name],
1970 [iter_out.name, a_out.name, acc_out.name],
1971 attr,
1972 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001973 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001974
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001975 if error_name in [
1976 ErrorIf.InputListCondGraphMismatch,
1977 ErrorIf.InputListBodyGraphInputMismatch,
1978 ErrorIf.InputListBodyGraphOutputMismatch,
1979 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001980 incorrect_iter = deepcopy(iter)
1981 for i in range(len(incorrect_iter.shape)):
1982 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
1983 if len(incorrect_iter.shape) == 0:
1984 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
1985
1986 incorrect_acc = deepcopy(acc)
1987 for i in range(len(incorrect_acc.shape)):
1988 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1989
Eric Kunzee5e26762020-10-13 16:11:07 -07001990 # COND block (input: iter, output: cond_tens )
1991 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001992 if error_name == ErrorIf.InputListCondGraphMismatch:
1993 self.ser.addInputTensor(incorrect_iter)
1994 self.ser.addInputTensor(a)
1995 self.ser.addInputTensor(incorrect_acc)
1996 else:
1997 self.ser.addInputTensor(iter)
1998 self.ser.addInputTensor(a)
1999 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002000 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002001
2002 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002003 cond_tens = self.ser.addOutput(
2004 [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
2005 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002006 else:
2007 cond_tens = self.ser.addOutput([], DType.BOOL)
2008
Kevin Cheng550ccc52021-03-03 11:21:43 -08002009 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002010
2011 # BODY block (input: a, acc, iter, output: a, acc, iter)
2012 # Note that local intermediate tensors need to be declared here for the outputs
2013 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002014 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2015 self.ser.addInputTensor(incorrect_iter)
2016 self.ser.addInputTensor(a)
2017 self.ser.addInputTensor(incorrect_acc)
2018 else:
2019 self.ser.addInputTensor(iter)
2020 self.ser.addInputTensor(a)
2021 self.ser.addInputTensor(acc)
2022
Kevin Cheng550ccc52021-03-03 11:21:43 -08002023 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002024
2025 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002026 iter_body_out = self.ser.addIntermediate(
2027 incorrect_iter.shape, incorrect_iter.dtype
2028 )
2029 acc_body_out = self.ser.addIntermediate(
2030 incorrect_acc.shape, incorrect_acc.dtype
2031 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002032 else:
2033 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2034 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2035
Eric Kunzee5e26762020-10-13 16:11:07 -07002036 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2037 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2038 self.ser.addOutputTensor(iter_body_out)
2039 self.ser.addOutputTensor(a)
2040 self.ser.addOutputTensor(acc_body_out)
2041
Les Bell729b0352021-11-24 10:28:21 +00002042 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002043 self.ser,
2044 validator_fcns,
2045 error_name,
2046 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002047 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002048 ):
2049 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002050
Eric Kunzee5e26762020-10-13 16:11:07 -07002051 return acc_out
2052
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002053 def create_filter_lists(
2054 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2055 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002056 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2057 default_test_rank_range = range(1, 5)
2058 if not shapeFilter:
2059 shapeFilter = [None]
2060
2061 # Calculate the filters based on what is requested and what the operator allows
2062 rmin, rmax = op["rank"]
2063 if rankFilter is not None:
2064 cleanRankFilter = []
2065 # Ensure rankFilter values are allowed by operator
2066 for rank in rankFilter:
2067 if rank >= rmin and rank <= rmax:
2068 cleanRankFilter.append(rank)
2069 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002070 # Ensure default behaviour is bounded by default range or by operator,
2071 # whichever is the smaller range of ranks.
2072 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002073 cleanRankFilter = (
2074 opRankRange
2075 if len(opRankRange) <= len(default_test_rank_range)
2076 else default_test_rank_range
2077 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002078 else:
2079 cleanRankFilter = range(rmin, rmax + 1)
2080
2081 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002082
Matthew Haddon1c00b712021-10-01 15:51:03 +01002083 if dtypeFilter is not None:
2084 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002085 # Create list of operator dtypes filtered by requested dtypes
2086 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002087 if dtype in dtypeFilter or (
2088 isinstance(dtype, list) and dtype[0] in dtypeFilter
2089 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002090 cleanDtypeFilter.append(dtype)
2091 else:
2092 cleanDtypeFilter = dtypes
2093
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002094 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002095 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002096 "shapeFilter": shapeFilter,
2097 "rankFilter": cleanRankFilter,
2098 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002099 }
2100 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002101 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002102 if validator is not None:
2103 validator_info = validator(check=False, op=op)
2104 else:
2105 return None
2106
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002107 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002108
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002109 # Set parameters as required
2110 if error_arguments["rank"] is not None:
2111 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002112 else:
2113 rankFilter = cleanRankFilter
2114
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002115 if error_arguments["dtype"] is not None:
2116 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002117 else:
2118 dtypeFilter = cleanDtypeFilter
2119
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002120 if error_arguments["shape"] is not None:
2121 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002122 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002123 shapeFilter = shapeFilter[
2124 :2
2125 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002126
2127 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002128 "shapeFilter": shapeFilter,
2129 "rankFilter": rankFilter,
2130 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002131 }
2132 return filterDict
2133
Kevin Cheng550ccc52021-03-03 11:21:43 -08002134 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002135 self,
2136 opName,
2137 shapeFilter=[None],
2138 rankFilter=None,
2139 dtypeFilter=None,
2140 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002141 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002142
2143 try:
2144 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002145 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002146 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002147
2148 # Initialize a new random number generator
2149 self.rng = np.random.default_rng(self.random_seed)
2150
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002151 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002152
Eric Kunzee5e26762020-10-13 16:11:07 -07002153 # Test list consists of a tuple of:
2154 # (opName, testNameStr, dtype, shapeList, argumentsList)
2155 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002156 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002157 error_if_validators = op["error_if_validators"]
2158 else:
2159 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002160
Matthew Haddon1c00b712021-10-01 15:51:03 +01002161 for validator in error_if_validators:
2162 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002163 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002164 else:
2165 error_name = None
2166
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002167 filterDict = self.create_filter_lists(
2168 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2169 )
2170 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002171 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002172 cleanRankFilter = filterDict["rankFilter"]
2173 cleanDtypeFilter = filterDict["dtypeFilter"]
2174 cleanShapeFilter = filterDict["shapeFilter"]
2175 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002176
2177 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002178 for t in cleanDtypeFilter:
2179 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002180 # Filter out by rank
2181 if shape is not None and len(shape) != r:
2182 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002183 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002184 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002185
Matthew Haddon74567092021-07-16 15:38:20 +01002186 shapeStr = self.shapeStr(shapeList[0])
2187 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002188
Matthew Haddon74567092021-07-16 15:38:20 +01002189 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2190 argList = []
2191 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002192 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002193 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002194 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002195
Matthew Haddon74567092021-07-16 15:38:20 +01002196 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002197 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002198 if argStr:
2199 testStr = "{}_{}_{}_{}".format(
2200 opName, shapeStr, typeStr, argStr
2201 )
2202 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002203 testStr = "{}_{}_{}".format(
2204 opName, shapeStr, typeStr
2205 )
2206 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002207 if argStr:
2208 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2209 opName, error_name, shapeStr, typeStr, argStr
2210 )
2211 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002212 testStr = "{}_ERRORIF_{}_{}_{}".format(
2213 opName, error_name, shapeStr, typeStr
2214 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002215
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002216 testList.append(
2217 (opName, testStr, t, error_name, shapeList, args)
2218 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002219
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002220 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002221 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2222 if "invalid_test_validators" in op:
2223 invalid_test_validators = op["invalid_test_validators"]
2224 clean_testList = []
2225 for test in testList:
2226 for validator_fcn in invalid_test_validators:
2227 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002228 if validator_fcn(
2229 opName=test[0],
2230 input_dtype=test[2],
2231 shapeList=test[4],
2232 args=test[5],
2233 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002234 remove_test = True
2235 if not remove_test:
2236 clean_testList.append(test)
2237 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002238
2239 return testList
2240
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002241 def serializeTest(
2242 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2243 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002244 try:
2245 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002246 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002247 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002248
2249 # Create a serializer
2250 self.createSerializer(opName, testStr)
2251
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002252 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002253 if "error_if_validators" in op:
2254 error_if_validators = op["error_if_validators"]
2255 else:
2256 error_if_validators = None
2257
Kevin Cheng550ccc52021-03-03 11:21:43 -08002258 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002259 num_operands = pCount + cCount
2260
2261 if isinstance(dtype_or_dtypeList, list):
2262 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002263 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002264 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002265 else:
2266 dtypeList = [dtype_or_dtypeList] * (num_operands)
2267
Kevin Cheng93a16282021-08-31 16:14:03 -07002268 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002269 assert (
2270 len(shapeList) == num_operands
2271 ), "shapeList length {} must match number of operands {}".format(
2272 len(shapeList), num_operands
2273 )
2274 assert (
2275 len(dtypeList) == num_operands
2276 ), "dtypeList length {} must match number of operands {}".format(
2277 len(dtypeList), num_operands
2278 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002279
2280 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002281 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002282 except KeyError:
2283 qgen = None
2284
2285 # Build the random tensor operands and the test
2286 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002287
Matthew Haddon1c00b712021-10-01 15:51:03 +01002288 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002289 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002290 else:
2291 qinfo = None
2292
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002293 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002294
Matthew Haddon1c00b712021-10-01 15:51:03 +01002295 try:
2296 if error_if_validators is None:
2297 if qinfo is not None:
2298 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2299 else:
2300 resultName = build_fcn(self, op, *tens, *testArgs)
2301 else:
2302 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002303 resultName = build_fcn(
2304 self,
2305 op,
2306 *tens,
2307 *testArgs,
2308 validator_fcns=error_if_validators,
2309 error_name=error_name,
2310 qinfo=qinfo,
2311 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002312 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002313 resultName = build_fcn(
2314 self,
2315 op,
2316 *tens,
2317 *testArgs,
2318 validator_fcns=error_if_validators,
2319 error_name=error_name,
2320 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002321 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002322 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002323 raise e
2324
Les Bell729b0352021-11-24 10:28:21 +00002325 if resultName:
2326 # The test is valid, serialize it
2327 self.serialize("test")
2328 else:
2329 # The test is not valid
2330 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002331
Eric Kunzee5e26762020-10-13 16:11:07 -07002332 def createDynamicOpLists(self):
2333
Jeremy Johnson00423432022-09-12 17:27:37 +01002334 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2335 # Already created these lists (can occur when class is initialized more than once)
2336 return
2337
Eric Kunzee5e26762020-10-13 16:11:07 -07002338 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002339 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002340
Kevin Cheng1533b852021-09-01 12:51:58 -07002341 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002342 testName = "conv2d_{}x{}".format(k[0], k[1])
2343 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2344 self.TOSA_OP_LIST[testName]["filter"] = k
2345 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002346
Kevin Cheng550ccc52021-03-03 11:21:43 -08002347 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2348 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2349 "depthwise_conv2d_TEMPLATE"
2350 ].copy()
2351 self.TOSA_OP_LIST[testName]["filter"] = k
2352 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002353
Kevin Cheng550ccc52021-03-03 11:21:43 -08002354 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2355 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2356 "transpose_conv2d_TEMPLATE"
2357 ].copy()
2358 self.TOSA_OP_LIST[testName]["filter"] = k
2359 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002360
Kevin Cheng1533b852021-09-01 12:51:58 -07002361 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2362 for k in KERNELS_3D:
2363 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2364 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2365 self.TOSA_OP_LIST[testName]["filter"] = k
2366 self.TOSA_OP_LIST[testName]["template"] = False
2367
Eric Kunzee5e26762020-10-13 16:11:07 -07002368 # Delete any templates after having created any dynamic ops
2369 # This is a two-pass operation because it's bad practice to delete
2370 # keys from dictionaries while iterating
2371 keyList = []
2372 for k in self.TOSA_OP_LIST:
2373 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002374 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002375 keyList.append(k)
2376 continue
2377 except KeyError:
2378 pass
2379
2380 for k in keyList:
2381 del self.TOSA_OP_LIST[k]
2382
2383 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002384 """Fill in default fields for ops if they aren't already specified.
2385 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002386 for op in self.TOSA_OP_LIST:
2387
2388 # Required fields
2389 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002390 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002391 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002392 raise Exception(
2393 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2394 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002395
2396 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002397 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002398 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002399 raise Exception(
2400 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2401 op
2402 )
2403 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002404
2405 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002406 _ = self.TOSA_OP_LIST[op]["types"]
2407 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002408 raise Exception(
2409 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2410 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002411
2412 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002413 _ = self.TOSA_OP_LIST[op]["op"]
2414 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002415 raise Exception(
2416 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2417 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002418
2419 # Put in default rank range, if missing
2420 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002421 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002422 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002423 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002424
2425 # Tensor operator list
2426 # 'op': op name
2427 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002428 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2429 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002430 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2431 # 'types': array of datatypes to be tested
James Ward8b390432022-08-12 20:48:56 +01002432 TYPE_FP = [DType.FLOAT, DType.FP16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002433
Kevin Cheng550ccc52021-03-03 11:21:43 -08002434 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002435 TYPE_INT_FP = [
2436 DType.INT8,
2437 DType.INT16,
2438 DType.INT32,
2439 DType.FP16,
2440 DType.FLOAT,
2441 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002442
Kevin Cheng550ccc52021-03-03 11:21:43 -08002443 TYPE_BOOL = [DType.BOOL]
James Ward8b390432022-08-12 20:48:56 +01002444 TYPE_FI32 = [DType.FLOAT, DType.FP16, DType.INT32] # floating-types and INT32
2445 TYPE_FIB = [
2446 DType.FP16,
2447 DType.FLOAT,
2448 DType.INT8,
2449 DType.INT16,
2450 DType.INT32,
2451 DType.BOOL,
2452 ]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002453 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002454
James Ward8b390432022-08-12 20:48:56 +01002455 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07002456
Kevin Cheng1533b852021-09-01 12:51:58 -07002457 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002458 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002459 [DType.INT8, DType.INT8, DType.INT32],
2460 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002461 [DType.FP16, DType.FP16, DType.FP16],
2462 [DType.FP16, DType.FP16, DType.FLOAT],
Kevin Cheng989cb052021-04-28 16:29:44 -07002463 DType.FLOAT,
2464 ]
2465
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002466 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002467
2468 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002469 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002470 "argmax": {
2471 "op": Op.ARGMAX,
2472 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002473 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002474 "build_fcn": (
2475 build_argmax,
2476 TosaTensorGen.tgBasic,
2477 TosaTensorValuesGen.tvgDefault,
2478 TosaArgGen.agAxis,
2479 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002480 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002481 "error_if_validators": (
2482 TosaErrorValidator.evAxisSmallerZero,
2483 TosaErrorValidator.evAxisLargerRank,
2484 TosaErrorValidator.evArgmaxOutputRankMismatch,
2485 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2486 TosaErrorValidator.evWrongRank,
2487 TosaErrorValidator.evWrongInputType,
2488 TosaErrorValidator.evWrongOutputType,
2489 TosaErrorValidator.evWrongInputList,
2490 TosaErrorValidator.evWrongOutputList,
2491 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002492 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002493 "avg_pool2d": {
2494 "op": Op.AVG_POOL2D,
2495 "operands": (1, 0),
2496 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002497 "build_fcn": (
2498 build_pool2d,
2499 TosaTensorGen.tgNHWC,
2500 TosaTensorValuesGen.tvgDefault,
2501 TosaArgGen.agPooling,
2502 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002503 "qgen": TosaQuantGen.qgUnary,
2504 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002505 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002506 "error_if_validators": (
2507 TosaErrorValidator.evKernelSmallerOne,
2508 TosaErrorValidator.evStrideSmallerOne,
2509 TosaErrorValidator.evPadSmallerZero,
2510 TosaErrorValidator.evWrongRank,
2511 TosaErrorValidator.evWrongInputType,
2512 TosaErrorValidator.evWrongOutputType,
2513 TosaErrorValidator.evWrongInputList,
2514 TosaErrorValidator.evWrongOutputList,
2515 TosaErrorValidator.evInputZeroPointNotZero,
2516 TosaErrorValidator.evOutputZeroPointNotZero,
2517 TosaErrorValidator.evPadLargerEqualKernel,
2518 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002519 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002520 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002521 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002522 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002523 "conv2d_TEMPLATE": {
2524 "op": Op.CONV2D,
2525 "operands": (1, 2),
2526 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002527 "build_fcn": (
2528 build_conv2d,
2529 TosaTensorGen.tgConv2D,
2530 TosaTensorValuesGen.tvgDefault,
2531 TosaArgGen.agConv,
2532 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002533 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002534 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002535 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2536 "error_if_validators": (
2537 TosaErrorValidator.evWrongInputType,
2538 TosaErrorValidator.evWrongOutputType,
2539 TosaErrorValidator.evWrongInputList,
2540 TosaErrorValidator.evWrongOutputList,
2541 TosaErrorValidator.evInputZeroPointNotZero,
2542 TosaErrorValidator.evWeightZeroPointNotZero,
2543 TosaErrorValidator.evPadSmallerZero,
2544 TosaErrorValidator.evStrideSmallerOne,
2545 TosaErrorValidator.evDilationSmallerOne,
2546 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002547 TosaErrorValidator.evConvOutputShapeMismatch,
2548 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002549 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002550 "template": True,
2551 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002552 # Templated operator. Filled in by createDynamicOpLists
2553 "conv3d_TEMPLATE": {
2554 "op": Op.CONV3D,
2555 "operands": (1, 2),
2556 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002557 "build_fcn": (
2558 build_conv3d,
2559 TosaTensorGen.tgConv3D,
2560 TosaTensorValuesGen.tvgDefault,
2561 TosaArgGen.agConv,
2562 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002563 "qgen": TosaQuantGen.qgConv,
2564 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002565 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2566 "error_if_validators": (
2567 TosaErrorValidator.evWrongInputType,
2568 TosaErrorValidator.evWrongOutputType,
2569 TosaErrorValidator.evWrongInputList,
2570 TosaErrorValidator.evWrongOutputList,
2571 TosaErrorValidator.evInputZeroPointNotZero,
2572 TosaErrorValidator.evWeightZeroPointNotZero,
2573 TosaErrorValidator.evPadSmallerZero,
2574 TosaErrorValidator.evStrideSmallerOne,
2575 TosaErrorValidator.evDilationSmallerOne,
2576 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002577 TosaErrorValidator.evConvOutputShapeMismatch,
2578 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002579 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002580 "template": True,
2581 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002582 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002583 "depthwise_conv2d_TEMPLATE": {
2584 "op": Op.DEPTHWISE_CONV2D,
2585 "operands": (1, 2),
2586 "filter": [1, 1],
2587 "rank": (4, 4),
2588 "build_fcn": (
2589 build_depthwise_conv2d,
2590 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002591 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002592 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002593 ),
2594 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002595 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002596 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2597 "error_if_validators": (
2598 TosaErrorValidator.evWrongInputType,
2599 TosaErrorValidator.evWrongOutputType,
2600 TosaErrorValidator.evWrongInputList,
2601 TosaErrorValidator.evWrongOutputList,
2602 TosaErrorValidator.evInputZeroPointNotZero,
2603 TosaErrorValidator.evWeightZeroPointNotZero,
2604 TosaErrorValidator.evPadSmallerZero,
2605 TosaErrorValidator.evStrideSmallerOne,
2606 TosaErrorValidator.evDilationSmallerOne,
2607 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002608 TosaErrorValidator.evConvOutputShapeMismatch,
2609 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002610 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002611 "template": True,
2612 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002613 "fully_connected": {
2614 "op": Op.FULLY_CONNECTED,
2615 "operands": (1, 2),
2616 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002617 "build_fcn": (
2618 build_fully_connected,
2619 TosaTensorGen.tgFullyConnected,
2620 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002621 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002622 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002623 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002624 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002625 "error_if_validators": (
2626 TosaErrorValidator.evInputZeroPointNotZero,
2627 TosaErrorValidator.evWeightZeroPointNotZero,
2628 TosaErrorValidator.evWrongRank,
2629 TosaErrorValidator.evWrongInputType,
2630 TosaErrorValidator.evWrongOutputType,
2631 TosaErrorValidator.evWrongInputList,
2632 TosaErrorValidator.evWrongOutputList,
2633 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002634 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002635 "matmul": {
2636 "op": Op.MATMUL,
2637 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002638 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002639 "build_fcn": (
2640 build_matmul,
2641 TosaTensorGen.tgMatmul,
2642 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002643 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002644 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002645 "qgen": TosaQuantGen.qgMatmul,
2646 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002647 "error_if_validators": (
2648 TosaErrorValidator.evInputZeroPointNotZero,
2649 TosaErrorValidator.evWrongRank,
2650 TosaErrorValidator.evWrongInputType,
2651 TosaErrorValidator.evWrongOutputType,
2652 TosaErrorValidator.evWrongInputList,
2653 TosaErrorValidator.evWrongOutputList,
2654 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002655 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002656 "max_pool2d": {
2657 "op": Op.MAX_POOL2D,
2658 "operands": (1, 0),
2659 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002660 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002661 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002662 TosaTensorGen.tgNHWC,
2663 TosaTensorValuesGen.tvgDefault,
2664 TosaArgGen.agPooling,
2665 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002666 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002667 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002668 "error_if_validators": (
2669 TosaErrorValidator.evKernelSmallerOne,
2670 TosaErrorValidator.evStrideSmallerOne,
2671 TosaErrorValidator.evPadSmallerZero,
2672 TosaErrorValidator.evWrongRank,
2673 TosaErrorValidator.evWrongInputType,
2674 TosaErrorValidator.evWrongOutputType,
2675 TosaErrorValidator.evWrongInputList,
2676 TosaErrorValidator.evWrongOutputList,
2677 TosaErrorValidator.evPadLargerEqualKernel,
2678 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002679 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002680 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002681 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002682 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002683 "transpose_conv2d_TEMPLATE": {
2684 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002685 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002686 "rank": (4, 4),
2687 "build_fcn": (
2688 build_transpose_conv2d,
2689 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002690 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002691 TosaArgGen.agTransposeConv2D,
2692 ),
2693 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002694 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002695 "invalid_test_validators": (
2696 TosaInvalidValidator.ivHeightWidthInvalid,
2697 TosaInvalidValidator.ivNonPositiveOutputShape,
2698 ),
2699 "error_if_validators": (
2700 TosaErrorValidator.evWrongInputType,
2701 TosaErrorValidator.evWrongOutputType,
2702 TosaErrorValidator.evWrongInputList,
2703 TosaErrorValidator.evWrongOutputList,
2704 TosaErrorValidator.evInputZeroPointNotZero,
2705 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002706 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002707 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002708 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002709 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002710 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002711 "template": True,
2712 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002713 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002714 "clamp": {
2715 "op": Op.CLAMP,
2716 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002717 "build_fcn": (
2718 build_clamp,
2719 TosaTensorGen.tgBasic,
2720 TosaTensorValuesGen.tvgDefault,
2721 None,
2722 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002723 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002724 "error_if_validators": (
2725 TosaErrorValidator.evMaxSmallerMin,
2726 TosaErrorValidator.evWrongInputType,
2727 TosaErrorValidator.evWrongOutputType,
2728 TosaErrorValidator.evWrongInputList,
2729 TosaErrorValidator.evWrongOutputList,
2730 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002731 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002732 "sigmoid": {
2733 "op": Op.SIGMOID,
2734 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002735 "build_fcn": (
2736 build_sigmoid,
2737 TosaTensorGen.tgBasic,
2738 TosaTensorValuesGen.tvgDefault,
2739 None,
2740 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002741 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002742 "error_if_validators": (
2743 TosaErrorValidator.evWrongInputType,
2744 TosaErrorValidator.evWrongOutputType,
2745 TosaErrorValidator.evWrongInputList,
2746 TosaErrorValidator.evWrongOutputList,
2747 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002748 },
2749 "tanh": {
2750 "op": Op.TANH,
2751 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002752 "build_fcn": (
2753 build_tanh,
2754 TosaTensorGen.tgBasic,
2755 TosaTensorValuesGen.tvgDefault,
2756 None,
2757 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002758 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002759 "error_if_validators": (
2760 TosaErrorValidator.evWrongInputType,
2761 TosaErrorValidator.evWrongOutputType,
2762 TosaErrorValidator.evWrongInputList,
2763 TosaErrorValidator.evWrongOutputList,
2764 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002765 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002766 # Elementwise Binary Operators
2767 "add": {
2768 "op": Op.ADD,
2769 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002770 "build_fcn": (
2771 build_binary_broadcast,
2772 TosaTensorGen.tgBroadcastFuzz,
2773 TosaTensorValuesGen.tvgAddSub,
2774 None,
2775 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002776 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002777 "error_if_validators": (
2778 TosaErrorValidator.evRankMismatch,
2779 TosaErrorValidator.evWrongInputType,
2780 TosaErrorValidator.evWrongOutputType,
2781 TosaErrorValidator.evWrongInputList,
2782 TosaErrorValidator.evWrongOutputList,
2783 TosaErrorValidator.evDimensionMismatch,
2784 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002785 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002786 "arithmetic_right_shift": {
2787 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2788 "operands": (2, 0),
2789 "build_fcn": (
2790 build_arithmetic_right_shift,
2791 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002792 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002793 TosaArgGen.agArithmeticRightShift,
2794 ),
2795 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002796 "error_if_validators": (
2797 TosaErrorValidator.evRankMismatch,
2798 TosaErrorValidator.evWrongInputType,
2799 TosaErrorValidator.evWrongOutputType,
2800 TosaErrorValidator.evWrongInputList,
2801 TosaErrorValidator.evWrongOutputList,
2802 TosaErrorValidator.evDimensionMismatch,
2803 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002804 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002805 "bitwise_and": {
2806 "op": Op.BITWISE_AND,
2807 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002808 "build_fcn": (
2809 build_binary_broadcast,
2810 TosaTensorGen.tgBroadcastFuzz,
2811 TosaTensorValuesGen.tvgDefault,
2812 None,
2813 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002814 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002815 "error_if_validators": (
2816 TosaErrorValidator.evRankMismatch,
2817 TosaErrorValidator.evWrongInputType,
2818 TosaErrorValidator.evWrongOutputType,
2819 TosaErrorValidator.evWrongInputList,
2820 TosaErrorValidator.evWrongOutputList,
2821 TosaErrorValidator.evDimensionMismatch,
2822 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002823 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002824 "bitwise_or": {
2825 "op": Op.BITWISE_OR,
2826 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002827 "build_fcn": (
2828 build_binary_broadcast,
2829 TosaTensorGen.tgBroadcastFuzz,
2830 TosaTensorValuesGen.tvgDefault,
2831 None,
2832 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002833 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002834 "error_if_validators": (
2835 TosaErrorValidator.evRankMismatch,
2836 TosaErrorValidator.evWrongInputType,
2837 TosaErrorValidator.evWrongOutputType,
2838 TosaErrorValidator.evWrongInputList,
2839 TosaErrorValidator.evWrongOutputList,
2840 TosaErrorValidator.evDimensionMismatch,
2841 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002842 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002843 "bitwise_xor": {
2844 "op": Op.BITWISE_XOR,
2845 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002846 "build_fcn": (
2847 build_binary_broadcast,
2848 TosaTensorGen.tgBroadcastFuzz,
2849 TosaTensorValuesGen.tvgDefault,
2850 None,
2851 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002852 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002853 "error_if_validators": (
2854 TosaErrorValidator.evRankMismatch,
2855 TosaErrorValidator.evWrongInputType,
2856 TosaErrorValidator.evWrongOutputType,
2857 TosaErrorValidator.evWrongInputList,
2858 TosaErrorValidator.evWrongOutputList,
2859 TosaErrorValidator.evDimensionMismatch,
2860 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002861 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002862 "intdiv": {
2863 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002864 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002865 "build_fcn": (
2866 build_binary_broadcast,
2867 TosaTensorGen.tgBroadcastFuzz,
2868 TosaTensorValuesGen.tvgIntDiv,
2869 None,
2870 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002871 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002872 "error_if_validators": (
2873 TosaErrorValidator.evRankMismatch,
2874 TosaErrorValidator.evWrongInputType,
2875 TosaErrorValidator.evWrongOutputType,
2876 TosaErrorValidator.evWrongInputList,
2877 TosaErrorValidator.evWrongOutputList,
2878 TosaErrorValidator.evDimensionMismatch,
2879 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002880 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002881 "logical_and": {
2882 "op": Op.LOGICAL_AND,
2883 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002884 "build_fcn": (
2885 build_binary_broadcast,
2886 TosaTensorGen.tgBroadcastFuzz,
2887 TosaTensorValuesGen.tvgDefault,
2888 None,
2889 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002890 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002891 "error_if_validators": (
2892 TosaErrorValidator.evRankMismatch,
2893 TosaErrorValidator.evWrongInputType,
2894 TosaErrorValidator.evWrongOutputType,
2895 TosaErrorValidator.evWrongInputList,
2896 TosaErrorValidator.evWrongOutputList,
2897 TosaErrorValidator.evDimensionMismatch,
2898 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002899 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002900 "logical_left_shift": {
2901 "op": Op.LOGICAL_LEFT_SHIFT,
2902 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002903 "build_fcn": (
2904 build_binary_broadcast,
2905 TosaTensorGen.tgBroadcastFuzz,
2906 TosaTensorValuesGen.tvgLogicalShift,
2907 None,
2908 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002909 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002910 "error_if_validators": (
2911 TosaErrorValidator.evRankMismatch,
2912 TosaErrorValidator.evWrongInputType,
2913 TosaErrorValidator.evWrongOutputType,
2914 TosaErrorValidator.evWrongInputList,
2915 TosaErrorValidator.evWrongOutputList,
2916 TosaErrorValidator.evDimensionMismatch,
2917 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002918 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002919 "logical_right_shift": {
2920 "op": Op.LOGICAL_RIGHT_SHIFT,
2921 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002922 "build_fcn": (
2923 build_binary_broadcast,
2924 TosaTensorGen.tgBroadcastFuzz,
2925 TosaTensorValuesGen.tvgLogicalShift,
2926 None,
2927 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002928 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002929 "error_if_validators": (
2930 TosaErrorValidator.evRankMismatch,
2931 TosaErrorValidator.evWrongInputType,
2932 TosaErrorValidator.evWrongOutputType,
2933 TosaErrorValidator.evWrongInputList,
2934 TosaErrorValidator.evWrongOutputList,
2935 TosaErrorValidator.evDimensionMismatch,
2936 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002937 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002938 "logical_or": {
2939 "op": Op.LOGICAL_OR,
2940 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002941 "build_fcn": (
2942 build_binary_broadcast,
2943 TosaTensorGen.tgBroadcastFuzz,
2944 TosaTensorValuesGen.tvgDefault,
2945 None,
2946 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002947 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002948 "error_if_validators": (
2949 TosaErrorValidator.evRankMismatch,
2950 TosaErrorValidator.evWrongInputType,
2951 TosaErrorValidator.evWrongOutputType,
2952 TosaErrorValidator.evWrongInputList,
2953 TosaErrorValidator.evWrongOutputList,
2954 TosaErrorValidator.evDimensionMismatch,
2955 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002956 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002957 "logical_xor": {
2958 "op": Op.LOGICAL_XOR,
2959 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002960 "build_fcn": (
2961 build_binary_broadcast,
2962 TosaTensorGen.tgBroadcastFuzz,
2963 TosaTensorValuesGen.tvgDefault,
2964 None,
2965 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002966 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002967 "error_if_validators": (
2968 TosaErrorValidator.evRankMismatch,
2969 TosaErrorValidator.evWrongInputType,
2970 TosaErrorValidator.evWrongOutputType,
2971 TosaErrorValidator.evWrongInputList,
2972 TosaErrorValidator.evWrongOutputList,
2973 TosaErrorValidator.evDimensionMismatch,
2974 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002975 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002976 "maximum": {
2977 "op": Op.MAXIMUM,
2978 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002979 "build_fcn": (
2980 build_binary_broadcast,
2981 TosaTensorGen.tgBroadcastFuzz,
2982 TosaTensorValuesGen.tvgDefault,
2983 None,
2984 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002985 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002986 "error_if_validators": (
2987 TosaErrorValidator.evRankMismatch,
2988 TosaErrorValidator.evWrongInputType,
2989 TosaErrorValidator.evWrongOutputType,
2990 TosaErrorValidator.evWrongInputList,
2991 TosaErrorValidator.evWrongOutputList,
2992 TosaErrorValidator.evDimensionMismatch,
2993 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002994 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002995 "minimum": {
2996 "op": Op.MINIMUM,
2997 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002998 "build_fcn": (
2999 build_binary_broadcast,
3000 TosaTensorGen.tgBroadcastFuzz,
3001 TosaTensorValuesGen.tvgDefault,
3002 None,
3003 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003004 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003005 "error_if_validators": (
3006 TosaErrorValidator.evRankMismatch,
3007 TosaErrorValidator.evWrongInputType,
3008 TosaErrorValidator.evWrongOutputType,
3009 TosaErrorValidator.evWrongInputList,
3010 TosaErrorValidator.evWrongOutputList,
3011 TosaErrorValidator.evDimensionMismatch,
3012 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003013 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003014 "mul": {
3015 "op": Op.MUL,
3016 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003017 "build_fcn": (
3018 build_mul,
3019 TosaTensorGen.tgBroadcastFuzz,
3020 TosaTensorValuesGen.tvgMul,
3021 TosaArgGen.agMul,
3022 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003023 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003024 "error_if_validators": (
3025 TosaErrorValidator.evWrongInputType,
3026 TosaErrorValidator.evWrongOutputType,
3027 TosaErrorValidator.evWrongInputList,
3028 TosaErrorValidator.evWrongOutputList,
3029 TosaErrorValidator.evRankMismatch,
3030 TosaErrorValidator.evDimensionMismatch,
3031 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003032 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003033 "pow": {
3034 "op": Op.POW,
3035 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003036 "build_fcn": (
3037 build_binary_broadcast,
3038 TosaTensorGen.tgBroadcastFuzz,
3039 TosaTensorValuesGen.tvgDefault,
3040 None,
3041 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003042 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003043 "error_if_validators": (
3044 TosaErrorValidator.evRankMismatch,
3045 TosaErrorValidator.evWrongInputType,
3046 TosaErrorValidator.evWrongOutputType,
3047 TosaErrorValidator.evWrongInputList,
3048 TosaErrorValidator.evWrongOutputList,
3049 TosaErrorValidator.evDimensionMismatch,
3050 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003051 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003052 "sub": {
3053 "op": Op.SUB,
3054 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003055 "build_fcn": (
3056 build_binary_broadcast,
3057 TosaTensorGen.tgBroadcastFuzz,
3058 TosaTensorValuesGen.tvgAddSub,
3059 None,
3060 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003061 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003062 "error_if_validators": (
3063 TosaErrorValidator.evRankMismatch,
3064 TosaErrorValidator.evWrongInputType,
3065 TosaErrorValidator.evWrongOutputType,
3066 TosaErrorValidator.evWrongInputList,
3067 TosaErrorValidator.evWrongOutputList,
3068 TosaErrorValidator.evDimensionMismatch,
3069 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003070 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003071 "table": {
3072 "op": Op.TABLE,
3073 # Use the automatic generation functions to create the input array
3074 # but create the table tensor in the build function, as it may be
3075 # a different type from the input
3076 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003077 "build_fcn": (
3078 build_table,
3079 TosaTensorGen.tgBasic,
3080 TosaTensorValuesGen.tvgDefault,
3081 TosaArgGen.agTable,
3082 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003083 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003084 "error_if_validators": (
3085 TosaErrorValidator.evWrongInputType,
3086 TosaErrorValidator.evWrongOutputType,
3087 TosaErrorValidator.evWrongInputList,
3088 TosaErrorValidator.evWrongOutputList,
3089 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003090 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003091 # Elementwise Unary operators
3092 "abs": {
3093 "op": Op.ABS,
3094 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003095 "build_fcn": (
3096 build_unary,
3097 TosaTensorGen.tgBasic,
3098 TosaTensorValuesGen.tvgDefault,
3099 None,
3100 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003101 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003102 "error_if_validators": (
3103 TosaErrorValidator.evWrongInputType,
3104 TosaErrorValidator.evWrongOutputType,
3105 TosaErrorValidator.evWrongInputList,
3106 TosaErrorValidator.evWrongOutputList,
3107 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003108 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003109 "bitwise_not": {
3110 "op": Op.BITWISE_NOT,
3111 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003112 "build_fcn": (
3113 build_unary,
3114 TosaTensorGen.tgBasic,
3115 TosaTensorValuesGen.tvgDefault,
3116 None,
3117 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003118 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003119 "error_if_validators": (
3120 TosaErrorValidator.evWrongInputType,
3121 TosaErrorValidator.evWrongOutputType,
3122 TosaErrorValidator.evWrongInputList,
3123 TosaErrorValidator.evWrongOutputList,
3124 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003125 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 "ceil": {
3127 "op": Op.CEIL,
3128 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003129 "build_fcn": (
3130 build_unary,
3131 TosaTensorGen.tgBasic,
3132 TosaTensorValuesGen.tvgDefault,
3133 None,
3134 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003135 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003136 "error_if_validators": (
3137 TosaErrorValidator.evWrongInputType,
3138 TosaErrorValidator.evWrongOutputType,
3139 TosaErrorValidator.evWrongInputList,
3140 TosaErrorValidator.evWrongOutputList,
3141 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003142 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003143 "clz": {
3144 "op": Op.CLZ,
3145 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003146 "build_fcn": (
3147 build_unary,
3148 TosaTensorGen.tgBasic,
3149 TosaTensorValuesGen.tvgDefault,
3150 None,
3151 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003152 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003153 "error_if_validators": (
3154 TosaErrorValidator.evWrongInputType,
3155 TosaErrorValidator.evWrongOutputType,
3156 TosaErrorValidator.evWrongInputList,
3157 TosaErrorValidator.evWrongOutputList,
3158 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003159 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003160 "exp": {
3161 "op": Op.EXP,
3162 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003163 "build_fcn": (
3164 build_unary,
3165 TosaTensorGen.tgBasic,
3166 TosaTensorValuesGen.tvgDefault,
3167 None,
3168 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003169 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003170 "error_if_validators": (
3171 TosaErrorValidator.evWrongInputType,
3172 TosaErrorValidator.evWrongOutputType,
3173 TosaErrorValidator.evWrongInputList,
3174 TosaErrorValidator.evWrongOutputList,
3175 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003177 "floor": {
3178 "op": Op.FLOOR,
3179 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003180 "build_fcn": (
3181 build_unary,
3182 TosaTensorGen.tgBasic,
3183 TosaTensorValuesGen.tvgDefault,
3184 None,
3185 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003186 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003187 "error_if_validators": (
3188 TosaErrorValidator.evWrongInputType,
3189 TosaErrorValidator.evWrongOutputType,
3190 TosaErrorValidator.evWrongInputList,
3191 TosaErrorValidator.evWrongOutputList,
3192 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003193 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003194 "log": {
3195 "op": Op.LOG,
3196 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003197 "build_fcn": (
3198 build_unary,
3199 TosaTensorGen.tgBasic,
3200 TosaTensorValuesGen.tvgDefault,
3201 None,
3202 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003203 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003204 "error_if_validators": (
3205 TosaErrorValidator.evWrongInputType,
3206 TosaErrorValidator.evWrongOutputType,
3207 TosaErrorValidator.evWrongInputList,
3208 TosaErrorValidator.evWrongOutputList,
3209 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003210 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003211 "logical_not": {
3212 "op": Op.LOGICAL_NOT,
3213 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003214 "build_fcn": (
3215 build_unary,
3216 TosaTensorGen.tgBasic,
3217 TosaTensorValuesGen.tvgDefault,
3218 None,
3219 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003220 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003221 "error_if_validators": (
3222 TosaErrorValidator.evWrongInputType,
3223 TosaErrorValidator.evWrongOutputType,
3224 TosaErrorValidator.evWrongInputList,
3225 TosaErrorValidator.evWrongOutputList,
3226 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003227 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003228 "negate": {
3229 "op": Op.NEGATE,
3230 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003231 "build_fcn": (
3232 build_unary,
3233 TosaTensorGen.tgBasic,
3234 TosaTensorValuesGen.tvgNegate,
3235 None,
3236 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003237 "qgen": TosaQuantGen.qgUnary,
3238 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003239 "error_if_validators": (
3240 TosaErrorValidator.evInputZeroPointNotZero,
3241 TosaErrorValidator.evOutputZeroPointNotZero,
3242 TosaErrorValidator.evWrongInputType,
3243 TosaErrorValidator.evWrongOutputType,
3244 TosaErrorValidator.evWrongInputList,
3245 TosaErrorValidator.evWrongOutputList,
3246 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003247 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003248 "reciprocal": {
3249 "op": Op.RECIPROCAL,
3250 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003251 "build_fcn": (
3252 build_unary,
3253 TosaTensorGen.tgBasic,
3254 TosaTensorValuesGen.tvgDefault,
3255 None,
3256 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003257 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003258 "error_if_validators": (
3259 TosaErrorValidator.evWrongInputType,
3260 TosaErrorValidator.evWrongOutputType,
3261 TosaErrorValidator.evWrongInputList,
3262 TosaErrorValidator.evWrongOutputList,
3263 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003264 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003265 "rsqrt": {
3266 "op": Op.RSQRT,
3267 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003268 "build_fcn": (
3269 build_unary,
3270 TosaTensorGen.tgBasic,
3271 TosaTensorValuesGen.tvgDefault,
3272 None,
3273 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003274 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003275 "error_if_validators": (
3276 TosaErrorValidator.evWrongInputType,
3277 TosaErrorValidator.evWrongOutputType,
3278 TosaErrorValidator.evWrongInputList,
3279 TosaErrorValidator.evWrongOutputList,
3280 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003281 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003282 # Elementwise Ternary operators
3283 "select": {
3284 "op": Op.SELECT,
3285 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003286 "build_fcn": (
3287 build_select,
3288 TosaTensorGen.tgBroadcastFuzz,
3289 TosaTensorValuesGen.tvgSelect,
3290 None,
3291 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003292 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003293 "error_if_validators": (
3294 TosaErrorValidator.evRankMismatch,
3295 TosaErrorValidator.evWrongInputType,
3296 TosaErrorValidator.evWrongOutputType,
3297 TosaErrorValidator.evWrongInputList,
3298 TosaErrorValidator.evWrongOutputList,
3299 TosaErrorValidator.evDimensionMismatch,
3300 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003301 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003302 # Comparison operators
3303 "equal": {
3304 "op": Op.EQUAL,
3305 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003306 "build_fcn": (
3307 build_comparison,
3308 TosaTensorGen.tgBroadcastFuzz,
3309 TosaTensorValuesGen.tvgEqual,
3310 None,
3311 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003312 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003313 "error_if_validators": (
3314 TosaErrorValidator.evRankMismatch,
3315 TosaErrorValidator.evWrongInputType,
3316 TosaErrorValidator.evWrongOutputType,
3317 TosaErrorValidator.evWrongInputList,
3318 TosaErrorValidator.evWrongOutputList,
3319 TosaErrorValidator.evDimensionMismatch,
3320 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003321 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003322 "greater_equal": {
3323 "op": Op.GREATER_EQUAL,
3324 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003325 "build_fcn": (
3326 build_comparison,
3327 TosaTensorGen.tgBroadcastFuzz,
3328 TosaTensorValuesGen.tvgDefault,
3329 None,
3330 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003331 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003332 "error_if_validators": (
3333 TosaErrorValidator.evRankMismatch,
3334 TosaErrorValidator.evWrongInputType,
3335 TosaErrorValidator.evWrongOutputType,
3336 TosaErrorValidator.evWrongInputList,
3337 TosaErrorValidator.evWrongOutputList,
3338 TosaErrorValidator.evDimensionMismatch,
3339 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003340 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003341 "greater": {
3342 "op": Op.GREATER,
3343 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003344 "build_fcn": (
3345 build_comparison,
3346 TosaTensorGen.tgBroadcastFuzz,
3347 TosaTensorValuesGen.tvgDefault,
3348 None,
3349 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003350 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003351 "error_if_validators": (
3352 TosaErrorValidator.evRankMismatch,
3353 TosaErrorValidator.evWrongInputType,
3354 TosaErrorValidator.evWrongOutputType,
3355 TosaErrorValidator.evWrongInputList,
3356 TosaErrorValidator.evWrongOutputList,
3357 TosaErrorValidator.evDimensionMismatch,
3358 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003359 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 # Reduction operators
3361 "reduce_all": {
3362 "op": Op.REDUCE_ALL,
3363 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003364 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003365 "build_fcn": (
3366 build_reduce,
3367 TosaTensorGen.tgBasic,
3368 TosaTensorValuesGen.tvgDefault,
3369 TosaArgGen.agAxis,
3370 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003371 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003372 "error_if_validators": (
3373 TosaErrorValidator.evAxisLargerRank,
3374 TosaErrorValidator.evAxisSmallerZero,
3375 TosaErrorValidator.evShapeOfAxisNotOne,
3376 TosaErrorValidator.evWrongInputType,
3377 TosaErrorValidator.evWrongOutputType,
3378 TosaErrorValidator.evWrongRank,
3379 TosaErrorValidator.evWrongInputList,
3380 TosaErrorValidator.evWrongOutputList,
3381 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003382 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003383 "reduce_any": {
3384 "op": Op.REDUCE_ANY,
3385 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003386 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003387 "build_fcn": (
3388 build_reduce,
3389 TosaTensorGen.tgBasic,
3390 TosaTensorValuesGen.tvgDefault,
3391 TosaArgGen.agAxis,
3392 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003393 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003394 "error_if_validators": (
3395 TosaErrorValidator.evAxisLargerRank,
3396 TosaErrorValidator.evAxisSmallerZero,
3397 TosaErrorValidator.evShapeOfAxisNotOne,
3398 TosaErrorValidator.evWrongInputType,
3399 TosaErrorValidator.evWrongOutputType,
3400 TosaErrorValidator.evWrongRank,
3401 TosaErrorValidator.evWrongInputList,
3402 TosaErrorValidator.evWrongOutputList,
3403 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003404 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003405 "reduce_max": {
3406 "op": Op.REDUCE_MAX,
3407 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003408 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003409 "build_fcn": (
3410 build_reduce,
3411 TosaTensorGen.tgBasic,
3412 TosaTensorValuesGen.tvgDefault,
3413 TosaArgGen.agAxis,
3414 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003415 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003416 "error_if_validators": (
3417 TosaErrorValidator.evAxisLargerRank,
3418 TosaErrorValidator.evAxisSmallerZero,
3419 TosaErrorValidator.evShapeOfAxisNotOne,
3420 TosaErrorValidator.evWrongInputType,
3421 TosaErrorValidator.evWrongOutputType,
3422 TosaErrorValidator.evWrongRank,
3423 TosaErrorValidator.evWrongInputList,
3424 TosaErrorValidator.evWrongOutputList,
3425 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003426 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003427 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003428 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003429 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003430 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003431 "build_fcn": (
3432 build_reduce,
3433 TosaTensorGen.tgBasic,
3434 TosaTensorValuesGen.tvgDefault,
3435 TosaArgGen.agAxis,
3436 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003437 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003438 "error_if_validators": (
3439 TosaErrorValidator.evAxisLargerRank,
3440 TosaErrorValidator.evAxisSmallerZero,
3441 TosaErrorValidator.evShapeOfAxisNotOne,
3442 TosaErrorValidator.evWrongInputType,
3443 TosaErrorValidator.evWrongOutputType,
3444 TosaErrorValidator.evWrongRank,
3445 TosaErrorValidator.evWrongInputList,
3446 TosaErrorValidator.evWrongOutputList,
3447 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003448 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003449 "reduce_product": {
3450 "op": Op.REDUCE_PRODUCT,
3451 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003452 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003453 "build_fcn": (
3454 build_reduce,
3455 TosaTensorGen.tgBasic,
3456 TosaTensorValuesGen.tvgDefault,
3457 TosaArgGen.agAxis,
3458 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003459 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003460 "error_if_validators": (
3461 TosaErrorValidator.evAxisLargerRank,
3462 TosaErrorValidator.evAxisSmallerZero,
3463 TosaErrorValidator.evShapeOfAxisNotOne,
3464 TosaErrorValidator.evWrongInputType,
3465 TosaErrorValidator.evWrongOutputType,
3466 TosaErrorValidator.evWrongRank,
3467 TosaErrorValidator.evWrongInputList,
3468 TosaErrorValidator.evWrongOutputList,
3469 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003470 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003471 "reduce_sum": {
3472 "op": Op.REDUCE_SUM,
3473 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003474 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003475 "build_fcn": (
3476 build_reduce,
3477 TosaTensorGen.tgBasic,
3478 TosaTensorValuesGen.tvgReduceSum,
3479 TosaArgGen.agAxis,
3480 ),
James Ward8b390432022-08-12 20:48:56 +01003481 "types": (DType.FP16, DType.FLOAT, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003482 "error_if_validators": (
3483 TosaErrorValidator.evAxisLargerRank,
3484 TosaErrorValidator.evAxisSmallerZero,
3485 TosaErrorValidator.evShapeOfAxisNotOne,
3486 TosaErrorValidator.evWrongInputType,
3487 TosaErrorValidator.evWrongOutputType,
3488 TosaErrorValidator.evWrongRank,
3489 TosaErrorValidator.evWrongInputList,
3490 TosaErrorValidator.evWrongOutputList,
3491 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003492 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003493 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003494 "concat": {
3495 "op": Op.CONCAT,
3496 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003497 "build_fcn": (
3498 build_concat,
3499 TosaTensorGen.tgConcat,
3500 TosaTensorValuesGen.tvgConcat,
3501 TosaArgGen.agAxis,
3502 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003503 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003504 "error_if_validators": (
3505 TosaErrorValidator.evAxisLargerRank,
3506 TosaErrorValidator.evAxisSmallerZero,
3507 TosaErrorValidator.evConcatInputRankMismatch,
3508 TosaErrorValidator.evConcatShapeSumMismatch,
3509 TosaErrorValidator.evConcatInputDimMismatch,
3510 TosaErrorValidator.evWrongInputType,
3511 TosaErrorValidator.evWrongOutputType,
3512 TosaErrorValidator.evWrongOutputList,
3513 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003514 },
3515 "pad": {
3516 "op": Op.PAD,
3517 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003518 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003519 "build_fcn": (
3520 build_pad,
3521 TosaTensorGen.tgBasic,
3522 TosaTensorValuesGen.tvgDefault,
3523 TosaArgGen.agPad,
3524 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003525 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003526 "error_if_validators": (
3527 TosaErrorValidator.evWrongInputType,
3528 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003529 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003530 TosaErrorValidator.evWrongOutputType,
3531 TosaErrorValidator.evWrongInputList,
3532 TosaErrorValidator.evWrongOutputList,
3533 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003534 },
3535 "reshape": {
3536 "op": Op.RESHAPE,
3537 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003538 "build_fcn": (
3539 build_reshape,
3540 TosaTensorGen.tgBasic,
3541 TosaTensorValuesGen.tvgDefault,
3542 TosaArgGen.agReshape,
3543 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003544 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003545 "error_if_validators": (
3546 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3547 TosaErrorValidator.evWrongInputType,
3548 TosaErrorValidator.evWrongOutputType,
3549 TosaErrorValidator.evWrongInputList,
3550 TosaErrorValidator.evWrongOutputList,
3551 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003552 },
3553 "reverse": {
3554 "op": Op.REVERSE,
3555 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003556 "build_fcn": (
3557 build_reverse,
3558 TosaTensorGen.tgBasic,
3559 TosaTensorValuesGen.tvgDefault,
3560 TosaArgGen.agAxis,
3561 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003562 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003563 "error_if_validators": (
3564 TosaErrorValidator.evAxisSmallerZero,
3565 TosaErrorValidator.evAxisLargerRank,
3566 TosaErrorValidator.evWrongInputType,
3567 TosaErrorValidator.evWrongOutputType,
3568 TosaErrorValidator.evWrongInputList,
3569 TosaErrorValidator.evWrongOutputList,
3570 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003571 },
3572 "slice": {
3573 "op": Op.SLICE,
3574 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003575 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003576 "build_fcn": (
3577 build_slice,
3578 TosaTensorGen.tgBasic,
3579 TosaTensorValuesGen.tvgDefault,
3580 TosaArgGen.agSlice,
3581 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003582 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003583 "error_if_validators": (
3584 TosaErrorValidator.evStartSmallerZero,
3585 TosaErrorValidator.evSizeSmallerEqualZero,
3586 TosaErrorValidator.evStartSizeOutsideBounds,
3587 TosaErrorValidator.evSizeOutputShapeMismatch,
3588 TosaErrorValidator.evInputSizeStartLengthMismatch,
3589 TosaErrorValidator.evWrongRank,
3590 TosaErrorValidator.evWrongInputType,
3591 TosaErrorValidator.evWrongOutputType,
3592 TosaErrorValidator.evWrongInputList,
3593 TosaErrorValidator.evWrongOutputList,
3594 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003595 },
3596 "tile": {
3597 "op": Op.TILE,
3598 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003599 "build_fcn": (
3600 build_tile,
3601 TosaTensorGen.tgBasic,
3602 TosaTensorValuesGen.tvgDefault,
3603 TosaArgGen.agTile,
3604 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003605 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003606 "error_if_validators": (
3607 TosaErrorValidator.evWrongInputType,
3608 TosaErrorValidator.evWrongOutputType,
3609 TosaErrorValidator.evWrongInputList,
3610 TosaErrorValidator.evWrongOutputList,
3611 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003612 },
3613 "transpose": {
3614 "op": Op.TRANSPOSE,
3615 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003616 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003617 "build_fcn": (
3618 build_transpose,
3619 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003620 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003621 TosaArgGen.agTranspose,
3622 ),
3623 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003624 "error_if_validators": (
3625 TosaErrorValidator.evIndexOutsideBounds,
3626 TosaErrorValidator.evIndexUsedTwice,
3627 TosaErrorValidator.evWrongInputType,
3628 TosaErrorValidator.evWrongOutputType,
3629 TosaErrorValidator.evWrongInputList,
3630 TosaErrorValidator.evWrongOutputList,
3631 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003632 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003633 # Data nodes
3634 "const": {
3635 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003636 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003637 "build_fcn": (
3638 build_const,
3639 TosaTensorGen.tgBasic,
3640 TosaTensorValuesGen.tvgDefault,
3641 None,
3642 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003643 "types": TYPE_FIB,
3644 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003645 "identity": {
3646 "op": Op.IDENTITY,
3647 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003648 "build_fcn": (
3649 build_unary,
3650 TosaTensorGen.tgBasic,
3651 TosaTensorValuesGen.tvgDefault,
3652 None,
3653 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003654 "types": TYPE_FIB,
3655 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003656 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003657 "gather": {
3658 "op": Op.GATHER,
3659 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3660 "operands": (1, 0),
3661 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003662 "build_fcn": (
3663 build_gather,
3664 TosaTensorGen.tgBasic,
3665 TosaTensorValuesGen.tvgDefault,
3666 None,
3667 ),
James Ward8b390432022-08-12 20:48:56 +01003668 "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FLOAT),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003669 "error_if_validators": (
3670 TosaErrorValidator.evWrongInputType,
3671 TosaErrorValidator.evWrongOutputType,
3672 TosaErrorValidator.evWrongInputList,
3673 TosaErrorValidator.evWrongOutputList,
3674 TosaErrorValidator.evWrongRank,
3675 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003676 },
3677 "scatter": {
3678 "op": Op.SCATTER,
3679 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003680 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003681 "operands": (2, 0),
3682 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003683 "build_fcn": (
3684 build_scatter,
3685 TosaTensorGen.tgScatter,
3686 TosaTensorValuesGen.tvgDefault,
3687 None,
3688 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003689 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003690 "error_if_validators": (
3691 TosaErrorValidator.evWrongInputType,
3692 TosaErrorValidator.evWrongOutputType,
3693 TosaErrorValidator.evWrongInputList,
3694 TosaErrorValidator.evWrongOutputList,
3695 TosaErrorValidator.evWrongRank,
3696 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003697 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003698 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003699 "resize": {
3700 "op": Op.RESIZE,
3701 "operands": (1, 0),
3702 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003703 "build_fcn": (
3704 build_resize,
3705 TosaTensorGen.tgNHWC,
3706 TosaTensorValuesGen.tvgDefault,
3707 TosaArgGen.agResize,
3708 ),
James Ward8b390432022-08-12 20:48:56 +01003709 "types": (DType.INT8, DType.INT16, DType.FP16, DType.FLOAT),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003710 "invalid_test_validators": (
3711 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003712 ),
3713 "error_if_validators": (
3714 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003715 TosaErrorValidator.evScaleSmallerEqualZero,
3716 TosaErrorValidator.evScaleNLargerMax,
3717 TosaErrorValidator.evScaleDLargerMax,
3718 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003719 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003720 TosaErrorValidator.evBorderSmallerMin,
3721 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003722 TosaErrorValidator.evWrongInputType,
3723 TosaErrorValidator.evWrongOutputType,
3724 TosaErrorValidator.evWrongRank,
3725 TosaErrorValidator.evWrongInputList,
3726 TosaErrorValidator.evWrongOutputList,
3727 TosaErrorValidator.evBatchMismatch,
3728 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003729 TosaErrorValidator.evResizeOutputShapeMismatch,
3730 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003731 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003732 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003733 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003734 "cast": {
3735 "op": Op.CAST,
3736 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003737 "build_fcn": (
3738 build_cast,
3739 TosaTensorGen.tgBasic,
3740 TosaTensorValuesGen.tvgDefault,
3741 TosaArgGen.agCast,
3742 ),
James Ward8b390432022-08-12 20:48:56 +01003743 "types": (
3744 DType.FP16,
3745 DType.FLOAT,
3746 DType.INT8,
3747 DType.INT16,
3748 DType.INT32,
3749 DType.BOOL,
3750 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003751 "error_if_validators": (
3752 TosaErrorValidator.evWrongInputType,
3753 TosaErrorValidator.evWrongOutputType,
3754 TosaErrorValidator.evWrongInputList,
3755 TosaErrorValidator.evWrongOutputList,
3756 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003757 },
3758 "rescale": {
3759 "op": Op.RESCALE,
3760 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003761 "build_fcn": (
3762 build_rescale,
3763 TosaTensorGen.tgBasic,
3764 TosaTensorValuesGen.tvgDefault,
3765 TosaArgGen.agRescale,
3766 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003767 "types": [
3768 DType.UINT8,
3769 DType.INT8,
3770 DType.INT16,
3771 DType.INT32,
3772 DType.INT48,
3773 DType.UINT16,
3774 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003775 "error_if_validators": (
3776 TosaErrorValidator.evInputZeroPointNotZero,
3777 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003778 TosaErrorValidator.evU16InputZeroPointNotValid,
3779 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003780 TosaErrorValidator.evScaleTrue,
3781 TosaErrorValidator.evScaleNotTrue,
3782 TosaErrorValidator.evWrongInputType,
3783 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003784 TosaErrorValidator.evWrongInputList,
3785 TosaErrorValidator.evWrongOutputList,
3786 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003787 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003788 # Custom
3789 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003790 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003791 # Two varients of cond_if, one that generates one of two constant tensors (no
3792 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3793 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003794 "cond_if_const": {
3795 "op": Op.COND_IF,
3796 "operands": (0, 2),
3797 "build_fcn": (
3798 build_cond_if_const,
3799 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003800 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003801 TosaArgGen.agCondIf,
3802 ),
3803 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003804 "error_if_validators": (
3805 TosaErrorValidator.evOutputListThenGraphMismatch,
3806 TosaErrorValidator.evOutputListElseGraphMismatch,
3807 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003808 },
3809 "cond_if_binary": {
3810 "op": Op.COND_IF,
3811 "operands": (2, 0),
3812 "build_fcn": (
3813 build_cond_if_binary,
3814 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003815 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003816 TosaArgGen.agCondIf,
3817 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003818 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 "error_if_validators": (
3820 TosaErrorValidator.evInputListThenGraphMismatch,
3821 TosaErrorValidator.evInputListElseGraphMismatch,
3822 TosaErrorValidator.evOutputListThenGraphMismatch,
3823 TosaErrorValidator.evOutputListElseGraphMismatch,
3824 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003825 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003826 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003827 "while_loop": {
3828 "op": Op.WHILE_LOOP,
3829 "operands": (0, 1),
3830 "build_fcn": (
3831 build_while_loop,
3832 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003833 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003834 TosaArgGen.agWhileLoop,
3835 ),
3836 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003837 "error_if_validators": (
3838 TosaErrorValidator.evInputListOutputListMismatch,
3839 TosaErrorValidator.evInputListCondGraphMismatch,
3840 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3841 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3842 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
3843 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003844 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003845 }
3846
Kevin Cheng550ccc52021-03-03 11:21:43 -08003847
Eric Kunzee5e26762020-10-13 16:11:07 -07003848class OutputShaper:
3849 # Methods in this class compute the expected output shape and datatype
3850 # for common classes of operations
3851 def __init__(self):
3852 pass
3853
3854 # These methods return arguments that can be used for
3855 # creating a new output tensor
3856 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003857 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3858 if error_name != ErrorIf.RankMismatch:
3859 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003860 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003861
3862 shape = []
3863 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003864 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003865 shape.append(b.shape[i])
3866 else:
3867 shape.append(a.shape[i])
3868
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003869 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003870 all_dtypes = [
3871 DType.INT8,
3872 DType.INT16,
3873 DType.INT32,
3874 DType.INT48,
3875 DType.FLOAT,
3876 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003877 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3878 outputDType = rng.choice(wrong_dtypes)
3879 else:
3880 outputDType = a.dtype
3881
3882 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003883
3884 @staticmethod
3885 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003886 assert len(a.shape) == len(b.shape)
3887 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003888
3889 shape = []
3890 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003891 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003892 shape.append(a.shape[i])
3893
Kevin Cheng550ccc52021-03-03 11:21:43 -08003894 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003895
3896 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003897 def unaryOp(ser, rng, a, error_name=None):
3898 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003899 all_dtypes = [
3900 DType.INT8,
3901 DType.INT16,
3902 DType.INT32,
3903 DType.INT48,
3904 DType.FLOAT,
3905 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003906 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3907 outputDType = rng.choice(wrong_dtypes)
3908 else:
3909 outputDType = a.dtype
3910
3911 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003912
3913 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003914 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003915 if error_name != ErrorIf.RankMismatch:
3916 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003917 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003918
3919 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003920 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003921 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003922 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3923 else:
3924 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003925
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003926 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003927 all_dtypes = [
3928 DType.INT8,
3929 DType.INT16,
3930 DType.INT32,
3931 DType.INT48,
3932 DType.FLOAT,
3933 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003934 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3935 outputDType = rng.choice(wrong_dtypes)
3936 else:
3937 outputDType = a.dtype
3938
3939 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003940
3941 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003942 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003943 if error_name != ErrorIf.RankMismatch:
3944 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003945 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003946
3947 # Do broadcast
3948 shape = []
3949 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08003950 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07003951 shape.append(b.shape[i])
3952 else:
3953 shape.append(a.shape[i])
3954
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003955 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003956 wrong_dtypes = [
3957 DType.INT8,
3958 DType.INT16,
3959 DType.INT32,
3960 DType.INT48,
3961 DType.FLOAT,
3962 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003963 outputDType = rng.choice(wrong_dtypes)
3964 else:
3965 outputDType = DType.BOOL
3966
3967 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003968
3969 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01003970 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003971 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003972 if error_name not in [
3973 ErrorIf.AxisSmallerZero,
3974 ErrorIf.AxisLargerRank,
3975 ErrorIf.ShapeOfAxisNotOne,
3976 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01003977 shape[axis] = 1
3978 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
3979 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07003980
Matthew Haddond6ce7252021-09-29 15:35:44 +01003981 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003982 all_dtypes = [
3983 DType.INT8,
3984 DType.INT16,
3985 DType.INT32,
3986 DType.INT48,
3987 DType.FLOAT,
3988 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01003989 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3990 outputDType = rng.choice(wrong_dtypes)
3991 else:
3992 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003993
Matthew Haddond6ce7252021-09-29 15:35:44 +01003994 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003995
3996 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003997 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07003998 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01003999
4000 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4001 del shape[axis]
4002
4003 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4004 remove = rng.choice([True, False])
4005 if remove and len(shape) > 1:
4006 del shape[0]
4007 else:
4008 shape.append(1)
4009 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4010 for i in range(len(shape)):
4011 shape[i] = shape[i] + rng.integers(1, 10)
4012
4013 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004014 all_dtypes = [
4015 DType.INT8,
4016 DType.INT16,
4017 DType.INT32,
4018 DType.INT48,
4019 DType.FLOAT,
4020 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004021 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4022 outputDType = rng.choice(wrong_dtypes)
4023 else:
4024 outputDType = DType.INT32
4025
4026 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004027
4028 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004029 def conv2dOp(
4030 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4031 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004032
4033 # IFM: NHWC
4034 # Filter: OHWI
4035 # OFM: NHWC
4036
Kevin Cheng550ccc52021-03-03 11:21:43 -08004037 h = (
4038 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004039 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004040 + padding[0]
4041 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004042 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004043 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004044
Kevin Cheng550ccc52021-03-03 11:21:43 -08004045 w = (
4046 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004047 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004048 + padding[2]
4049 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004050 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004051 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004052
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004053 if error_name == ErrorIf.ConvOutputShapeMismatch:
4054 choices = [1, 2, 3]
4055 change = rng.choice(choices)
4056 # increment in multiples of stride to not hit non-integer error case
4057 if change in [1, 3]:
4058 h = h + (rng.choice(choices) * strides[0])
4059 if change in [2, 3]:
4060 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004061
Eric Kunzee5e26762020-10-13 16:11:07 -07004062 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4063
James Ward8b390432022-08-12 20:48:56 +01004064 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004065 # Pick some potentially correct output dtype if input type is incorrect
4066 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004067 else:
James Ward8b390432022-08-12 20:48:56 +01004068 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004069
4070 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004071 if ifm.dtype == DType.FP16:
4072 excludes = [DType.FP16, DType.FLOAT]
4073 else:
4074 excludes = [out_dtype]
4075 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004076 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004077
Kevin Cheng550ccc52021-03-03 11:21:43 -08004078 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004079
4080 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004081 def conv3dOp(
4082 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4083 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004084
4085 # IFM: NDHWC
4086 # Filter: ODHWI
4087 # OFM: NDHWC
4088
4089 d = (
4090 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004091 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004092 + padding[0]
4093 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004094 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004095 ) // strides[0] + 1
4096
4097 h = (
4098 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004099 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004100 + padding[2]
4101 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004102 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004103 ) // strides[1] + 1
4104
4105 w = (
4106 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004107 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004108 + padding[4]
4109 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004110 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004111 ) // strides[2] + 1
4112
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004113 if error_name == ErrorIf.ConvOutputShapeMismatch:
4114 choices = [1, 2, 3, 4]
4115 change = rng.choice(choices)
4116 # increment in multiples of stride to not hit non-integer error case
4117 if change in [1, 4]:
4118 d = d + (rng.choice(choices) * strides[0])
4119 if change in [2, 4]:
4120 h = h + (rng.choice(choices) * strides[1])
4121 if change in [3, 4]:
4122 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004123
Kevin Cheng1533b852021-09-01 12:51:58 -07004124 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4125
James Ward8b390432022-08-12 20:48:56 +01004126 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004127 # Pick some potentially correct output dtype if input type is incorrect
4128 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004129 else:
James Ward8b390432022-08-12 20:48:56 +01004130 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004131
4132 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004133 if ifm.dtype == DType.FP16:
4134 excludes = [DType.FP16, DType.FLOAT]
4135 else:
4136 excludes = [out_dtype]
4137 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004138 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004139
4140 return ser.addOutput(ofm_shape, out_dtype)
4141
4142 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004143 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004144 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004145 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004146 # IFM: NHWC
4147 # Filter: HWCM
4148 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004149
Kevin Cheng550ccc52021-03-03 11:21:43 -08004150 h = (
4151 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004152 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004153 + padding[0]
4154 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004155 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004156 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004157
Kevin Cheng550ccc52021-03-03 11:21:43 -08004158 w = (
4159 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004160 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004161 + padding[2]
4162 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004163 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004164 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004165
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004166 if error_name == ErrorIf.ConvOutputShapeMismatch:
4167 choices = [1, 2, 3]
4168 change = rng.choice(choices)
4169 # increment in multiples of stride to not hit non-integer error case
4170 if change in [1, 3]:
4171 h = h + (rng.choice(choices) * strides[0])
4172 if change in [2, 3]:
4173 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004174
Eric Kunzee5e26762020-10-13 16:11:07 -07004175 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4176
James Ward8b390432022-08-12 20:48:56 +01004177 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004178 # Pick some potentially correct output dtype if input type is incorrect
4179 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004180 else:
James Ward8b390432022-08-12 20:48:56 +01004181 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004182
4183 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004184 if ifm.dtype == DType.FP16:
4185 excludes = [DType.FP16, DType.FLOAT]
4186 else:
4187 excludes = [out_dtype]
4188 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004189 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004190
Kevin Cheng550ccc52021-03-03 11:21:43 -08004191 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004192
4193 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004194 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004195 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004196 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004197 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004198 h = 1
4199 w = 1
4200 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004201 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4202 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004203
4204 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004205 choices = [1, 2, 3]
4206 change = rng.choice(choices)
4207 # increment in multiples of stride to not hit non-integer error case
4208 if change in [1, 3]:
4209 h = h + (rng.choice(choices) * stride[0])
4210 if change in [2, 3]:
4211 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004212 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004213
4214 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004215 all_dtypes = [
4216 DType.INT8,
4217 DType.INT16,
4218 DType.INT32,
4219 DType.INT48,
4220 DType.FLOAT,
James Ward8b390432022-08-12 20:48:56 +01004221 DType.FP16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004222 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004223 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4224 outputDType = rng.choice(wrong_dtypes)
4225 else:
4226 outputDType = ifm.dtype
4227
4228 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004229
4230 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004231 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004232 # input: N, IC
4233 # filter: OC, IC
4234 # output: N, OC
4235
4236 output_shape = [input.shape[0], filter.shape[0]]
4237
James Ward8b390432022-08-12 20:48:56 +01004238 # Validated in arg_gen (also invalidated for ErrorIf)
4239 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004240
Kevin Cheng550ccc52021-03-03 11:21:43 -08004241 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004242
4243 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004244 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004245 # a: N, H, C
4246 # b: N, C, W
4247 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004248
Kevin Cheng2d60f002021-06-09 14:18:32 -07004249 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004250
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004251 if error_name == ErrorIf.WrongOutputType:
4252 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004253 incorrect_types = (
4254 DType.INT4,
4255 DType.INT8,
4256 DType.INT16,
4257 DType.INT48,
4258 DType.FLOAT,
4259 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004260 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004261 incorrect_types = (
4262 DType.INT4,
4263 DType.INT8,
4264 DType.INT16,
4265 DType.INT32,
4266 DType.FLOAT,
4267 )
James Ward8b390432022-08-12 20:48:56 +01004268 elif a.dtype == DType.FLOAT or a.dtype == DType.FP16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004269 incorrect_types = (
4270 DType.INT4,
4271 DType.INT8,
4272 DType.INT16,
4273 DType.INT32,
4274 DType.INT48,
4275 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004276 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004277 elif error_name == ErrorIf.WrongInputType:
4278 # Pick some potentially correct output dtype if input type is incorrect
4279 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004280 else:
James Ward8b390432022-08-12 20:48:56 +01004281 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004282
Kevin Cheng550ccc52021-03-03 11:21:43 -08004283 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004284
4285 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004286 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004287 input1 = a[0]
4288 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004289
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004290 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004291 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004292 if not (
4293 # unable to concat tensors of different ranks
4294 error_name == ErrorIf.ConcatInputRankMismatch
4295 # unable to concat tensors along an invalid axis
4296 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004297 ):
4298 for tensor in remaining_inputs:
4299 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004300
Matthew Haddon01c359d2021-10-15 16:30:48 +01004301 if error_name == ErrorIf.ConcatShapeSumMismatch:
4302 output_shape[axis] += rng.integers(5, 10)
4303
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004304 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004305 all_dtypes = {
4306 DType.INT8,
4307 DType.INT16,
4308 DType.INT32,
4309 DType.INT48,
4310 DType.FLOAT,
4311 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004312 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4313 outputDType = rng.choice(wrong_dtypes)
4314 else:
4315 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004316
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004317 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004318
4319 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004320 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004321
4322 output_shape = a.shape.copy()
4323
4324 for i in range(len(output_shape)):
4325 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4326
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004327 if error_name == ErrorIf.PadOutputShapeMismatch:
4328 bad_dim = rng.choice(range(len(output_shape)))
4329 output_shape[bad_dim] -= rng.choice([1, 2])
4330
Matthew Haddone807aae2021-10-11 18:12:58 +01004331 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004332 all_dtypes = [
4333 DType.INT8,
4334 DType.INT16,
4335 DType.INT32,
4336 DType.INT48,
4337 DType.FLOAT,
James Ward8b390432022-08-12 20:48:56 +01004338 DType.FP16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004339 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004340 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4341 outputDType = rng.choice(wrong_dtypes)
4342 else:
4343 outputDType = a.dtype
4344
4345 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004346
4347 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004348 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004349 output_shape = shape.copy()
4350
Matthew Haddone807aae2021-10-11 18:12:58 +01004351 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4352 for i in range(len(output_shape)):
4353 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4354
4355 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004356 all_dtypes = [
4357 DType.INT8,
4358 DType.INT16,
4359 DType.INT32,
4360 DType.INT48,
4361 DType.FLOAT,
4362 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004363 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4364 outputDType = rng.choice(wrong_dtypes)
4365 else:
4366 outputDType = a.dtype
4367
4368 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004369
4370 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004371 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004372
Matthew Haddone807aae2021-10-11 18:12:58 +01004373 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004374 all_dtypes = [
4375 DType.INT8,
4376 DType.INT16,
4377 DType.INT32,
4378 DType.INT48,
4379 DType.FLOAT,
4380 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004381 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4382 outputDType = rng.choice(wrong_dtypes)
4383 else:
4384 outputDType = a.dtype
4385
4386 if error_name == ErrorIf.SizeOutputShapeMismatch:
4387 output_shape = size.copy()
4388 for index in range(len(output_shape)):
4389 if output_shape[index] <= 2:
4390 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4391 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004392 output_shape[index] = output_shape[index] + rng.choice(
4393 [-2, -1, 1, 2]
4394 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004395 else:
4396 output_shape = size.copy()
4397
4398 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004399
4400 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004401 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004402
4403 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004404 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004405
4406 for i in range(len(output_shape)):
4407 output_shape[i] = a.shape[i] * multiples[i]
4408
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004409 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004410 all_dtypes = [
4411 DType.INT8,
4412 DType.INT16,
4413 DType.INT32,
4414 DType.INT48,
4415 DType.FLOAT,
4416 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004417 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4418 outputDType = rng.choice(wrong_dtypes)
4419 else:
4420 outputDType = a.dtype
4421
4422 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004423
4424 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004425 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004426 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004427
Kevin Cheng550ccc52021-03-03 11:21:43 -08004428 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004429
Matthew Haddone807aae2021-10-11 18:12:58 +01004430 if error_name == ErrorIf.IndexOutsideBounds:
4431 for i in range(len(output_shape)):
4432 output_shape[i] = a.shape[0]
4433 else:
4434 for i in range(len(output_shape)):
4435 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004436
Matthew Haddone807aae2021-10-11 18:12:58 +01004437 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004438 all_dtypes = [
4439 DType.INT8,
4440 DType.INT16,
4441 DType.INT32,
4442 DType.INT48,
4443 DType.FLOAT,
4444 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004445 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4446 outputDType = rng.choice(wrong_dtypes)
4447 else:
4448 outputDType = a.dtype
4449
4450 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004451
4452 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004453 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004454 if error_name != ErrorIf.WrongRank:
4455 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004456 assert len(indices.shape) == 2
4457 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004458
Kevin Cheng77d0f762020-11-24 10:26:32 -08004459 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4460
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004461 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004462 all_dtypes = [
4463 DType.INT8,
4464 DType.INT16,
4465 DType.INT32,
4466 DType.INT48,
4467 DType.FLOAT,
4468 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004469 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4470 outputDType = rng.choice(wrong_dtypes)
4471 else:
4472 outputDType = values.dtype
4473
4474 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004475
4476 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004477 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004478 if error_name != ErrorIf.WrongRank:
4479 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004480 assert len(indices.shape) == 2
4481 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004482 assert values_in.shape[0] == indices.shape[0] # N
4483 assert input.shape[1] == indices.shape[1] # W
4484 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004485
4486 output_shape = values_in.shape
4487
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004488 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004489 all_dtypes = [
4490 DType.INT8,
4491 DType.INT16,
4492 DType.INT32,
4493 DType.INT48,
4494 DType.FLOAT,
4495 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004496 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4497 outputDType = rng.choice(wrong_dtypes)
4498 else:
4499 outputDType = values_in.dtype
4500
4501 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004502
4503 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004504 def tableOp(ser, rng, input, error_name=None):
4505 # Same shape as the input, dtype dependent on input dtype
4506 if error_name != ErrorIf.WrongInputType:
4507 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004508 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004509 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004510 wrong_dtypes = [
4511 DType.INT8,
4512 DType.INT16,
4513 DType.INT32,
4514 DType.INT48,
4515 DType.FLOAT,
4516 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004517 wrong_dtypes.remove(output_dtype)
4518 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004519 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004520
4521 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004522 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004523 serializer,
4524 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004525 input,
4526 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004527 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004528 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004529 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004530 input_dtype,
4531 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004532 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004533 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004534 # Calculate OH, OW
4535 scale_y_n = scale[0]
4536 scale_y_d = scale[1]
4537 scale_x_n = scale[2]
4538 scale_x_d = scale[3]
4539 if error_name == ErrorIf.ScaleSmallerEqualZero:
4540 scale_y_n = max(scale_y_n, 1)
4541 scale_y_d = max(scale_y_d, 1)
4542 scale_x_n = max(scale_x_n, 1)
4543 scale_x_d = max(scale_x_d, 1)
4544
4545 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4546 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4547
4548 if error_name is not None:
4549 # Make sure the output tensor is valid, which can occur when
4550 # scale, offset or border have been changed for ERROR_IFs
4551 oh = max(oh, 1)
4552 ow = max(ow, 1)
4553 if error_name != ErrorIf.MaxDimExceeded:
4554 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4555 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4556
4557 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4558 choices = [1, 2, 3]
4559 change = rng.choice(choices)
4560 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4561 if change in [1, 3]:
4562 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4563 oh -= scale_y_d
4564 assert oh > 0 # Should have been caught in agResize
4565 else:
4566 oh += scale_y_d
4567 if change in [2, 3]:
4568 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4569 ow -= scale_x_d
4570 assert ow > 0 # Should have been caught in agResize
4571 else:
4572 ow += scale_x_d
4573
Matthew Haddon848efb42021-09-09 12:30:53 +01004574 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004575 output_dims = [
4576 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004577 oh,
4578 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004579 input.shape[0],
4580 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004581 elif error_name == ErrorIf.BatchMismatch:
4582 output_dims = [
4583 input.shape[0] + rng.integers(1, 10),
4584 oh,
4585 ow,
4586 input.shape[3],
4587 ]
4588 elif error_name == ErrorIf.ChannelMismatch:
4589 output_dims = [
4590 input.shape[0],
4591 oh,
4592 ow,
4593 input.shape[3] + rng.integers(1, 10),
4594 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004595 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004596 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004597
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004598 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004599
4600 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004601 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004602 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004603
4604 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004605 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004606 if error_name == ErrorIf.ConvOutputShapeMismatch:
4607 choices = [1, 2, 3]
4608 change = rng.choice(choices)
4609 if change in [1, 3]:
4610 output_shape[1] = output_shape[1] + rng.choice(choices)
4611 if change in [2, 3]:
4612 output_shape[2] = output_shape[2] + rng.choice(choices)
4613
James Ward8b390432022-08-12 20:48:56 +01004614 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004615 # Pick some potentially correct output dtype if input type is incorrect
4616 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004617 else:
James Ward8b390432022-08-12 20:48:56 +01004618 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004619
4620 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004621 if ifm.dtype == DType.FP16:
4622 excludes = [DType.FP16, DType.FLOAT]
4623 else:
4624 excludes = [out_dtype]
4625 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004626 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004627
Kevin Cheng550ccc52021-03-03 11:21:43 -08004628 return ser.addOutput(output_shape, out_dtype)