blob: d799eb0f961641f222f1eb51880d28427eae72d9 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Jeremy Johnson05c711e2022-12-12 18:00:41 +000017from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010018from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010019from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010020from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000021from tosa.DType import DType
22from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010023
24
Eric Kunzee5e26762020-10-13 16:11:07 -070025class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010026 # Maximum rank of tensor supported by test generator.
27 TOSA_TENSOR_MAX_RANK = 6
28
Eric Kunzee5e26762020-10-13 16:11:07 -070029 def __init__(self, args):
30 self.args = args
31 self.basePath = args.output_dir
32 self.random_seed = args.random_seed
33 self.ser = None
34 self.rng = np.random.default_rng(self.random_seed)
35 self.createDynamicOpLists()
36 self.initOpListDefaults()
37 self.quantGen = TosaQuantGen()
38 # Force makeShape to do a specific starting shape
39 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010040 # Work out floating point range
41 self.random_fp_low = min(args.tensor_fp_value_range)
42 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070043
44 def createSerializer(self, opName, testPath):
45 self.testPath = os.path.join(opName, testPath)
46
47 fullPath = os.path.join(self.basePath, self.testPath)
48 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010049 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070050
51 def getSerializer(self):
52 return self.ser
53
54 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080055 with open(
56 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
57 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070058 fd.write(self.ser.serialize())
59
Kevin Cheng550ccc52021-03-03 11:21:43 -080060 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
61 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070062
Matthew Haddon74567092021-07-16 15:38:20 +010063 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000064 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010065 seed = self.random_seed + 1
66 self.rng = np.random.default_rng(seed)
67
Eric Kunzee5e26762020-10-13 16:11:07 -070068 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070069 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070071 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070072 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070073 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070074 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010075 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
76 elif dtype == DType.UINT8:
77 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070078 elif dtype == DType.INT16:
79 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010080 elif dtype == DType.UINT16:
81 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070082 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080083 return np.int32(
84 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
85 )
Eric Kunzee5e26762020-10-13 16:11:07 -070086 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080087 return np.int64(
88 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
89 )
James Ward8b390432022-08-12 20:48:56 +010090 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010091 return np.float16(
92 self.rng.uniform(
93 low=self.random_fp_low, high=self.random_fp_high, size=shape
94 )
95 )
James Ward24dbc422022-10-19 12:20:31 +010096 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010097 f32_tensor = np.float32(
98 self.rng.uniform(
99 low=self.random_fp_low, high=self.random_fp_high, size=shape
100 )
101 )
James Ward24dbc422022-10-19 12:20:31 +0100102 # Floor the last 16 bits of each f32 value
103 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100104 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100105 return np.float32(
106 self.rng.uniform(
107 low=self.random_fp_low, high=self.random_fp_high, size=shape
108 )
109 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800111 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700112
Kevin Cheng989cb052021-04-28 16:29:44 -0700113 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 placeholders = []
115
Kevin Cheng989cb052021-04-28 16:29:44 -0700116 assert len(shape_list) == len(dtype_list)
117
118 for idx, shape in enumerate(shape_list):
119 arr = self.getRandTensor(shape, dtype_list[idx])
120 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 return placeholders
123
Kevin Cheng989cb052021-04-28 16:29:44 -0700124 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700125 consts = []
126
Kevin Cheng989cb052021-04-28 16:29:44 -0700127 assert len(shape_list) == len(dtype_list)
128
129 for idx, shape in enumerate(shape_list):
130 arr = self.getRandTensor(shape, dtype_list[idx])
131 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132
133 return consts
134
135 def makeShape(self, rank):
136 if self.targetted_shape:
137 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800138 return np.int32(
139 self.rng.integers(
140 low=self.args.tensor_shape_range[0],
141 high=self.args.tensor_shape_range[1],
142 size=rank,
143 )
144 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
146 def setTargetShape(self, shape):
147 self.targetted_shape = shape
148
149 def randInt(self, low=0, high=256):
150 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
151
152 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100153 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100154 return np.float32(
155 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
156 )
James Ward8b390432022-08-12 20:48:56 +0100157 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100158 return np.float16(
159 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
160 )
James Ward24dbc422022-10-19 12:20:31 +0100161 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162 rand_f32 = np.float32(
163 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
164 )
James Ward24dbc422022-10-19 12:20:31 +0100165 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 elif dtype == DType.BOOL:
167 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700168 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700169 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700170 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100172 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 elif dtype == DType.INT16:
174 low, high = (-32768, 32768)
175 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800176 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800178 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 # Special size
180 return np.int64(self.rng.integers(low, high, size=1))[0]
181 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800182 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700183
184 return np.int32(self.rng.integers(low, high, size=1))[0]
185
186 def shapeStr(self, shape):
187
188 sStr = []
189 # Convert to strings
190 for i in shape:
191 sStr.append(str(i))
192
Kevin Cheng550ccc52021-03-03 11:21:43 -0800193 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700194
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100195 def typeStr(self, dtype):
196 if isinstance(dtype, list) or isinstance(dtype, tuple):
197 assert len(dtype) >= 2
198 strs = [self.typeStr(t) for t in dtype]
199 # Limit types to the first 2 as the 3rd is the accumulator
200 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100202 if dtype in DTYPE_ATTRIBUTES:
203 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700204 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100205 raise Exception(
206 "Unknown dtype, cannot convert to string: {}".format(dtype)
207 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100209 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100210 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100211 if dtype in DTYPE_ATTRIBUTES:
212 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 # Argument generators
217 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
218 # Where the string descriptor is used to generate the test name and
219 # The build_fcn_arg_list is expanded and passed to the operator test
220 # build function
221
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100222 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
223 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
224
Matthew Haddon848efb42021-09-09 12:30:53 +0100225 # build_placeholder returns an int, ABS/other ops does not
226 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100228 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000229 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000230 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100231 return result_tens
232
233 # Ensure new output type has correct qinfo
234 if error_name == ErrorIf.WrongOutputType:
235 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000236 qinfo = [
237 TosaQuantGen.getZeroPoint(self, a.dtype),
238 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
239 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100240
241 # Invalidate Input/Output list for error if checks.
242 input_list = [a.name]
243 output_list = [result_tens.name]
244 pCount, cCount = op["operands"]
245 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000246 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
247 self, error_name, input_list, output_list
248 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100249
Les Bell729b0352021-11-24 10:28:21 +0000250 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100251 self.ser,
252 validator_fcns,
253 error_name,
254 op=op,
255 input_dtype=a.dtype,
256 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000257 qinfo=qinfo,
258 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100259 input_list=input_list,
260 output_list=output_list,
261 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000262 ):
263 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100264
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000265 attr = None
266 if op["op"] == Op.NEGATE:
267 attr = ts.TosaSerializerAttribute()
268 attr.NegateAttribute(qinfo[0], qinfo[1])
269
270 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700271 return result_tens
272
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100273 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000274 result_tens = OutputShaper.binaryBroadcastOp(
275 self.ser, self.rng, a, b, error_name
276 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100277
278 # Invalidate Input/Output list for error if checks.
279 input_list = [a.name, b.name]
280 output_list = [result_tens.name]
281 pCount, cCount = op["operands"]
282 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000283 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
284 self, error_name, input_list, output_list
285 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100286
Les Bell729b0352021-11-24 10:28:21 +0000287 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100288 self.ser,
289 validator_fcns,
290 error_name,
291 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000292 input1=a,
293 input2=b,
294 input_dtype=a.dtype,
295 output_dtype=result_tens.dtype,
296 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100297 input_list=input_list,
298 output_list=output_list,
299 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000300 ):
301 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100302
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000303 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 return result_tens
305
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100306 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700307 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000308 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700309 return result_tens
310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000311 def build_arithmetic_right_shift(
312 self, op, a, b, round, validator_fcns=None, error_name=None
313 ):
314 result_tens = OutputShaper.binaryBroadcastOp(
315 self.ser, self.rng, a, b, error_name
316 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100317
318 # Invalidate Input/Output list for error if checks.
319 input_list = [a.name, b.name]
320 output_list = [result_tens.name]
321 pCount, cCount = op["operands"]
322 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000323 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
324 self, error_name, input_list, output_list
325 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100326
Les Bell729b0352021-11-24 10:28:21 +0000327 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100328 self.ser,
329 validator_fcns,
330 error_name,
331 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000332 input1=a,
333 input2=b,
334 input_dtype=a.dtype,
335 output_dtype=result_tens.dtype,
336 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100337 input_list=input_list,
338 output_list=output_list,
339 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000340 ):
341 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800342
343 attr = ts.TosaSerializerAttribute()
344 attr.ArithmeticRightShiftAttribute(round)
345
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000346 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800347 return result_tens
348
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100349 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000350 result_tens = OutputShaper.binaryBroadcastOp(
351 self.ser, self.rng, a, b, error_name
352 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
354 # Special for multiply:
355 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100356 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700357 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100358 if error_name == ErrorIf.WrongOutputType:
359 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
360 outputDType = self.rng.choice(all_dtypes)
361 result_tens.setDtype(outputDType)
362
363 # Invalidate Input/Output list for error if checks.
364 input_list = [a.name, b.name]
365 output_list = [result_tens.name]
366 pCount, cCount = op["operands"]
367 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000368 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
369 self, error_name, input_list, output_list
370 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100371
Les Bell729b0352021-11-24 10:28:21 +0000372 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100373 self.ser,
374 validator_fcns,
375 error_name,
376 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000377 input1=a,
378 input2=b,
379 input_dtype=a.dtype,
380 output_dtype=result_tens.dtype,
381 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100382 input_list=input_list,
383 output_list=output_list,
384 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000385 ):
386 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700387
Kevin Chengaee1fac2020-11-11 13:54:06 -0800388 attr = ts.TosaSerializerAttribute()
389 attr.MulAttribute(shift)
390
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000391 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 return result_tens
393
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100394 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
395 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700396
Kevin Chengfe392ce2021-10-18 21:51:55 +0000397 attr = ts.TosaSerializerAttribute()
398 attr.TableAttribute(table)
399
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100400 # Invalidate Input/Output list for error if checks.
401 input_list = [a.name]
402 output_list = [result_tens.name]
403 pCount, cCount = op["operands"]
404 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000405 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
406 self, error_name, input_list, output_list
407 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100408
Les Bell729b0352021-11-24 10:28:21 +0000409 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410 self.ser,
411 validator_fcns,
412 error_name,
413 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000414 input_shape=a.shape,
415 input_dtype=a.dtype,
416 output_dtype=result_tens.dtype,
417 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100418 input_list=input_list,
419 output_list=output_list,
420 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000421 ):
422 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100423
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700425
426 return result_tens
427
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100428 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
429 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
430
431 # Invalidate Input/Output list for error if checks.
432 input_list = [cond.name, a.name, b.name]
433 output_list = [result_tens.name]
434 pCount, cCount = op["operands"]
435 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
437 self, error_name, input_list, output_list
438 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100439
Les Bell729b0352021-11-24 10:28:21 +0000440 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100441 self.ser,
442 validator_fcns,
443 error_name,
444 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000445 input1=cond,
446 input2=a,
447 input3=b,
448 input_shape=a.shape,
449 input_dtype=a.dtype,
450 output_dtype=result_tens.dtype,
451 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100452 input_list=input_list,
453 output_list=output_list,
454 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000455 ):
456 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000458 self.ser.addOperator(
459 op["op"],
460 input_list,
461 output_list,
462 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700463 return result_tens
464
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100465 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000466 result_tens = OutputShaper.binaryComparisonOp(
467 self.ser, self.rng, a, b, error_name
468 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100469
470 # Invalidate Input/Output list for error if checks.
471 input_list = [a.name, b.name]
472 output_list = [result_tens.name]
473 pCount, cCount = op["operands"]
474 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000475 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
476 self, error_name, input_list, output_list
477 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478
Les Bell729b0352021-11-24 10:28:21 +0000479 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480 self.ser,
481 validator_fcns,
482 error_name,
483 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000484 input1=a,
485 input2=b,
486 input_shape=a.shape,
487 input_dtype=a.dtype,
488 output_shape=result_tens.shape,
489 output_dtype=result_tens.dtype,
490 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491 input_list=input_list,
492 output_list=output_list,
493 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000494 ):
495 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100496
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000497 self.ser.addOperator(
498 op["op"],
499 input_list,
500 output_list,
501 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700502 return result_tens
503
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100504 def build_argmax(self, op, a, axis, validator_fcns, error_name):
505 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
506
507 # Invalidate Input/Output list for error if checks.
508 input_list = [a.name]
509 output_list = [result_tens.name]
510 pCount, cCount = op["operands"]
511 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000512 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
513 self, error_name, input_list, output_list
514 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100515
Les Bell729b0352021-11-24 10:28:21 +0000516 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100517 self.ser,
518 validator_fcns,
519 error_name,
520 op=op,
521 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 input_shape=a.shape,
523 input_dtype=a.dtype,
524 output_shape=result_tens.shape,
525 output_dtype=result_tens.dtype,
526 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100527 input_list=input_list,
528 output_list=output_list,
529 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000530 ):
531 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700532
533 attr = ts.TosaSerializerAttribute()
534 attr.AxisAttribute(axis)
535
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700537 return result_tens
538
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000539 def build_pool2d(
540 self,
541 op,
542 input,
James Ward8b390432022-08-12 20:48:56 +0100543 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 stride,
545 pad,
546 kernel,
547 validator_fcns=None,
548 error_name=None,
549 qinfo=None,
550 ):
551 result_tens = OutputShaper.pool2dOp(
552 self.ser, self.rng, input, kernel, stride, pad, error_name
553 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100554
555 # Ensure new output type has correct qinfo
556 if error_name == ErrorIf.WrongInputType:
557 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000558 qinfo = [
559 TosaQuantGen.getZeroPoint(self, input.dtype),
560 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
561 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100562
563 # Invalidate Input/Output list for error if checks.
564 input_list = [input.name]
565 output_list = [result_tens.name]
566 pCount, cCount = op["operands"]
567 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000568 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
569 self, error_name, input_list, output_list
570 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100571
Les Bell729b0352021-11-24 10:28:21 +0000572 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100573 self.ser,
574 validator_fcns,
575 error_name,
576 op=op,
577 input_shape=input.shape,
578 input_dtype=input.dtype,
579 output_shape=result_tens.shape,
580 output_dtype=result_tens.dtype,
581 kernel=kernel,
582 stride=stride,
583 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 qinfo=qinfo,
585 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100586 input_list=input_list,
587 output_list=output_list,
588 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000589 ):
590 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700591
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000592 if qinfo is None:
593 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700594
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000595 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100596 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000597
598 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700599 return result_tens
600
James Ward8b390432022-08-12 20:48:56 +0100601 def build_maxpool2d(
602 self,
603 op,
604 input,
605 stride,
606 pad,
607 kernel,
608 validator_fcns=None,
609 error_name=None,
610 qinfo=None,
611 ):
612 # Same as build_pool2d but manually sets accum_dtype value
613 # (maxpool has no accum_dtype)
614 return self.build_pool2d(
615 op,
616 input,
617 DType.UNKNOWN,
618 stride,
619 pad,
620 kernel,
621 validator_fcns,
622 error_name,
623 qinfo,
624 )
625
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 def build_conv2d(
627 self,
628 op,
629 ifm,
630 filter,
631 bias,
James Ward8b390432022-08-12 20:48:56 +0100632 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000633 strides,
634 padding,
635 dilations,
636 validator_fcns=None,
637 error_name=None,
638 qinfo=None,
639 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800640 assert len(padding) == 4
641 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100642 self.ser,
643 self.rng,
644 ifm,
645 filter,
646 accum_dtype,
647 strides,
648 padding,
649 dilations,
650 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000651 )
652
653 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
655 DType.INT8,
656 DType.UINT8,
657 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000658 qinfo = [
659 TosaQuantGen.getZeroPoint(self, ifm.dtype),
660 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
661 ]
Les Bell0e027d42021-11-09 14:42:14 +0000662
663 # Invalidate Input/Output list for error_if checks.
664 input_list = [ifm.name, filter.name, bias.name]
665 output_list = [result_tens.name]
666 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
668 self, error_name, input_list, output_list
669 )
Les Bell0e027d42021-11-09 14:42:14 +0000670
Les Bell729b0352021-11-24 10:28:21 +0000671 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000672 self.ser,
673 validator_fcns,
674 error_name,
675 op=op,
676 input_dtype=ifm.dtype,
677 weight_dtype=filter.dtype,
678 output_dtype=result_tens.dtype,
679 qinfo=qinfo,
680 input_list=input_list,
681 num_operands=num_operands,
682 output_list=output_list,
683 pad=padding,
684 stride=strides,
685 dilation=dilations,
686 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100687 weight_shape=filter.shape,
688 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000689 ):
690 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700691
692 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100693 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700694
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000695 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700696 return result_tens
697
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 def build_conv3d(
699 self,
700 op,
701 ifm,
702 filter,
703 bias,
James Ward8b390432022-08-12 20:48:56 +0100704 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 strides,
706 padding,
707 dilations,
708 validator_fcns=None,
709 error_name=None,
710 qinfo=None,
711 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700712 assert len(padding) == 6
713 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100714 self.ser,
715 self.rng,
716 ifm,
717 filter,
718 accum_dtype,
719 strides,
720 padding,
721 dilations,
722 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000723 )
724
725 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000726 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
727 DType.INT8,
728 DType.UINT8,
729 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000730 qinfo = [
731 TosaQuantGen.getZeroPoint(self, ifm.dtype),
732 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
733 ]
Les Bell0e027d42021-11-09 14:42:14 +0000734
735 # Invalidate Input/Output list for error_if checks.
736 input_list = [ifm.name, filter.name, bias.name]
737 output_list = [result_tens.name]
738 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000739 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
740 self, error_name, input_list, output_list
741 )
Les Bell0e027d42021-11-09 14:42:14 +0000742
Les Bell729b0352021-11-24 10:28:21 +0000743 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000744 self.ser,
745 validator_fcns,
746 error_name,
747 op=op,
748 input_dtype=ifm.dtype,
749 weight_dtype=filter.dtype,
750 output_dtype=result_tens.dtype,
751 qinfo=qinfo,
752 input_list=input_list,
753 num_operands=num_operands,
754 output_list=output_list,
755 pad=padding,
756 stride=strides,
757 dilation=dilations,
758 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100759 weight_shape=filter.shape,
760 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000761 ):
762 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700763
764 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100765 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Kevin Cheng1533b852021-09-01 12:51:58 -0700766
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000767 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700768 return result_tens
769
Kevin Cheng550ccc52021-03-03 11:21:43 -0800770 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000771 self,
772 op,
773 ifm,
774 filter,
775 bias,
James Ward8b390432022-08-12 20:48:56 +0100776 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000777 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700778 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000779 output_shape,
780 validator_fcns=None,
781 error_name=None,
782 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800783 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700784 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000785 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100786 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 )
Les Bell0e027d42021-11-09 14:42:14 +0000788
789 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000790 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
791 DType.INT8,
792 DType.UINT8,
793 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000794 qinfo = [
795 TosaQuantGen.getZeroPoint(self, ifm.dtype),
796 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
797 ]
Les Bell0e027d42021-11-09 14:42:14 +0000798
799 # Invalidate Input/Output list for error_if checks.
800 input_list = [ifm.name, filter.name, bias.name]
801 output_list = [result_tens.name]
802 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
804 self, error_name, input_list, output_list
805 )
Les Bell0e027d42021-11-09 14:42:14 +0000806
Les Bell729b0352021-11-24 10:28:21 +0000807 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000808 self.ser,
809 validator_fcns,
810 error_name,
811 op=op,
812 input_dtype=ifm.dtype,
813 weight_dtype=filter.dtype,
814 output_dtype=result_tens.dtype,
815 qinfo=qinfo,
816 input_list=input_list,
817 num_operands=num_operands,
818 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700819 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000820 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000821 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100822 weight_shape=filter.shape,
823 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000824 ):
825 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700826
827 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100828 attr.TransposeConvAttribute(
829 out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype
830 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700831
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000832 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700833 return result_tens
834
Kevin Cheng550ccc52021-03-03 11:21:43 -0800835 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000836 self,
837 op,
838 ifm,
839 filter,
840 bias,
James Ward8b390432022-08-12 20:48:56 +0100841 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000842 strides,
843 padding,
844 dilations,
845 validator_fcns=None,
846 error_name=None,
847 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800848 ):
849 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100850 self.ser,
851 self.rng,
852 ifm,
853 filter,
854 accum_dtype,
855 strides,
856 padding,
857 dilations,
858 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000859 )
860
861 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000862 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
863 DType.INT8,
864 DType.UINT8,
865 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000866 qinfo = [
867 TosaQuantGen.getZeroPoint(self, ifm.dtype),
868 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
869 ]
Les Bell0e027d42021-11-09 14:42:14 +0000870
871 # Invalidate Input/Output list for error_if checks.
872 input_list = [ifm.name, filter.name, bias.name]
873 output_list = [result_tens.name]
874 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000875 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
876 self, error_name, input_list, output_list
877 )
Les Bell0e027d42021-11-09 14:42:14 +0000878
Les Bell729b0352021-11-24 10:28:21 +0000879 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000880 self.ser,
881 validator_fcns,
882 error_name,
883 op=op,
884 input_dtype=ifm.dtype,
885 weight_dtype=filter.dtype,
886 output_dtype=result_tens.dtype,
887 qinfo=qinfo,
888 input_list=input_list,
889 num_operands=num_operands,
890 output_list=output_list,
891 pad=padding,
892 stride=strides,
893 dilation=dilations,
894 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100895 weight_shape=filter.shape,
896 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000897 ):
898 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700899
900 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100901 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700902
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000903 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700904 return result_tens
905
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100907 self,
908 op,
909 ifm,
910 filter,
911 bias,
912 accum_dtype,
913 validator_fcns=None,
914 error_name=None,
915 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000916 ):
917 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100918 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000919 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100920
921 # Invalidate Input/Output list for error if checks.
922 input_list = [ifm.name, filter.name, bias.name]
923 output_list = [result_tens.name]
924 pCount, cCount = op["operands"]
925 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000926 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
927 self, error_name, input_list, output_list
928 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100929
Les Bell729b0352021-11-24 10:28:21 +0000930 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100931 self.ser,
932 validator_fcns,
933 error_name,
934 op=op,
935 input_shape=ifm.shape,
936 input_dtype=ifm.dtype,
937 weight_dtype=filter.dtype,
938 output_shape=result_tens.shape,
939 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000940 qinfo=qinfo,
941 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100942 input_list=input_list,
943 output_list=output_list,
944 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100945 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000946 ):
947 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700948
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000949 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100950 attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000951
952 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700953 return result_tens
954
James Ward8b390432022-08-12 20:48:56 +0100955 def build_matmul(
956 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
957 ):
958 result_tens = OutputShaper.matmulOp(
959 self.ser, self.rng, a, b, accum_dtype, error_name
960 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100961
962 # Invalidate Input/Output list for error if checks.
963 input_list = [a.name, b.name]
964 output_list = [result_tens.name]
965 pCount, cCount = op["operands"]
966 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
968 self, error_name, input_list, output_list
969 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100970
Les Bell729b0352021-11-24 10:28:21 +0000971 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100972 self.ser,
973 validator_fcns,
974 error_name,
975 op=op,
976 input_shape=a.shape,
977 input_dtype=a.dtype,
978 input2_shape=b.shape,
979 input2_dtype=b.dtype,
980 output_shape=result_tens.shape,
981 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000982 qinfo=qinfo,
983 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100984 input_list=input_list,
985 output_list=output_list,
986 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100987 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000988 ):
989 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100990
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000991 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100992 attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000993
994 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700995 return result_tens
996
Matthew Haddond6ce7252021-09-29 15:35:44 +0100997 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
998 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
999
1000 # Invalidate Input/Output list for error if checks.
1001 input_list = [a.name]
1002 output_list = [result_tens.name]
1003 pCount, cCount = op["operands"]
1004 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001005 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1006 self, error_name, input_list, output_list
1007 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001008
Les Bell729b0352021-11-24 10:28:21 +00001009 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001010 self.ser,
1011 validator_fcns,
1012 error_name,
1013 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001014 axis=axis,
1015 input_shape=a.shape,
1016 output_shape=result_tens.shape,
1017 input_dtype=a.dtype,
1018 output_dtype=result_tens.dtype,
1019 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +01001020 input_list=input_list,
1021 output_list=output_list,
1022 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001023 ):
1024 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001025
1026 attr = ts.TosaSerializerAttribute()
1027 attr.AxisAttribute(axis)
1028
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001029 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001030 return result_tens
1031
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001032 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1033 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001034
Jeremy Johnson18e26662021-07-22 16:15:29 +01001035 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001036
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001037 if error_name == ErrorIf.MaxSmallerMin:
1038 # Make sure the numbers are different to invoke this error
1039 while v[0] == v[1]:
1040 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1041 max_val = min(v)
1042 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001044 max_val = max(v)
1045 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001046
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001047 # Invalidate Input/Output list for error if checks.
1048 input_list = [a.name]
1049 output_list = [result_tens.name]
1050 pCount, cCount = op["operands"]
1051 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001052 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1053 self, error_name, input_list, output_list
1054 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001055
Les Bell729b0352021-11-24 10:28:21 +00001056 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001057 self.ser,
1058 validator_fcns,
1059 error_name,
1060 op=op,
1061 max_val=max_val,
1062 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001063 input_shape=a.shape,
1064 output_shape=result_tens.shape,
1065 input_dtype=a.dtype,
1066 output_dtype=result_tens.dtype,
1067 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001068 input_list=input_list,
1069 output_list=output_list,
1070 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001071 ):
1072 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001073
1074 attr = ts.TosaSerializerAttribute()
James Ward24dbc422022-10-19 12:20:31 +01001075 if a.dtype in (DType.FP16, DType.BF16, DType.FP32):
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001076 attr.ClampAttribute(0, 0, min_val, max_val)
1077 else:
1078 attr.ClampAttribute(min_val, max_val, 0, 0)
1079
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001080 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001081 return result_tens
1082
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001083 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1084 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001085 attr = ts.TosaSerializerAttribute()
1086
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001087 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001088
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001089 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001090 return result_tens
1091
1092 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001093 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1094 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001095
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001096 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001097 return result_tens
1098
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001099 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1100 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1101
1102 # Invalidate Input/Output list for error if checks.
1103 input_list = [a.name]
1104 output_list = [result_tens.name]
1105 pCount, cCount = op["operands"]
1106 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001107 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1108 self, error_name, input_list, output_list
1109 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001110
Les Bell729b0352021-11-24 10:28:21 +00001111 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001112 self.ser,
1113 validator_fcns,
1114 error_name,
1115 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001116 input_shape=a.shape,
1117 output_shape=result_tens.shape,
1118 input_dtype=a.dtype,
1119 output_dtype=result_tens.dtype,
1120 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001121 input_list=input_list,
1122 output_list=output_list,
1123 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001124 ):
1125 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001126
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001127 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001128 return result_tens
1129
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001130 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1131 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1132
1133 # Invalidate Input/Output list for error if checks.
1134 input_list = [a.name]
1135 output_list = [result_tens.name]
1136 pCount, cCount = op["operands"]
1137 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001138 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1139 self, error_name, input_list, output_list
1140 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001141
Les Bell729b0352021-11-24 10:28:21 +00001142 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001143 self.ser,
1144 validator_fcns,
1145 error_name,
1146 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001147 input_shape=a.shape,
1148 output_shape=result_tens.shape,
1149 input_dtype=a.dtype,
1150 output_dtype=result_tens.dtype,
1151 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001152 input_list=input_list,
1153 output_list=output_list,
1154 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001155 ):
1156 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001157
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001158 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001159 return result_tens
1160
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001161 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1162 if error_name != ErrorIf.WrongInputType:
1163 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001164
1165 # To store variable length list of input tensors we need to store axis along with it
1166 axis = a[-1]
1167 a = a[:-1]
1168
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001169 result_tens = OutputShaper.concatOp(
1170 self.ser, self.rng, axis, *a, error_name=error_name
1171 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
Matthew Haddon818ab902021-07-27 09:12:49 +01001173 input_tensor_names = []
1174 for tensor in a:
1175 input_tensor_names.append(tensor.name)
1176
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001177 # Invalidate Input/Output list for error if checks.
1178 input_list = input_tensor_names
1179 output_list = [result_tens.name]
1180 pCount, cCount = op["operands"]
1181 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001182 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1183 self, error_name, input_list, output_list
1184 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001185
Les Bell729b0352021-11-24 10:28:21 +00001186 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001187 self.ser,
1188 validator_fcns,
1189 error_name,
1190 op=op,
1191 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001192 input_shape=a[0].shape,
1193 output_shape=result_tens.shape,
1194 input_dtype=a[0].dtype,
1195 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001196 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001198 input_list=input_list,
1199 output_list=output_list,
1200 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001201 ):
1202 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001203
1204 attr = ts.TosaSerializerAttribute()
1205 attr.AxisAttribute(axis)
1206
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001207 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001208 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001209
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001210 def build_pad(
1211 self,
1212 op,
1213 a,
1214 padding,
1215 pad_const_int,
1216 pad_const_float,
1217 validator_fcns=None,
1218 error_name=None,
1219 qinfo=None,
1220 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001221 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001222
Kevin Chengfe392ce2021-10-18 21:51:55 +00001223 attr = ts.TosaSerializerAttribute()
1224 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001225
Matthew Haddone807aae2021-10-11 18:12:58 +01001226 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001227 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001228 output_list = [result_tens.name]
1229 pCount, cCount = op["operands"]
1230 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001231 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1232 self, error_name, input_list, output_list
1233 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001234
Les Bell729b0352021-11-24 10:28:21 +00001235 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001236 self.ser,
1237 validator_fcns,
1238 error_name,
1239 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001240 input_shape=a.shape,
1241 output_shape=result_tens.shape,
1242 input_dtype=a.dtype,
1243 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001244 pad=padding,
1245 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001246 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001247 input_list=input_list,
1248 output_list=output_list,
1249 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001250 ):
1251 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001252
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001253 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001254 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001255
Matthew Haddone807aae2021-10-11 18:12:58 +01001256 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 result_tens = OutputShaper.reshapeOp(
1258 self.ser, self.rng, a, newShape, error_name
1259 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001260
1261 # Invalidate Input/Output list for error if checks.
1262 input_list = [a.name]
1263 output_list = [result_tens.name]
1264 pCount, cCount = op["operands"]
1265 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001266 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1267 self, error_name, input_list, output_list
1268 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001269
Les Bell729b0352021-11-24 10:28:21 +00001270 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001271 self.ser,
1272 validator_fcns,
1273 error_name,
1274 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001275 input_shape=a.shape,
1276 output_shape=result_tens.shape,
1277 input_dtype=a.dtype,
1278 output_dtype=result_tens.dtype,
1279 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001280 input_list=input_list,
1281 output_list=output_list,
1282 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001283 ):
1284 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001285
1286 attr = ts.TosaSerializerAttribute()
1287 attr.ReshapeAttribute(newShape)
1288
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001289 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001290 return result_tens
1291
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001292 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1293 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1294
1295 # Invalidate Input/Output list for error if checks.
1296 input_list = [a.name]
1297 output_list = [result_tens.name]
1298 pCount, cCount = op["operands"]
1299 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001300 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1301 self, error_name, input_list, output_list
1302 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001303
Les Bell729b0352021-11-24 10:28:21 +00001304 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001305 self.ser,
1306 validator_fcns,
1307 error_name,
1308 op=op,
1309 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 input_shape=a.shape,
1311 output_shape=result_tens.shape,
1312 input_dtype=a.dtype,
1313 output_dtype=result_tens.dtype,
1314 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001315 input_list=input_list,
1316 output_list=output_list,
1317 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001318 ):
1319 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001320
1321 attr = ts.TosaSerializerAttribute()
1322 attr.AxisAttribute(axis)
1323
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001324 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001325 return result_tens
1326
Matthew Haddone807aae2021-10-11 18:12:58 +01001327 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1328 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001329
Kevin Chengfe392ce2021-10-18 21:51:55 +00001330 attr = ts.TosaSerializerAttribute()
1331 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001332
Matthew Haddone807aae2021-10-11 18:12:58 +01001333 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001334 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001335 output_list = [result_tens.name]
1336 pCount, cCount = op["operands"]
1337 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001338 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1339 self, error_name, input_list, output_list
1340 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001341
Les Bell729b0352021-11-24 10:28:21 +00001342 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001343 self.ser,
1344 validator_fcns,
1345 error_name,
1346 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001347 input_shape=a.shape,
1348 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001349 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001350 input_dtype=a.dtype,
1351 output_dtype=result_tens.dtype,
1352 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001353 input_list=input_list,
1354 output_list=output_list,
1355 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001356 ):
1357 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001358
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001359 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001360 return result_tens
1361
Matthew Haddone807aae2021-10-11 18:12:58 +01001362 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001363 result_tens = OutputShaper.sliceOp(
1364 self.ser, self.rng, a, start, size, error_name
1365 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001366
1367 # Invalidate Input/Output list for error if checks.
1368 input_list = [a.name]
1369 output_list = [result_tens.name]
1370 pCount, cCount = op["operands"]
1371 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1373 self, error_name, input_list, output_list
1374 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001375
Les Bell729b0352021-11-24 10:28:21 +00001376 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001377 self.ser,
1378 validator_fcns,
1379 error_name,
1380 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001381 input_shape=a.shape,
1382 output_shape=result_tens.shape,
1383 input_dtype=a.dtype,
1384 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001385 start=start,
1386 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001387 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001388 input_list=input_list,
1389 output_list=output_list,
1390 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001391 ):
1392 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001393
1394 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001395 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001396
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001397 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001398 return result_tens
1399
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001400 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1401 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1402
1403 # Invalidate Input/Output list for error if checks.
1404 input_list = [a.name]
1405 output_list = [result_tens.name]
1406 pCount, cCount = op["operands"]
1407 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1409 self, error_name, input_list, output_list
1410 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001411
Les Bell729b0352021-11-24 10:28:21 +00001412 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413 self.ser,
1414 validator_fcns,
1415 error_name,
1416 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001417 input_shape=a.shape,
1418 output_shape=result_tens.shape,
1419 input_dtype=a.dtype,
1420 output_dtype=result_tens.dtype,
1421 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422 input_list=input_list,
1423 output_list=output_list,
1424 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001425 ):
1426 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001427
1428 attr = ts.TosaSerializerAttribute()
1429 attr.TileAttribute(multiples)
1430
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001431 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001432 return result_tens
1433
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001434 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001435
1436 # Create a new indicies tensor
1437 # here with data that doesn't exceed the dimensions of the values tensor
1438
Kevin Cheng550ccc52021-03-03 11:21:43 -08001439 K = values.shape[1] # K
1440 W = self.randInt(
1441 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1442 ) # W
1443 indicies_arr = np.int32(
1444 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1445 ) # (N, W)
1446 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001447
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001448 result_tens = OutputShaper.gatherOp(
1449 self.ser, self.rng, values, indicies, error_name
1450 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001451
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001452 # Invalidate Input/Output list for error if checks.
1453 input_list = [values.name, indicies.name]
1454 output_list = [result_tens.name]
1455 pCount, cCount = op["operands"]
1456 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001457 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1458 self, error_name, input_list, output_list
1459 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001460
Les Bell729b0352021-11-24 10:28:21 +00001461 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001462 self.ser,
1463 validator_fcns,
1464 error_name,
1465 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001466 input_shape=values.shape,
1467 output_shape=result_tens.shape,
1468 input_dtype=values.dtype,
1469 output_dtype=result_tens.dtype,
1470 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001471 input_list=input_list,
1472 output_list=output_list,
1473 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001474 ):
1475 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001476
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001477 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001478
1479 return result_tens
1480
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001481 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001482
1483 # Create a new indicies tensor
1484 # here with data that doesn't exceed the dimensions of the values_in tensor
1485
Kevin Cheng550ccc52021-03-03 11:21:43 -08001486 K = values_in.shape[1] # K
1487 W = input.shape[1] # W
1488 indicies_arr = np.int32(
1489 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1490 ) # (N, W)
1491 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001492
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001493 result_tens = OutputShaper.scatterOp(
1494 self.ser, self.rng, values_in, indicies, input, error_name
1495 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001496
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001497 # Invalidate Input/Output list for error if checks.
1498 input_list = [values_in.name, indicies.name, input.name]
1499 output_list = [result_tens.name]
1500 pCount, cCount = op["operands"]
1501 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001502 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1503 self, error_name, input_list, output_list
1504 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001505
Les Bell729b0352021-11-24 10:28:21 +00001506 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001507 self.ser,
1508 validator_fcns,
1509 error_name,
1510 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001511 input_shape=values_in.shape,
1512 output_shape=result_tens.shape,
1513 input_dtype=values_in.dtype,
1514 output_dtype=result_tens.dtype,
1515 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001516 input_list=input_list,
1517 output_list=output_list,
1518 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001519 ):
1520 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001521
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001522 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001523
Kevin Cheng77d0f762020-11-24 10:26:32 -08001524 return result_tens
1525
Kevin Cheng550ccc52021-03-03 11:21:43 -08001526 def build_resize(
1527 self,
1528 op,
1529 input,
1530 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001531 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001532 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001533 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001534 input_dtype,
1535 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001536 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001537 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001538 ):
1539 result_tens = OutputShaper.resizeOp(
1540 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001541 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001542 input,
1543 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001544 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001545 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001546 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001547 input_dtype,
1548 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001549 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001550 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001551
Matthew Haddon848efb42021-09-09 12:30:53 +01001552 # Invalidate Input/Output list for error if checks.
1553 input_list = [input.name]
1554 output_list = [result_tens.name]
1555 pCount, cCount = op["operands"]
1556 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001557 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1558 self, error_name, input_list, output_list
1559 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001560
Les Bell729b0352021-11-24 10:28:21 +00001561 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001562 self.ser,
1563 validator_fcns,
1564 error_name,
1565 op=op,
1566 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001567 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001568 input_dtype=input_dtype,
1569 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001570 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001571 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001572 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001573 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001574 input_list=input_list,
1575 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001576 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001577 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001578 ):
1579 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001580
Eric Kunzee5e26762020-10-13 16:11:07 -07001581 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001582
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001583 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001584
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001585 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001586 return result_tens
1587
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001588 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1589 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1590 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001591 self.ser.addOperator(
1592 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1593 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001594 return result_tens
1595
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001596 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001597 self.ser.addOutputTensor(val)
1598 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001599
1600 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001601 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001602 result_tens = OutputShaper.typeConversionOp(
1603 self.ser, self.rng, val, out_dtype, error_name
1604 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001605
1606 # Invalidate Input/Output list for error if checks.
1607 input_list = [val.name]
1608 output_list = [result_tens.name]
1609 pCount, cCount = op["operands"]
1610 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001611 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1612 self, error_name, input_list, output_list
1613 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001614
Les Bell729b0352021-11-24 10:28:21 +00001615 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001616 self.ser,
1617 validator_fcns,
1618 error_name,
1619 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001620 input_shape=val.shape,
1621 output_shape=result_tens.shape,
1622 input_dtype=val.dtype,
1623 output_dtype=result_tens.dtype,
1624 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001625 input_list=input_list,
1626 output_list=output_list,
1627 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001628 ):
1629 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001630
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001631 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001632 return result_tens
1633
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001634 def build_rescale(
1635 self,
1636 op,
1637 val,
1638 out_dtype,
1639 scale32,
1640 double_round,
1641 per_channel,
1642 validator_fcns,
1643 error_name,
1644 ):
1645 result_tens = OutputShaper.typeConversionOp(
1646 self.ser, self.rng, val, out_dtype, error_name
1647 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001648
1649 if per_channel:
1650 nc = val.shape[-1]
1651 else:
1652 nc = 1
1653
1654 in_type_width = self.typeWidth(val.dtype)
1655 out_type_width = self.typeWidth(out_dtype)
1656
Kevin Cheng3a478572021-01-22 17:21:02 -08001657 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001658 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001659 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001660 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001661 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001662 in_type_width += 1
1663 elif error_name in [
1664 ErrorIf.InputZeroPointNotZero,
1665 ErrorIf.U16InputZeroPointNotValid,
1666 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001667 input_zp = self.randInt(-128, 128)
1668 if input_zp == 0:
1669 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001670 in_type_width += 1
1671 elif val.dtype == DType.UINT16:
1672 # Must come after ErrorIf.U16InputZeroPointNotValid check
1673 input_zp = self.rng.choice([0, 32768])
1674 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001675 else:
1676 input_zp = 0
1677
Kevin Cheng3a478572021-01-22 17:21:02 -08001678 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001679 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001680 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001681 elif out_dtype == DType.UINT8:
1682 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001683 out_type_width += 1
1684 elif error_name in [
1685 ErrorIf.OutputZeroPointNotZero,
1686 ErrorIf.U16OutputZeroPointNotValid,
1687 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001688 output_zp = self.randInt(-128, 128)
1689 if output_zp == 0:
1690 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001691 out_type_width += 1
1692 elif out_dtype == DType.UINT16:
1693 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1694 output_zp = self.rng.choice([0, 32768])
1695 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001696 else:
1697 output_zp = 0
1698
1699 # Calculate scale based on:
1700 # scale = a *(2^output_width)/(2^input_width))
1701
1702 a = np.float32(self.rng.random(size=[nc]))
1703 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1704
1705 if scale32:
1706 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001707 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001708 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1709 else:
1710 # Cap the scaling at 2^15 - 1 for scale16
1711 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1712
Kevin Cheng550ccc52021-03-03 11:21:43 -08001713 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001714
1715 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1716 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001717 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1718 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001719
1720 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001721 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1722 scale_arr[i], scale32
1723 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001724 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1725 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
Kevin Cheng550ccc52021-03-03 11:21:43 -08001727 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001728 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001729 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001730 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001731 assert val.placeholderFilename
1732 values = np.load(
1733 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1734 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001735 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1736 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1737 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1738 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001739 if not np.all(np.array_equal(values, val_adj)):
1740 # Values changed so overwrite file with new values
1741 np.save(
1742 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1743 val_adj,
1744 False,
1745 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001746
Matthew Haddonc2025212021-10-08 21:21:05 +01001747 # Invalidate Input/Output list for error if checks.
1748 input_list = [val.name]
1749 output_list = [result_tens.name]
1750 pCount, cCount = op["operands"]
1751 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001752 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1753 self, error_name, input_list, output_list
1754 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001755
1756 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001757 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001758 self.ser,
1759 validator_fcns,
1760 error_name,
1761 op=op,
1762 input_dtype=val.dtype,
1763 output_dtype=out_dtype,
1764 input_shape=val.shape,
1765 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001766 scale32=scale32,
1767 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001768 input_list=input_list,
1769 output_list=output_list,
1770 result_tensor=result_tens,
1771 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001772 ):
1773 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001774
Eric Kunzee5e26762020-10-13 16:11:07 -07001775 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001776 attr.RescaleAttribute(
1777 input_zp,
1778 output_zp,
1779 multiplier_arr,
1780 shift_arr,
1781 scale32,
1782 double_round,
1783 per_channel,
1784 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001785
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001786 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001787 return result_tens
1788
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001789 def _get_condition_tensor(self, op, cond, error_name):
1790 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1791 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1792 else:
1793 cond_type = DType.BOOL
1794 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1795 choice = self.rng.choice([1, 2])
1796 if choice == 1:
1797 cond_shape = [2]
1798 else:
1799 cond_shape = [1, 2]
1800 else:
1801 # Must be of size 1 (rank 0)
1802 cond_shape = []
1803 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1804 return cond_tens
1805
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001806 def build_cond_if_const(
1807 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1808 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001809 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001810 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001811 # and fill them with const nodes for the body.
1812
1813 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001814 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001815
1816 # Make then/else tensors
1817 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001818
1819 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001820 if error_name in [
1821 ErrorIf.CondIfOutputListThenGraphMismatch,
1822 ErrorIf.CondIfOutputListElseGraphMismatch,
1823 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001824 incorrect_shape = deepcopy(then_tens.shape)
1825 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001826 incorrect_shape[i] += (
1827 self.rng.choice([-3, -2, 2, 3])
1828 if incorrect_shape[i] > 3
1829 else self.rng.choice([1, 2, 4])
1830 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001831 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1832
Jeremy Johnson18e26662021-07-22 16:15:29 +01001833 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1834 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001835
1836 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001837 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001838
1839 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001840 then_block = "THEN_BLOCK"
1841 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001842 attr = ts.TosaSerializerAttribute()
1843 attr.CondIfAttribute(then_block, else_block)
1844
1845 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001846 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001847
Jerry Ge9e94af82022-10-27 09:57:00 -07001848 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001849 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001850 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1851 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1852 else:
1853 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001854 self.ser.addOutputTensor(then_tens)
1855
Jerry Ge9e94af82022-10-27 09:57:00 -07001856 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001857 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1858 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1859 else:
1860 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001861 self.ser.addOutputTensor(else_tens)
1862
Les Bell729b0352021-11-24 10:28:21 +00001863 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001864 self.ser,
1865 validator_fcns,
1866 error_name,
1867 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001868 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001869 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001870 ):
1871 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001872
Eric Kunzee5e26762020-10-13 16:11:07 -07001873 return result_tens
1874
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001875 def build_cond_if_binary(
1876 self, op, a, b, cond, validator_fcns=None, error_name=None
1877 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001878 # For cond_if with a binary op in the then/else blocks, take a and b and
1879 # alternately add or subtract them based on the condition
1880
1881 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001882 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001883
Kevin Cheng550ccc52021-03-03 11:21:43 -08001884 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001885
1886 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001887 then_block = "THEN_BLOCK"
1888 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001889 attr = ts.TosaSerializerAttribute()
1890 attr.CondIfAttribute(then_block, else_block)
1891
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001892 if error_name in [
1893 ErrorIf.CondIfInputListThenGraphMismatch,
1894 ErrorIf.CondIfInputListElseGraphMismatch,
1895 ErrorIf.CondIfOutputListElseGraphMismatch,
1896 ErrorIf.CondIfOutputListThenGraphMismatch,
1897 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001898 incorrect_shape = a.shape.copy()
1899 for i in range(len(incorrect_shape)):
1900 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1901 incorrect_block_input = deepcopy(a)
1902 incorrect_block_input.shape = incorrect_shape
1903
Eric Kunzee5e26762020-10-13 16:11:07 -07001904 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001905 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001906 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001907 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
James Ward24dbc422022-10-19 12:20:31 +01001909 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001910 then_op, else_op = Op.ADD, Op.SUB
1911 elif a.dtype in (DType.INT8, DType.INT16):
1912 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1913 else:
1914 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001915
Les Bell6040b4d2021-10-11 12:50:31 +01001916 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001917 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001918 if (
1919 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1920 and block == then_block
1921 ) or (
1922 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1923 and block == else_block
1924 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001925 self.ser.addInputTensor(incorrect_block_input)
1926 self.ser.addInputTensor(b)
1927 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001928 elif (
1929 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1930 and block == then_block
1931 ) or (
1932 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1933 and block == else_block
1934 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001935 self.ser.addInputTensor(a)
1936 self.ser.addInputTensor(b)
1937 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1938 else:
1939 self.ser.addInputTensor(a)
1940 self.ser.addInputTensor(b)
1941 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001942 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
Les Bell729b0352021-11-24 10:28:21 +00001944 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001945 self.ser,
1946 validator_fcns,
1947 error_name,
1948 op=op,
1949 a=a,
1950 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07001951 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001952 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001953 ):
1954 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001955
Eric Kunzee5e26762020-10-13 16:11:07 -07001956 return result_tens
1957
Matthew Haddon630c17c2021-10-14 15:05:41 +01001958 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001959 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001960
Kevin Cheng550ccc52021-03-03 11:21:43 -08001961 cond_block = "COND_BLOCK"
1962 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001963
1964 attr = ts.TosaSerializerAttribute()
1965 attr.WhileLoopAttribute(cond_block, body_block)
1966
1967 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001968 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001969 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001970 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001971
1972 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001973 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1974 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001975 if error_name == ErrorIf.InputListOutputListMismatch:
1976 incorrect_acc = deepcopy(acc)
1977 for i in range(len(incorrect_acc.shape)):
1978 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1979 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1980 else:
1981 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001982
1983 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001984 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001985 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001986 [iter.name, a.name, acc.name],
1987 [iter_out.name, a_out.name, acc_out.name],
1988 attr,
1989 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001990 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001991
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001992 if error_name in [
1993 ErrorIf.InputListCondGraphMismatch,
1994 ErrorIf.InputListBodyGraphInputMismatch,
1995 ErrorIf.InputListBodyGraphOutputMismatch,
1996 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001997 incorrect_iter = deepcopy(iter)
1998 for i in range(len(incorrect_iter.shape)):
1999 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2000 if len(incorrect_iter.shape) == 0:
2001 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2002
2003 incorrect_acc = deepcopy(acc)
2004 for i in range(len(incorrect_acc.shape)):
2005 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2006
Eric Kunzee5e26762020-10-13 16:11:07 -07002007 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002008 self.ser.addBasicBlock(cond_block)
2009
Matthew Haddon630c17c2021-10-14 15:05:41 +01002010 if error_name == ErrorIf.InputListCondGraphMismatch:
2011 self.ser.addInputTensor(incorrect_iter)
2012 self.ser.addInputTensor(a)
2013 self.ser.addInputTensor(incorrect_acc)
2014 else:
2015 self.ser.addInputTensor(iter)
2016 self.ser.addInputTensor(a)
2017 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002018 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002019
2020 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002021 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002022 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002023 cond_type = DType.BOOL
2024 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2025 choice = self.rng.choice([1, 2])
2026 if choice == 1:
2027 cond_shape = [3]
2028 else:
2029 cond_shape = [1, 2]
2030 else:
2031 cond_shape = []
2032 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002033
Kevin Cheng550ccc52021-03-03 11:21:43 -08002034 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002035
2036 # BODY block (input: a, acc, iter, output: a, acc, iter)
2037 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002038 self.ser.addBasicBlock(body_block)
2039
Matthew Haddon630c17c2021-10-14 15:05:41 +01002040 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2041 self.ser.addInputTensor(incorrect_iter)
2042 self.ser.addInputTensor(a)
2043 self.ser.addInputTensor(incorrect_acc)
2044 else:
2045 self.ser.addInputTensor(iter)
2046 self.ser.addInputTensor(a)
2047 self.ser.addInputTensor(acc)
2048
Kevin Cheng550ccc52021-03-03 11:21:43 -08002049 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002050
2051 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002052 iter_body_out = self.ser.addIntermediate(
2053 incorrect_iter.shape, incorrect_iter.dtype
2054 )
2055 acc_body_out = self.ser.addIntermediate(
2056 incorrect_acc.shape, incorrect_acc.dtype
2057 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002058 else:
2059 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2060 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2061
Eric Kunzee5e26762020-10-13 16:11:07 -07002062 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2063 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2064 self.ser.addOutputTensor(iter_body_out)
2065 self.ser.addOutputTensor(a)
2066 self.ser.addOutputTensor(acc_body_out)
2067
Les Bell729b0352021-11-24 10:28:21 +00002068 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002069 self.ser,
2070 validator_fcns,
2071 error_name,
2072 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002073 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002074 ):
2075 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002076
Eric Kunzee5e26762020-10-13 16:11:07 -07002077 return acc_out
2078
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002079 def create_filter_lists(
2080 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2081 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002082 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2083 default_test_rank_range = range(1, 5)
2084 if not shapeFilter:
2085 shapeFilter = [None]
2086
2087 # Calculate the filters based on what is requested and what the operator allows
2088 rmin, rmax = op["rank"]
2089 if rankFilter is not None:
2090 cleanRankFilter = []
2091 # Ensure rankFilter values are allowed by operator
2092 for rank in rankFilter:
2093 if rank >= rmin and rank <= rmax:
2094 cleanRankFilter.append(rank)
2095 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002096 # Ensure default behaviour is bounded by default range or by operator,
2097 # whichever is the smaller range of ranks.
2098 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002099 cleanRankFilter = (
2100 opRankRange
2101 if len(opRankRange) <= len(default_test_rank_range)
2102 else default_test_rank_range
2103 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002104 else:
2105 cleanRankFilter = range(rmin, rmax + 1)
2106
2107 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002108
Matthew Haddon1c00b712021-10-01 15:51:03 +01002109 if dtypeFilter is not None:
2110 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002111 # Create list of operator dtypes filtered by requested dtypes
2112 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002113 if dtype in dtypeFilter or (
2114 isinstance(dtype, list) and dtype[0] in dtypeFilter
2115 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002116 cleanDtypeFilter.append(dtype)
2117 else:
2118 cleanDtypeFilter = dtypes
2119
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002120 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002121 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002122 "shapeFilter": shapeFilter,
2123 "rankFilter": cleanRankFilter,
2124 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002125 }
2126 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002127 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002128 if validator is not None:
2129 validator_info = validator(check=False, op=op)
2130 else:
2131 return None
2132
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002134
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002135 # Set parameters as required
2136 if error_arguments["rank"] is not None:
2137 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002138 else:
2139 rankFilter = cleanRankFilter
2140
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002141 if error_arguments["dtype"] is not None:
2142 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002143 else:
2144 dtypeFilter = cleanDtypeFilter
2145
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002146 if error_arguments["shape"] is not None:
2147 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002148 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002149 shapeFilter = shapeFilter[
2150 :2
2151 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002152
2153 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002154 "shapeFilter": shapeFilter,
2155 "rankFilter": rankFilter,
2156 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002157 }
2158 return filterDict
2159
Kevin Cheng550ccc52021-03-03 11:21:43 -08002160 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002161 self,
2162 opName,
2163 shapeFilter=[None],
2164 rankFilter=None,
2165 dtypeFilter=None,
2166 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002167 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002168
2169 try:
2170 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002171 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002172 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002173
2174 # Initialize a new random number generator
2175 self.rng = np.random.default_rng(self.random_seed)
2176
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002177 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002178
Eric Kunzee5e26762020-10-13 16:11:07 -07002179 # Test list consists of a tuple of:
2180 # (opName, testNameStr, dtype, shapeList, argumentsList)
2181 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002182 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002183 error_if_validators = op["error_if_validators"]
2184 else:
2185 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002186
Matthew Haddon1c00b712021-10-01 15:51:03 +01002187 for validator in error_if_validators:
2188 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002189 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002190 else:
2191 error_name = None
2192
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002193 filterDict = self.create_filter_lists(
2194 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2195 )
2196 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002197 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002198 cleanRankFilter = filterDict["rankFilter"]
2199 cleanDtypeFilter = filterDict["dtypeFilter"]
2200 cleanShapeFilter = filterDict["shapeFilter"]
2201 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002202
2203 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002204 for t in cleanDtypeFilter:
2205 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002206 # Filter out by rank
2207 if shape is not None and len(shape) != r:
2208 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002209 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002210 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002211
Matthew Haddon74567092021-07-16 15:38:20 +01002212 shapeStr = self.shapeStr(shapeList[0])
2213 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002214
Matthew Haddon74567092021-07-16 15:38:20 +01002215 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2216 argList = []
2217 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002218 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002219 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002220 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002221
Matthew Haddon74567092021-07-16 15:38:20 +01002222 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002223 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002224 if argStr:
2225 testStr = "{}_{}_{}_{}".format(
2226 opName, shapeStr, typeStr, argStr
2227 )
2228 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002229 testStr = "{}_{}_{}".format(
2230 opName, shapeStr, typeStr
2231 )
2232 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002233 if argStr:
2234 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2235 opName, error_name, shapeStr, typeStr, argStr
2236 )
2237 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002238 testStr = "{}_ERRORIF_{}_{}_{}".format(
2239 opName, error_name, shapeStr, typeStr
2240 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002241
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002242 testList.append(
2243 (opName, testStr, t, error_name, shapeList, args)
2244 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002245
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002246 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002247 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2248 if "invalid_test_validators" in op:
2249 invalid_test_validators = op["invalid_test_validators"]
2250 clean_testList = []
2251 for test in testList:
2252 for validator_fcn in invalid_test_validators:
2253 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002254 if validator_fcn(
2255 opName=test[0],
2256 input_dtype=test[2],
2257 shapeList=test[4],
2258 args=test[5],
2259 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002260 remove_test = True
2261 if not remove_test:
2262 clean_testList.append(test)
2263 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002264
2265 return testList
2266
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002267 def serializeTest(
2268 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2269 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002270 try:
2271 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002272 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002273 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
2275 # Create a serializer
2276 self.createSerializer(opName, testStr)
2277
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002278 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002279 if "error_if_validators" in op:
2280 error_if_validators = op["error_if_validators"]
2281 else:
2282 error_if_validators = None
2283
Kevin Cheng550ccc52021-03-03 11:21:43 -08002284 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002285 num_operands = pCount + cCount
2286
2287 if isinstance(dtype_or_dtypeList, list):
2288 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002289 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002290 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002291 else:
2292 dtypeList = [dtype_or_dtypeList] * (num_operands)
2293
Kevin Cheng93a16282021-08-31 16:14:03 -07002294 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002295 assert (
2296 len(shapeList) == num_operands
2297 ), "shapeList length {} must match number of operands {}".format(
2298 len(shapeList), num_operands
2299 )
2300 assert (
2301 len(dtypeList) == num_operands
2302 ), "dtypeList length {} must match number of operands {}".format(
2303 len(dtypeList), num_operands
2304 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002305
2306 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002307 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002308 except KeyError:
2309 qgen = None
2310
2311 # Build the random tensor operands and the test
2312 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002313
Matthew Haddon1c00b712021-10-01 15:51:03 +01002314 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002315 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002316 else:
2317 qinfo = None
2318
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002319 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002320
Matthew Haddon1c00b712021-10-01 15:51:03 +01002321 try:
2322 if error_if_validators is None:
2323 if qinfo is not None:
2324 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2325 else:
2326 resultName = build_fcn(self, op, *tens, *testArgs)
2327 else:
2328 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002329 resultName = build_fcn(
2330 self,
2331 op,
2332 *tens,
2333 *testArgs,
2334 validator_fcns=error_if_validators,
2335 error_name=error_name,
2336 qinfo=qinfo,
2337 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002338 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002339 resultName = build_fcn(
2340 self,
2341 op,
2342 *tens,
2343 *testArgs,
2344 validator_fcns=error_if_validators,
2345 error_name=error_name,
2346 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002347 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002348 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002349 raise e
2350
Les Bell729b0352021-11-24 10:28:21 +00002351 if resultName:
2352 # The test is valid, serialize it
2353 self.serialize("test")
2354 else:
2355 # The test is not valid
2356 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002357
Eric Kunzee5e26762020-10-13 16:11:07 -07002358 def createDynamicOpLists(self):
2359
Jeremy Johnson00423432022-09-12 17:27:37 +01002360 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2361 # Already created these lists (can occur when class is initialized more than once)
2362 return
2363
Eric Kunzee5e26762020-10-13 16:11:07 -07002364 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002365 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002366
Kevin Cheng1533b852021-09-01 12:51:58 -07002367 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002368 testName = "conv2d_{}x{}".format(k[0], k[1])
2369 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2370 self.TOSA_OP_LIST[testName]["filter"] = k
2371 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002372
Kevin Cheng550ccc52021-03-03 11:21:43 -08002373 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2374 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2375 "depthwise_conv2d_TEMPLATE"
2376 ].copy()
2377 self.TOSA_OP_LIST[testName]["filter"] = k
2378 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002379
Kevin Cheng550ccc52021-03-03 11:21:43 -08002380 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2381 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2382 "transpose_conv2d_TEMPLATE"
2383 ].copy()
2384 self.TOSA_OP_LIST[testName]["filter"] = k
2385 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002386
Kevin Cheng1533b852021-09-01 12:51:58 -07002387 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2388 for k in KERNELS_3D:
2389 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2390 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2391 self.TOSA_OP_LIST[testName]["filter"] = k
2392 self.TOSA_OP_LIST[testName]["template"] = False
2393
Eric Kunzee5e26762020-10-13 16:11:07 -07002394 # Delete any templates after having created any dynamic ops
2395 # This is a two-pass operation because it's bad practice to delete
2396 # keys from dictionaries while iterating
2397 keyList = []
2398 for k in self.TOSA_OP_LIST:
2399 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002400 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002401 keyList.append(k)
2402 continue
2403 except KeyError:
2404 pass
2405
2406 for k in keyList:
2407 del self.TOSA_OP_LIST[k]
2408
2409 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002410 """Fill in default fields for ops if they aren't already specified.
2411 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002412 for op in self.TOSA_OP_LIST:
2413
2414 # Required fields
2415 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002416 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002417 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002418 raise Exception(
2419 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2420 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002421
2422 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002423 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002424 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002425 raise Exception(
2426 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2427 op
2428 )
2429 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002430
2431 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002432 _ = self.TOSA_OP_LIST[op]["types"]
2433 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002434 raise Exception(
2435 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2436 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002437
2438 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002439 _ = self.TOSA_OP_LIST[op]["op"]
2440 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002441 raise Exception(
2442 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2443 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002444
2445 # Put in default rank range, if missing
2446 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002447 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002448 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002449 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002450
2451 # Tensor operator list
2452 # 'op': op name
2453 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002454 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2455 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002456 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2457 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002458 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002459
Kevin Cheng550ccc52021-03-03 11:21:43 -08002460 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002461 TYPE_INT_FP = [
2462 DType.INT8,
2463 DType.INT16,
2464 DType.INT32,
2465 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002466 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002467 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002468 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002469
Kevin Cheng550ccc52021-03-03 11:21:43 -08002470 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002471 TYPE_FI32 = [
2472 DType.FP32,
2473 DType.FP16,
2474 DType.BF16,
2475 DType.INT32,
2476 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002477 TYPE_FIB = [
2478 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002479 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002480 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002481 DType.INT8,
2482 DType.INT16,
2483 DType.INT32,
2484 DType.BOOL,
2485 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002486 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002487
James Ward24dbc422022-10-19 12:20:31 +01002488 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002489
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002490 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002491 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002492 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002493 [DType.INT8, DType.INT8, DType.INT32],
2494 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002495 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002496 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002497 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002498 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002499 ]
2500
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002501 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002502
2503 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002504 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002505 "argmax": {
2506 "op": Op.ARGMAX,
2507 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002508 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002509 "build_fcn": (
2510 build_argmax,
2511 TosaTensorGen.tgBasic,
2512 TosaTensorValuesGen.tvgDefault,
2513 TosaArgGen.agAxis,
2514 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002515 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002516 "error_if_validators": (
2517 TosaErrorValidator.evAxisSmallerZero,
2518 TosaErrorValidator.evAxisLargerRank,
2519 TosaErrorValidator.evArgmaxOutputRankMismatch,
2520 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2521 TosaErrorValidator.evWrongRank,
2522 TosaErrorValidator.evWrongInputType,
2523 TosaErrorValidator.evWrongOutputType,
2524 TosaErrorValidator.evWrongInputList,
2525 TosaErrorValidator.evWrongOutputList,
2526 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002527 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002528 "avg_pool2d": {
2529 "op": Op.AVG_POOL2D,
2530 "operands": (1, 0),
2531 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002532 "build_fcn": (
2533 build_pool2d,
2534 TosaTensorGen.tgNHWC,
2535 TosaTensorValuesGen.tvgDefault,
2536 TosaArgGen.agPooling,
2537 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002538 "qgen": TosaQuantGen.qgUnary,
2539 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002540 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002541 "error_if_validators": (
2542 TosaErrorValidator.evKernelSmallerOne,
2543 TosaErrorValidator.evStrideSmallerOne,
2544 TosaErrorValidator.evPadSmallerZero,
2545 TosaErrorValidator.evWrongRank,
2546 TosaErrorValidator.evWrongInputType,
2547 TosaErrorValidator.evWrongOutputType,
2548 TosaErrorValidator.evWrongInputList,
2549 TosaErrorValidator.evWrongOutputList,
2550 TosaErrorValidator.evInputZeroPointNotZero,
2551 TosaErrorValidator.evOutputZeroPointNotZero,
2552 TosaErrorValidator.evPadLargerEqualKernel,
2553 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002554 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002555 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002556 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002557 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002558 "conv2d_TEMPLATE": {
2559 "op": Op.CONV2D,
2560 "operands": (1, 2),
2561 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002562 "build_fcn": (
2563 build_conv2d,
2564 TosaTensorGen.tgConv2D,
2565 TosaTensorValuesGen.tvgDefault,
2566 TosaArgGen.agConv,
2567 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002568 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002569 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002570 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2571 "error_if_validators": (
2572 TosaErrorValidator.evWrongInputType,
2573 TosaErrorValidator.evWrongOutputType,
2574 TosaErrorValidator.evWrongInputList,
2575 TosaErrorValidator.evWrongOutputList,
2576 TosaErrorValidator.evInputZeroPointNotZero,
2577 TosaErrorValidator.evWeightZeroPointNotZero,
2578 TosaErrorValidator.evPadSmallerZero,
2579 TosaErrorValidator.evStrideSmallerOne,
2580 TosaErrorValidator.evDilationSmallerOne,
2581 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002582 TosaErrorValidator.evConvOutputShapeMismatch,
2583 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002584 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002585 "template": True,
2586 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002587 # Templated operator. Filled in by createDynamicOpLists
2588 "conv3d_TEMPLATE": {
2589 "op": Op.CONV3D,
2590 "operands": (1, 2),
2591 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002592 "build_fcn": (
2593 build_conv3d,
2594 TosaTensorGen.tgConv3D,
2595 TosaTensorValuesGen.tvgDefault,
2596 TosaArgGen.agConv,
2597 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002598 "qgen": TosaQuantGen.qgConv,
2599 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002600 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2601 "error_if_validators": (
2602 TosaErrorValidator.evWrongInputType,
2603 TosaErrorValidator.evWrongOutputType,
2604 TosaErrorValidator.evWrongInputList,
2605 TosaErrorValidator.evWrongOutputList,
2606 TosaErrorValidator.evInputZeroPointNotZero,
2607 TosaErrorValidator.evWeightZeroPointNotZero,
2608 TosaErrorValidator.evPadSmallerZero,
2609 TosaErrorValidator.evStrideSmallerOne,
2610 TosaErrorValidator.evDilationSmallerOne,
2611 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002612 TosaErrorValidator.evConvOutputShapeMismatch,
2613 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002614 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002615 "template": True,
2616 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002617 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002618 "depthwise_conv2d_TEMPLATE": {
2619 "op": Op.DEPTHWISE_CONV2D,
2620 "operands": (1, 2),
2621 "filter": [1, 1],
2622 "rank": (4, 4),
2623 "build_fcn": (
2624 build_depthwise_conv2d,
2625 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002626 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002627 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002628 ),
2629 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002630 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002631 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2632 "error_if_validators": (
2633 TosaErrorValidator.evWrongInputType,
2634 TosaErrorValidator.evWrongOutputType,
2635 TosaErrorValidator.evWrongInputList,
2636 TosaErrorValidator.evWrongOutputList,
2637 TosaErrorValidator.evInputZeroPointNotZero,
2638 TosaErrorValidator.evWeightZeroPointNotZero,
2639 TosaErrorValidator.evPadSmallerZero,
2640 TosaErrorValidator.evStrideSmallerOne,
2641 TosaErrorValidator.evDilationSmallerOne,
2642 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002643 TosaErrorValidator.evConvOutputShapeMismatch,
2644 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002645 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002646 "template": True,
2647 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002648 "fully_connected": {
2649 "op": Op.FULLY_CONNECTED,
2650 "operands": (1, 2),
2651 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002652 "build_fcn": (
2653 build_fully_connected,
2654 TosaTensorGen.tgFullyConnected,
2655 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002656 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002657 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002658 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002659 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002660 "error_if_validators": (
2661 TosaErrorValidator.evInputZeroPointNotZero,
2662 TosaErrorValidator.evWeightZeroPointNotZero,
2663 TosaErrorValidator.evWrongRank,
2664 TosaErrorValidator.evWrongInputType,
2665 TosaErrorValidator.evWrongOutputType,
2666 TosaErrorValidator.evWrongInputList,
2667 TosaErrorValidator.evWrongOutputList,
2668 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002669 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002670 "matmul": {
2671 "op": Op.MATMUL,
2672 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002673 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002674 "build_fcn": (
2675 build_matmul,
2676 TosaTensorGen.tgMatmul,
2677 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002678 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002679 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002680 "qgen": TosaQuantGen.qgMatmul,
2681 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002682 "error_if_validators": (
2683 TosaErrorValidator.evInputZeroPointNotZero,
2684 TosaErrorValidator.evWrongRank,
2685 TosaErrorValidator.evWrongInputType,
2686 TosaErrorValidator.evWrongOutputType,
2687 TosaErrorValidator.evWrongInputList,
2688 TosaErrorValidator.evWrongOutputList,
2689 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002690 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002691 "max_pool2d": {
2692 "op": Op.MAX_POOL2D,
2693 "operands": (1, 0),
2694 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002695 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002696 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002697 TosaTensorGen.tgNHWC,
2698 TosaTensorValuesGen.tvgDefault,
2699 TosaArgGen.agPooling,
2700 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002701 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002702 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002703 "error_if_validators": (
2704 TosaErrorValidator.evKernelSmallerOne,
2705 TosaErrorValidator.evStrideSmallerOne,
2706 TosaErrorValidator.evPadSmallerZero,
2707 TosaErrorValidator.evWrongRank,
2708 TosaErrorValidator.evWrongInputType,
2709 TosaErrorValidator.evWrongOutputType,
2710 TosaErrorValidator.evWrongInputList,
2711 TosaErrorValidator.evWrongOutputList,
2712 TosaErrorValidator.evPadLargerEqualKernel,
2713 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002714 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002715 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002716 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002717 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002718 "transpose_conv2d_TEMPLATE": {
2719 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002720 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002721 "rank": (4, 4),
2722 "build_fcn": (
2723 build_transpose_conv2d,
2724 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002725 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002726 TosaArgGen.agTransposeConv2D,
2727 ),
2728 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002729 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002730 "invalid_test_validators": (
2731 TosaInvalidValidator.ivHeightWidthInvalid,
2732 TosaInvalidValidator.ivNonPositiveOutputShape,
2733 ),
2734 "error_if_validators": (
2735 TosaErrorValidator.evWrongInputType,
2736 TosaErrorValidator.evWrongOutputType,
2737 TosaErrorValidator.evWrongInputList,
2738 TosaErrorValidator.evWrongOutputList,
2739 TosaErrorValidator.evInputZeroPointNotZero,
2740 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002741 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002742 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002743 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002744 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002745 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002746 "template": True,
2747 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002748 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002749 "clamp": {
2750 "op": Op.CLAMP,
2751 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002752 "build_fcn": (
2753 build_clamp,
2754 TosaTensorGen.tgBasic,
2755 TosaTensorValuesGen.tvgDefault,
2756 None,
2757 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002758 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002759 "error_if_validators": (
2760 TosaErrorValidator.evMaxSmallerMin,
2761 TosaErrorValidator.evWrongInputType,
2762 TosaErrorValidator.evWrongOutputType,
2763 TosaErrorValidator.evWrongInputList,
2764 TosaErrorValidator.evWrongOutputList,
2765 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002766 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002767 "sigmoid": {
2768 "op": Op.SIGMOID,
2769 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002770 "build_fcn": (
2771 build_sigmoid,
2772 TosaTensorGen.tgBasic,
2773 TosaTensorValuesGen.tvgDefault,
2774 None,
2775 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002776 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002777 "error_if_validators": (
2778 TosaErrorValidator.evWrongInputType,
2779 TosaErrorValidator.evWrongOutputType,
2780 TosaErrorValidator.evWrongInputList,
2781 TosaErrorValidator.evWrongOutputList,
2782 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002783 },
2784 "tanh": {
2785 "op": Op.TANH,
2786 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002787 "build_fcn": (
2788 build_tanh,
2789 TosaTensorGen.tgBasic,
2790 TosaTensorValuesGen.tvgDefault,
2791 None,
2792 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002793 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002794 "error_if_validators": (
2795 TosaErrorValidator.evWrongInputType,
2796 TosaErrorValidator.evWrongOutputType,
2797 TosaErrorValidator.evWrongInputList,
2798 TosaErrorValidator.evWrongOutputList,
2799 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002800 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002801 # Elementwise Binary Operators
2802 "add": {
2803 "op": Op.ADD,
2804 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002805 "build_fcn": (
2806 build_binary_broadcast,
2807 TosaTensorGen.tgBroadcastFuzz,
2808 TosaTensorValuesGen.tvgAddSub,
2809 None,
2810 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002811 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002812 "error_if_validators": (
2813 TosaErrorValidator.evRankMismatch,
2814 TosaErrorValidator.evWrongInputType,
2815 TosaErrorValidator.evWrongOutputType,
2816 TosaErrorValidator.evWrongInputList,
2817 TosaErrorValidator.evWrongOutputList,
2818 TosaErrorValidator.evDimensionMismatch,
2819 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002820 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002821 "arithmetic_right_shift": {
2822 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2823 "operands": (2, 0),
2824 "build_fcn": (
2825 build_arithmetic_right_shift,
2826 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002827 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002828 TosaArgGen.agArithmeticRightShift,
2829 ),
2830 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002831 "error_if_validators": (
2832 TosaErrorValidator.evRankMismatch,
2833 TosaErrorValidator.evWrongInputType,
2834 TosaErrorValidator.evWrongOutputType,
2835 TosaErrorValidator.evWrongInputList,
2836 TosaErrorValidator.evWrongOutputList,
2837 TosaErrorValidator.evDimensionMismatch,
2838 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002839 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002840 "bitwise_and": {
2841 "op": Op.BITWISE_AND,
2842 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002843 "build_fcn": (
2844 build_binary_broadcast,
2845 TosaTensorGen.tgBroadcastFuzz,
2846 TosaTensorValuesGen.tvgDefault,
2847 None,
2848 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002849 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002850 "error_if_validators": (
2851 TosaErrorValidator.evRankMismatch,
2852 TosaErrorValidator.evWrongInputType,
2853 TosaErrorValidator.evWrongOutputType,
2854 TosaErrorValidator.evWrongInputList,
2855 TosaErrorValidator.evWrongOutputList,
2856 TosaErrorValidator.evDimensionMismatch,
2857 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002858 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002859 "bitwise_or": {
2860 "op": Op.BITWISE_OR,
2861 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002862 "build_fcn": (
2863 build_binary_broadcast,
2864 TosaTensorGen.tgBroadcastFuzz,
2865 TosaTensorValuesGen.tvgDefault,
2866 None,
2867 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002868 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002869 "error_if_validators": (
2870 TosaErrorValidator.evRankMismatch,
2871 TosaErrorValidator.evWrongInputType,
2872 TosaErrorValidator.evWrongOutputType,
2873 TosaErrorValidator.evWrongInputList,
2874 TosaErrorValidator.evWrongOutputList,
2875 TosaErrorValidator.evDimensionMismatch,
2876 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002877 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002878 "bitwise_xor": {
2879 "op": Op.BITWISE_XOR,
2880 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002881 "build_fcn": (
2882 build_binary_broadcast,
2883 TosaTensorGen.tgBroadcastFuzz,
2884 TosaTensorValuesGen.tvgDefault,
2885 None,
2886 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002887 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002888 "error_if_validators": (
2889 TosaErrorValidator.evRankMismatch,
2890 TosaErrorValidator.evWrongInputType,
2891 TosaErrorValidator.evWrongOutputType,
2892 TosaErrorValidator.evWrongInputList,
2893 TosaErrorValidator.evWrongOutputList,
2894 TosaErrorValidator.evDimensionMismatch,
2895 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002896 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002897 "intdiv": {
2898 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002899 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002900 "build_fcn": (
2901 build_binary_broadcast,
2902 TosaTensorGen.tgBroadcastFuzz,
2903 TosaTensorValuesGen.tvgIntDiv,
2904 None,
2905 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002906 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002907 "error_if_validators": (
2908 TosaErrorValidator.evRankMismatch,
2909 TosaErrorValidator.evWrongInputType,
2910 TosaErrorValidator.evWrongOutputType,
2911 TosaErrorValidator.evWrongInputList,
2912 TosaErrorValidator.evWrongOutputList,
2913 TosaErrorValidator.evDimensionMismatch,
2914 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002915 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002916 "logical_and": {
2917 "op": Op.LOGICAL_AND,
2918 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002919 "build_fcn": (
2920 build_binary_broadcast,
2921 TosaTensorGen.tgBroadcastFuzz,
2922 TosaTensorValuesGen.tvgDefault,
2923 None,
2924 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002925 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002926 "error_if_validators": (
2927 TosaErrorValidator.evRankMismatch,
2928 TosaErrorValidator.evWrongInputType,
2929 TosaErrorValidator.evWrongOutputType,
2930 TosaErrorValidator.evWrongInputList,
2931 TosaErrorValidator.evWrongOutputList,
2932 TosaErrorValidator.evDimensionMismatch,
2933 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002934 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002935 "logical_left_shift": {
2936 "op": Op.LOGICAL_LEFT_SHIFT,
2937 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002938 "build_fcn": (
2939 build_binary_broadcast,
2940 TosaTensorGen.tgBroadcastFuzz,
2941 TosaTensorValuesGen.tvgLogicalShift,
2942 None,
2943 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002944 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002945 "error_if_validators": (
2946 TosaErrorValidator.evRankMismatch,
2947 TosaErrorValidator.evWrongInputType,
2948 TosaErrorValidator.evWrongOutputType,
2949 TosaErrorValidator.evWrongInputList,
2950 TosaErrorValidator.evWrongOutputList,
2951 TosaErrorValidator.evDimensionMismatch,
2952 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002953 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002954 "logical_right_shift": {
2955 "op": Op.LOGICAL_RIGHT_SHIFT,
2956 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002957 "build_fcn": (
2958 build_binary_broadcast,
2959 TosaTensorGen.tgBroadcastFuzz,
2960 TosaTensorValuesGen.tvgLogicalShift,
2961 None,
2962 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002963 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002964 "error_if_validators": (
2965 TosaErrorValidator.evRankMismatch,
2966 TosaErrorValidator.evWrongInputType,
2967 TosaErrorValidator.evWrongOutputType,
2968 TosaErrorValidator.evWrongInputList,
2969 TosaErrorValidator.evWrongOutputList,
2970 TosaErrorValidator.evDimensionMismatch,
2971 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002972 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002973 "logical_or": {
2974 "op": Op.LOGICAL_OR,
2975 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002976 "build_fcn": (
2977 build_binary_broadcast,
2978 TosaTensorGen.tgBroadcastFuzz,
2979 TosaTensorValuesGen.tvgDefault,
2980 None,
2981 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002982 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002983 "error_if_validators": (
2984 TosaErrorValidator.evRankMismatch,
2985 TosaErrorValidator.evWrongInputType,
2986 TosaErrorValidator.evWrongOutputType,
2987 TosaErrorValidator.evWrongInputList,
2988 TosaErrorValidator.evWrongOutputList,
2989 TosaErrorValidator.evDimensionMismatch,
2990 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002991 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002992 "logical_xor": {
2993 "op": Op.LOGICAL_XOR,
2994 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002995 "build_fcn": (
2996 build_binary_broadcast,
2997 TosaTensorGen.tgBroadcastFuzz,
2998 TosaTensorValuesGen.tvgDefault,
2999 None,
3000 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003001 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003002 "error_if_validators": (
3003 TosaErrorValidator.evRankMismatch,
3004 TosaErrorValidator.evWrongInputType,
3005 TosaErrorValidator.evWrongOutputType,
3006 TosaErrorValidator.evWrongInputList,
3007 TosaErrorValidator.evWrongOutputList,
3008 TosaErrorValidator.evDimensionMismatch,
3009 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003010 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003011 "maximum": {
3012 "op": Op.MAXIMUM,
3013 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003014 "build_fcn": (
3015 build_binary_broadcast,
3016 TosaTensorGen.tgBroadcastFuzz,
3017 TosaTensorValuesGen.tvgDefault,
3018 None,
3019 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003020 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003021 "error_if_validators": (
3022 TosaErrorValidator.evRankMismatch,
3023 TosaErrorValidator.evWrongInputType,
3024 TosaErrorValidator.evWrongOutputType,
3025 TosaErrorValidator.evWrongInputList,
3026 TosaErrorValidator.evWrongOutputList,
3027 TosaErrorValidator.evDimensionMismatch,
3028 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003029 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003030 "minimum": {
3031 "op": Op.MINIMUM,
3032 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003033 "build_fcn": (
3034 build_binary_broadcast,
3035 TosaTensorGen.tgBroadcastFuzz,
3036 TosaTensorValuesGen.tvgDefault,
3037 None,
3038 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003039 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003040 "error_if_validators": (
3041 TosaErrorValidator.evRankMismatch,
3042 TosaErrorValidator.evWrongInputType,
3043 TosaErrorValidator.evWrongOutputType,
3044 TosaErrorValidator.evWrongInputList,
3045 TosaErrorValidator.evWrongOutputList,
3046 TosaErrorValidator.evDimensionMismatch,
3047 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003048 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003049 "mul": {
3050 "op": Op.MUL,
3051 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003052 "build_fcn": (
3053 build_mul,
3054 TosaTensorGen.tgBroadcastFuzz,
3055 TosaTensorValuesGen.tvgMul,
3056 TosaArgGen.agMul,
3057 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003058 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003059 "error_if_validators": (
3060 TosaErrorValidator.evWrongInputType,
3061 TosaErrorValidator.evWrongOutputType,
3062 TosaErrorValidator.evWrongInputList,
3063 TosaErrorValidator.evWrongOutputList,
3064 TosaErrorValidator.evRankMismatch,
3065 TosaErrorValidator.evDimensionMismatch,
3066 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003067 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003068 "pow": {
3069 "op": Op.POW,
3070 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003071 "build_fcn": (
3072 build_binary_broadcast,
3073 TosaTensorGen.tgBroadcastFuzz,
3074 TosaTensorValuesGen.tvgDefault,
3075 None,
3076 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003077 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003078 "error_if_validators": (
3079 TosaErrorValidator.evRankMismatch,
3080 TosaErrorValidator.evWrongInputType,
3081 TosaErrorValidator.evWrongOutputType,
3082 TosaErrorValidator.evWrongInputList,
3083 TosaErrorValidator.evWrongOutputList,
3084 TosaErrorValidator.evDimensionMismatch,
3085 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003086 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003087 "sub": {
3088 "op": Op.SUB,
3089 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003090 "build_fcn": (
3091 build_binary_broadcast,
3092 TosaTensorGen.tgBroadcastFuzz,
3093 TosaTensorValuesGen.tvgAddSub,
3094 None,
3095 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003096 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003097 "error_if_validators": (
3098 TosaErrorValidator.evRankMismatch,
3099 TosaErrorValidator.evWrongInputType,
3100 TosaErrorValidator.evWrongOutputType,
3101 TosaErrorValidator.evWrongInputList,
3102 TosaErrorValidator.evWrongOutputList,
3103 TosaErrorValidator.evDimensionMismatch,
3104 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003105 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003106 "table": {
3107 "op": Op.TABLE,
3108 # Use the automatic generation functions to create the input array
3109 # but create the table tensor in the build function, as it may be
3110 # a different type from the input
3111 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003112 "build_fcn": (
3113 build_table,
3114 TosaTensorGen.tgBasic,
3115 TosaTensorValuesGen.tvgDefault,
3116 TosaArgGen.agTable,
3117 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003118 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003119 "error_if_validators": (
3120 TosaErrorValidator.evWrongInputType,
3121 TosaErrorValidator.evWrongOutputType,
3122 TosaErrorValidator.evWrongInputList,
3123 TosaErrorValidator.evWrongOutputList,
3124 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003125 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 # Elementwise Unary operators
3127 "abs": {
3128 "op": Op.ABS,
3129 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003130 "build_fcn": (
3131 build_unary,
3132 TosaTensorGen.tgBasic,
3133 TosaTensorValuesGen.tvgDefault,
3134 None,
3135 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003136 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003137 "error_if_validators": (
3138 TosaErrorValidator.evWrongInputType,
3139 TosaErrorValidator.evWrongOutputType,
3140 TosaErrorValidator.evWrongInputList,
3141 TosaErrorValidator.evWrongOutputList,
3142 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003143 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003144 "bitwise_not": {
3145 "op": Op.BITWISE_NOT,
3146 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003147 "build_fcn": (
3148 build_unary,
3149 TosaTensorGen.tgBasic,
3150 TosaTensorValuesGen.tvgDefault,
3151 None,
3152 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003153 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003154 "error_if_validators": (
3155 TosaErrorValidator.evWrongInputType,
3156 TosaErrorValidator.evWrongOutputType,
3157 TosaErrorValidator.evWrongInputList,
3158 TosaErrorValidator.evWrongOutputList,
3159 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003160 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003161 "ceil": {
3162 "op": Op.CEIL,
3163 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003164 "build_fcn": (
3165 build_unary,
3166 TosaTensorGen.tgBasic,
3167 TosaTensorValuesGen.tvgDefault,
3168 None,
3169 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003170 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003171 "error_if_validators": (
3172 TosaErrorValidator.evWrongInputType,
3173 TosaErrorValidator.evWrongOutputType,
3174 TosaErrorValidator.evWrongInputList,
3175 TosaErrorValidator.evWrongOutputList,
3176 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003177 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003178 "clz": {
3179 "op": Op.CLZ,
3180 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003181 "build_fcn": (
3182 build_unary,
3183 TosaTensorGen.tgBasic,
3184 TosaTensorValuesGen.tvgDefault,
3185 None,
3186 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003187 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003188 "error_if_validators": (
3189 TosaErrorValidator.evWrongInputType,
3190 TosaErrorValidator.evWrongOutputType,
3191 TosaErrorValidator.evWrongInputList,
3192 TosaErrorValidator.evWrongOutputList,
3193 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003194 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003195 "exp": {
3196 "op": Op.EXP,
3197 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003198 "build_fcn": (
3199 build_unary,
3200 TosaTensorGen.tgBasic,
3201 TosaTensorValuesGen.tvgDefault,
3202 None,
3203 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003204 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003205 "error_if_validators": (
3206 TosaErrorValidator.evWrongInputType,
3207 TosaErrorValidator.evWrongOutputType,
3208 TosaErrorValidator.evWrongInputList,
3209 TosaErrorValidator.evWrongOutputList,
3210 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003211 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003212 "floor": {
3213 "op": Op.FLOOR,
3214 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003215 "build_fcn": (
3216 build_unary,
3217 TosaTensorGen.tgBasic,
3218 TosaTensorValuesGen.tvgDefault,
3219 None,
3220 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003221 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003222 "error_if_validators": (
3223 TosaErrorValidator.evWrongInputType,
3224 TosaErrorValidator.evWrongOutputType,
3225 TosaErrorValidator.evWrongInputList,
3226 TosaErrorValidator.evWrongOutputList,
3227 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003228 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003229 "log": {
3230 "op": Op.LOG,
3231 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003232 "build_fcn": (
3233 build_unary,
3234 TosaTensorGen.tgBasic,
3235 TosaTensorValuesGen.tvgDefault,
3236 None,
3237 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003238 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003239 "error_if_validators": (
3240 TosaErrorValidator.evWrongInputType,
3241 TosaErrorValidator.evWrongOutputType,
3242 TosaErrorValidator.evWrongInputList,
3243 TosaErrorValidator.evWrongOutputList,
3244 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003245 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003246 "logical_not": {
3247 "op": Op.LOGICAL_NOT,
3248 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003249 "build_fcn": (
3250 build_unary,
3251 TosaTensorGen.tgBasic,
3252 TosaTensorValuesGen.tvgDefault,
3253 None,
3254 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003255 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003256 "error_if_validators": (
3257 TosaErrorValidator.evWrongInputType,
3258 TosaErrorValidator.evWrongOutputType,
3259 TosaErrorValidator.evWrongInputList,
3260 TosaErrorValidator.evWrongOutputList,
3261 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003262 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003263 "negate": {
3264 "op": Op.NEGATE,
3265 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003266 "build_fcn": (
3267 build_unary,
3268 TosaTensorGen.tgBasic,
3269 TosaTensorValuesGen.tvgNegate,
3270 None,
3271 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003272 "qgen": TosaQuantGen.qgUnary,
3273 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003274 "error_if_validators": (
3275 TosaErrorValidator.evInputZeroPointNotZero,
3276 TosaErrorValidator.evOutputZeroPointNotZero,
3277 TosaErrorValidator.evWrongInputType,
3278 TosaErrorValidator.evWrongOutputType,
3279 TosaErrorValidator.evWrongInputList,
3280 TosaErrorValidator.evWrongOutputList,
3281 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003282 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003283 "reciprocal": {
3284 "op": Op.RECIPROCAL,
3285 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003286 "build_fcn": (
3287 build_unary,
3288 TosaTensorGen.tgBasic,
3289 TosaTensorValuesGen.tvgDefault,
3290 None,
3291 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003292 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003293 "error_if_validators": (
3294 TosaErrorValidator.evWrongInputType,
3295 TosaErrorValidator.evWrongOutputType,
3296 TosaErrorValidator.evWrongInputList,
3297 TosaErrorValidator.evWrongOutputList,
3298 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003300 "rsqrt": {
3301 "op": Op.RSQRT,
3302 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003303 "build_fcn": (
3304 build_unary,
3305 TosaTensorGen.tgBasic,
3306 TosaTensorValuesGen.tvgDefault,
3307 None,
3308 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003310 "error_if_validators": (
3311 TosaErrorValidator.evWrongInputType,
3312 TosaErrorValidator.evWrongOutputType,
3313 TosaErrorValidator.evWrongInputList,
3314 TosaErrorValidator.evWrongOutputList,
3315 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003316 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003317 # Elementwise Ternary operators
3318 "select": {
3319 "op": Op.SELECT,
3320 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003321 "build_fcn": (
3322 build_select,
3323 TosaTensorGen.tgBroadcastFuzz,
3324 TosaTensorValuesGen.tvgSelect,
3325 None,
3326 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003327 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003328 "error_if_validators": (
3329 TosaErrorValidator.evRankMismatch,
3330 TosaErrorValidator.evWrongInputType,
3331 TosaErrorValidator.evWrongOutputType,
3332 TosaErrorValidator.evWrongInputList,
3333 TosaErrorValidator.evWrongOutputList,
3334 TosaErrorValidator.evDimensionMismatch,
3335 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003336 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 # Comparison operators
3338 "equal": {
3339 "op": Op.EQUAL,
3340 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003341 "build_fcn": (
3342 build_comparison,
3343 TosaTensorGen.tgBroadcastFuzz,
3344 TosaTensorValuesGen.tvgEqual,
3345 None,
3346 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003347 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003348 "error_if_validators": (
3349 TosaErrorValidator.evRankMismatch,
3350 TosaErrorValidator.evWrongInputType,
3351 TosaErrorValidator.evWrongOutputType,
3352 TosaErrorValidator.evWrongInputList,
3353 TosaErrorValidator.evWrongOutputList,
3354 TosaErrorValidator.evDimensionMismatch,
3355 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003356 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003357 "greater_equal": {
3358 "op": Op.GREATER_EQUAL,
3359 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003360 "build_fcn": (
3361 build_comparison,
3362 TosaTensorGen.tgBroadcastFuzz,
3363 TosaTensorValuesGen.tvgDefault,
3364 None,
3365 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003366 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003367 "error_if_validators": (
3368 TosaErrorValidator.evRankMismatch,
3369 TosaErrorValidator.evWrongInputType,
3370 TosaErrorValidator.evWrongOutputType,
3371 TosaErrorValidator.evWrongInputList,
3372 TosaErrorValidator.evWrongOutputList,
3373 TosaErrorValidator.evDimensionMismatch,
3374 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003375 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003376 "greater": {
3377 "op": Op.GREATER,
3378 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003379 "build_fcn": (
3380 build_comparison,
3381 TosaTensorGen.tgBroadcastFuzz,
3382 TosaTensorValuesGen.tvgDefault,
3383 None,
3384 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003385 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003386 "error_if_validators": (
3387 TosaErrorValidator.evRankMismatch,
3388 TosaErrorValidator.evWrongInputType,
3389 TosaErrorValidator.evWrongOutputType,
3390 TosaErrorValidator.evWrongInputList,
3391 TosaErrorValidator.evWrongOutputList,
3392 TosaErrorValidator.evDimensionMismatch,
3393 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003394 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003395 # Reduction operators
3396 "reduce_all": {
3397 "op": Op.REDUCE_ALL,
3398 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003399 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003400 "build_fcn": (
3401 build_reduce,
3402 TosaTensorGen.tgBasic,
3403 TosaTensorValuesGen.tvgDefault,
3404 TosaArgGen.agAxis,
3405 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003406 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003407 "error_if_validators": (
3408 TosaErrorValidator.evAxisLargerRank,
3409 TosaErrorValidator.evAxisSmallerZero,
3410 TosaErrorValidator.evShapeOfAxisNotOne,
3411 TosaErrorValidator.evWrongInputType,
3412 TosaErrorValidator.evWrongOutputType,
3413 TosaErrorValidator.evWrongRank,
3414 TosaErrorValidator.evWrongInputList,
3415 TosaErrorValidator.evWrongOutputList,
3416 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003417 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 "reduce_any": {
3419 "op": Op.REDUCE_ANY,
3420 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003421 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003422 "build_fcn": (
3423 build_reduce,
3424 TosaTensorGen.tgBasic,
3425 TosaTensorValuesGen.tvgDefault,
3426 TosaArgGen.agAxis,
3427 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003428 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003429 "error_if_validators": (
3430 TosaErrorValidator.evAxisLargerRank,
3431 TosaErrorValidator.evAxisSmallerZero,
3432 TosaErrorValidator.evShapeOfAxisNotOne,
3433 TosaErrorValidator.evWrongInputType,
3434 TosaErrorValidator.evWrongOutputType,
3435 TosaErrorValidator.evWrongRank,
3436 TosaErrorValidator.evWrongInputList,
3437 TosaErrorValidator.evWrongOutputList,
3438 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003439 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003440 "reduce_max": {
3441 "op": Op.REDUCE_MAX,
3442 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003443 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003444 "build_fcn": (
3445 build_reduce,
3446 TosaTensorGen.tgBasic,
3447 TosaTensorValuesGen.tvgDefault,
3448 TosaArgGen.agAxis,
3449 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003450 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003451 "error_if_validators": (
3452 TosaErrorValidator.evAxisLargerRank,
3453 TosaErrorValidator.evAxisSmallerZero,
3454 TosaErrorValidator.evShapeOfAxisNotOne,
3455 TosaErrorValidator.evWrongInputType,
3456 TosaErrorValidator.evWrongOutputType,
3457 TosaErrorValidator.evWrongRank,
3458 TosaErrorValidator.evWrongInputList,
3459 TosaErrorValidator.evWrongOutputList,
3460 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003461 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003462 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003463 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003464 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003465 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003466 "build_fcn": (
3467 build_reduce,
3468 TosaTensorGen.tgBasic,
3469 TosaTensorValuesGen.tvgDefault,
3470 TosaArgGen.agAxis,
3471 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003472 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003473 "error_if_validators": (
3474 TosaErrorValidator.evAxisLargerRank,
3475 TosaErrorValidator.evAxisSmallerZero,
3476 TosaErrorValidator.evShapeOfAxisNotOne,
3477 TosaErrorValidator.evWrongInputType,
3478 TosaErrorValidator.evWrongOutputType,
3479 TosaErrorValidator.evWrongRank,
3480 TosaErrorValidator.evWrongInputList,
3481 TosaErrorValidator.evWrongOutputList,
3482 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003483 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003484 "reduce_product": {
3485 "op": Op.REDUCE_PRODUCT,
3486 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003487 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003488 "build_fcn": (
3489 build_reduce,
3490 TosaTensorGen.tgBasic,
3491 TosaTensorValuesGen.tvgDefault,
3492 TosaArgGen.agAxis,
3493 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003494 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003495 "error_if_validators": (
3496 TosaErrorValidator.evAxisLargerRank,
3497 TosaErrorValidator.evAxisSmallerZero,
3498 TosaErrorValidator.evShapeOfAxisNotOne,
3499 TosaErrorValidator.evWrongInputType,
3500 TosaErrorValidator.evWrongOutputType,
3501 TosaErrorValidator.evWrongRank,
3502 TosaErrorValidator.evWrongInputList,
3503 TosaErrorValidator.evWrongOutputList,
3504 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003505 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003506 "reduce_sum": {
3507 "op": Op.REDUCE_SUM,
3508 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003509 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003510 "build_fcn": (
3511 build_reduce,
3512 TosaTensorGen.tgBasic,
3513 TosaTensorValuesGen.tvgReduceSum,
3514 TosaArgGen.agAxis,
3515 ),
James Ward24dbc422022-10-19 12:20:31 +01003516 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003517 "error_if_validators": (
3518 TosaErrorValidator.evAxisLargerRank,
3519 TosaErrorValidator.evAxisSmallerZero,
3520 TosaErrorValidator.evShapeOfAxisNotOne,
3521 TosaErrorValidator.evWrongInputType,
3522 TosaErrorValidator.evWrongOutputType,
3523 TosaErrorValidator.evWrongRank,
3524 TosaErrorValidator.evWrongInputList,
3525 TosaErrorValidator.evWrongOutputList,
3526 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003527 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003528 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003529 "concat": {
3530 "op": Op.CONCAT,
3531 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003532 "build_fcn": (
3533 build_concat,
3534 TosaTensorGen.tgConcat,
3535 TosaTensorValuesGen.tvgConcat,
3536 TosaArgGen.agAxis,
3537 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003538 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003539 "error_if_validators": (
3540 TosaErrorValidator.evAxisLargerRank,
3541 TosaErrorValidator.evAxisSmallerZero,
3542 TosaErrorValidator.evConcatInputRankMismatch,
3543 TosaErrorValidator.evConcatShapeSumMismatch,
3544 TosaErrorValidator.evConcatInputDimMismatch,
3545 TosaErrorValidator.evWrongInputType,
3546 TosaErrorValidator.evWrongOutputType,
3547 TosaErrorValidator.evWrongOutputList,
3548 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003549 },
3550 "pad": {
3551 "op": Op.PAD,
3552 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003553 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003554 "build_fcn": (
3555 build_pad,
3556 TosaTensorGen.tgBasic,
3557 TosaTensorValuesGen.tvgDefault,
3558 TosaArgGen.agPad,
3559 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003560 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003561 "error_if_validators": (
3562 TosaErrorValidator.evWrongInputType,
3563 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003564 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003565 TosaErrorValidator.evWrongOutputType,
3566 TosaErrorValidator.evWrongInputList,
3567 TosaErrorValidator.evWrongOutputList,
3568 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003569 },
3570 "reshape": {
3571 "op": Op.RESHAPE,
3572 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003573 "build_fcn": (
3574 build_reshape,
3575 TosaTensorGen.tgBasic,
3576 TosaTensorValuesGen.tvgDefault,
3577 TosaArgGen.agReshape,
3578 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003579 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003580 "error_if_validators": (
3581 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3582 TosaErrorValidator.evWrongInputType,
3583 TosaErrorValidator.evWrongOutputType,
3584 TosaErrorValidator.evWrongInputList,
3585 TosaErrorValidator.evWrongOutputList,
3586 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003587 },
3588 "reverse": {
3589 "op": Op.REVERSE,
3590 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003591 "build_fcn": (
3592 build_reverse,
3593 TosaTensorGen.tgBasic,
3594 TosaTensorValuesGen.tvgDefault,
3595 TosaArgGen.agAxis,
3596 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003597 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003598 "error_if_validators": (
3599 TosaErrorValidator.evAxisSmallerZero,
3600 TosaErrorValidator.evAxisLargerRank,
3601 TosaErrorValidator.evWrongInputType,
3602 TosaErrorValidator.evWrongOutputType,
3603 TosaErrorValidator.evWrongInputList,
3604 TosaErrorValidator.evWrongOutputList,
3605 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003606 },
3607 "slice": {
3608 "op": Op.SLICE,
3609 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003610 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003611 "build_fcn": (
3612 build_slice,
3613 TosaTensorGen.tgBasic,
3614 TosaTensorValuesGen.tvgDefault,
3615 TosaArgGen.agSlice,
3616 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003617 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003618 "error_if_validators": (
3619 TosaErrorValidator.evStartSmallerZero,
3620 TosaErrorValidator.evSizeSmallerEqualZero,
3621 TosaErrorValidator.evStartSizeOutsideBounds,
3622 TosaErrorValidator.evSizeOutputShapeMismatch,
3623 TosaErrorValidator.evInputSizeStartLengthMismatch,
3624 TosaErrorValidator.evWrongRank,
3625 TosaErrorValidator.evWrongInputType,
3626 TosaErrorValidator.evWrongOutputType,
3627 TosaErrorValidator.evWrongInputList,
3628 TosaErrorValidator.evWrongOutputList,
3629 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003630 },
3631 "tile": {
3632 "op": Op.TILE,
3633 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003634 "build_fcn": (
3635 build_tile,
3636 TosaTensorGen.tgBasic,
3637 TosaTensorValuesGen.tvgDefault,
3638 TosaArgGen.agTile,
3639 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003640 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003641 "error_if_validators": (
3642 TosaErrorValidator.evWrongInputType,
3643 TosaErrorValidator.evWrongOutputType,
3644 TosaErrorValidator.evWrongInputList,
3645 TosaErrorValidator.evWrongOutputList,
3646 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003647 },
3648 "transpose": {
3649 "op": Op.TRANSPOSE,
3650 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003651 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003652 "build_fcn": (
3653 build_transpose,
3654 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003655 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003656 TosaArgGen.agTranspose,
3657 ),
3658 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003659 "error_if_validators": (
3660 TosaErrorValidator.evIndexOutsideBounds,
3661 TosaErrorValidator.evIndexUsedTwice,
3662 TosaErrorValidator.evWrongInputType,
3663 TosaErrorValidator.evWrongOutputType,
3664 TosaErrorValidator.evWrongInputList,
3665 TosaErrorValidator.evWrongOutputList,
3666 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003668 # Data nodes
3669 "const": {
3670 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003671 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003672 "build_fcn": (
3673 build_const,
3674 TosaTensorGen.tgBasic,
3675 TosaTensorValuesGen.tvgDefault,
3676 None,
3677 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 "types": TYPE_FIB,
3679 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003680 "identity": {
3681 "op": Op.IDENTITY,
3682 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003683 "build_fcn": (
3684 build_unary,
3685 TosaTensorGen.tgBasic,
3686 TosaTensorValuesGen.tvgDefault,
3687 None,
3688 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003689 "types": TYPE_FIB,
3690 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003691 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003692 "gather": {
3693 "op": Op.GATHER,
3694 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3695 "operands": (1, 0),
3696 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003697 "build_fcn": (
3698 build_gather,
3699 TosaTensorGen.tgBasic,
3700 TosaTensorValuesGen.tvgDefault,
3701 None,
3702 ),
James Ward24dbc422022-10-19 12:20:31 +01003703 "types": (
3704 DType.INT8,
3705 DType.INT16,
3706 DType.INT32,
3707 DType.FP16,
3708 DType.BF16,
3709 DType.FP32,
3710 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003711 "error_if_validators": (
3712 TosaErrorValidator.evWrongInputType,
3713 TosaErrorValidator.evWrongOutputType,
3714 TosaErrorValidator.evWrongInputList,
3715 TosaErrorValidator.evWrongOutputList,
3716 TosaErrorValidator.evWrongRank,
3717 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003718 },
3719 "scatter": {
3720 "op": Op.SCATTER,
3721 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003722 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003723 "operands": (2, 0),
3724 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003725 "build_fcn": (
3726 build_scatter,
3727 TosaTensorGen.tgScatter,
3728 TosaTensorValuesGen.tvgDefault,
3729 None,
3730 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003731 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003732 "error_if_validators": (
3733 TosaErrorValidator.evWrongInputType,
3734 TosaErrorValidator.evWrongOutputType,
3735 TosaErrorValidator.evWrongInputList,
3736 TosaErrorValidator.evWrongOutputList,
3737 TosaErrorValidator.evWrongRank,
3738 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003739 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003740 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003741 "resize": {
3742 "op": Op.RESIZE,
3743 "operands": (1, 0),
3744 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003745 "build_fcn": (
3746 build_resize,
3747 TosaTensorGen.tgNHWC,
3748 TosaTensorValuesGen.tvgDefault,
3749 TosaArgGen.agResize,
3750 ),
James Ward24dbc422022-10-19 12:20:31 +01003751 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003752 "invalid_test_validators": (
3753 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003754 ),
3755 "error_if_validators": (
3756 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003757 TosaErrorValidator.evScaleSmallerEqualZero,
3758 TosaErrorValidator.evScaleNLargerMax,
3759 TosaErrorValidator.evScaleDLargerMax,
3760 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003761 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003762 TosaErrorValidator.evBorderSmallerMin,
3763 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003764 TosaErrorValidator.evWrongInputType,
3765 TosaErrorValidator.evWrongOutputType,
3766 TosaErrorValidator.evWrongRank,
3767 TosaErrorValidator.evWrongInputList,
3768 TosaErrorValidator.evWrongOutputList,
3769 TosaErrorValidator.evBatchMismatch,
3770 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003771 TosaErrorValidator.evResizeOutputShapeMismatch,
3772 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003773 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003774 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003775 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003776 "cast": {
3777 "op": Op.CAST,
3778 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003779 "build_fcn": (
3780 build_cast,
3781 TosaTensorGen.tgBasic,
3782 TosaTensorValuesGen.tvgDefault,
3783 TosaArgGen.agCast,
3784 ),
James Ward8b390432022-08-12 20:48:56 +01003785 "types": (
3786 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003787 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003788 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003789 DType.INT8,
3790 DType.INT16,
3791 DType.INT32,
3792 DType.BOOL,
3793 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003794 "error_if_validators": (
3795 TosaErrorValidator.evWrongInputType,
3796 TosaErrorValidator.evWrongOutputType,
3797 TosaErrorValidator.evWrongInputList,
3798 TosaErrorValidator.evWrongOutputList,
3799 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003800 },
3801 "rescale": {
3802 "op": Op.RESCALE,
3803 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003804 "build_fcn": (
3805 build_rescale,
3806 TosaTensorGen.tgBasic,
3807 TosaTensorValuesGen.tvgDefault,
3808 TosaArgGen.agRescale,
3809 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003810 "types": [
3811 DType.UINT8,
3812 DType.INT8,
3813 DType.INT16,
3814 DType.INT32,
3815 DType.INT48,
3816 DType.UINT16,
3817 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003818 "error_if_validators": (
3819 TosaErrorValidator.evInputZeroPointNotZero,
3820 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003821 TosaErrorValidator.evU16InputZeroPointNotValid,
3822 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003823 TosaErrorValidator.evScaleTrue,
3824 TosaErrorValidator.evScaleNotTrue,
3825 TosaErrorValidator.evWrongInputType,
3826 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003827 TosaErrorValidator.evWrongInputList,
3828 TosaErrorValidator.evWrongOutputList,
3829 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003830 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003831 # Custom
3832 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003833 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003834 # Two varients of cond_if, one that generates one of two constant tensors (no
3835 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3836 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003837 "cond_if_const": {
3838 "op": Op.COND_IF,
3839 "operands": (0, 2),
3840 "build_fcn": (
3841 build_cond_if_const,
3842 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003843 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003844 TosaArgGen.agCondIf,
3845 ),
3846 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003847 "error_if_validators": (
3848 TosaErrorValidator.evOutputListThenGraphMismatch,
3849 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003850 TosaErrorValidator.evCondIfCondNotMatchingBool,
3851 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003852 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003853 },
3854 "cond_if_binary": {
3855 "op": Op.COND_IF,
3856 "operands": (2, 0),
3857 "build_fcn": (
3858 build_cond_if_binary,
3859 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003860 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003861 TosaArgGen.agCondIf,
3862 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003863 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003864 "error_if_validators": (
3865 TosaErrorValidator.evInputListThenGraphMismatch,
3866 TosaErrorValidator.evInputListElseGraphMismatch,
3867 TosaErrorValidator.evOutputListThenGraphMismatch,
3868 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003869 TosaErrorValidator.evCondIfCondNotMatchingBool,
3870 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003872 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003873 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003874 "while_loop": {
3875 "op": Op.WHILE_LOOP,
3876 "operands": (0, 1),
3877 "build_fcn": (
3878 build_while_loop,
3879 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003880 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003881 TosaArgGen.agWhileLoop,
3882 ),
3883 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003884 "error_if_validators": (
3885 TosaErrorValidator.evInputListOutputListMismatch,
3886 TosaErrorValidator.evInputListCondGraphMismatch,
3887 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3888 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3889 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003890 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003891 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003892 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003893 }
3894
Kevin Cheng550ccc52021-03-03 11:21:43 -08003895
Eric Kunzee5e26762020-10-13 16:11:07 -07003896class OutputShaper:
3897 # Methods in this class compute the expected output shape and datatype
3898 # for common classes of operations
3899 def __init__(self):
3900 pass
3901
3902 # These methods return arguments that can be used for
3903 # creating a new output tensor
3904 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003905 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3906 if error_name != ErrorIf.RankMismatch:
3907 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003908 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003909
3910 shape = []
3911 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003912 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003913 shape.append(b.shape[i])
3914 else:
3915 shape.append(a.shape[i])
3916
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003917 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003918 all_dtypes = [
3919 DType.INT8,
3920 DType.INT16,
3921 DType.INT32,
3922 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01003923 DType.FP16,
3924 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003925 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003926 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003927 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3928 outputDType = rng.choice(wrong_dtypes)
3929 else:
3930 outputDType = a.dtype
3931
3932 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003933
3934 @staticmethod
3935 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003936 assert len(a.shape) == len(b.shape)
3937 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003938
3939 shape = []
3940 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003941 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003942 shape.append(a.shape[i])
3943
Kevin Cheng550ccc52021-03-03 11:21:43 -08003944 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003945
3946 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003947 def unaryOp(ser, rng, a, error_name=None):
3948 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003949 all_dtypes = [
3950 DType.INT8,
3951 DType.INT16,
3952 DType.INT32,
3953 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003954 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01003955 DType.FP16,
3956 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003957 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003958 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3959 outputDType = rng.choice(wrong_dtypes)
3960 else:
3961 outputDType = a.dtype
3962
3963 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003964
3965 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003966 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003967 if error_name != ErrorIf.RankMismatch:
3968 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003969 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003970
3971 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003972 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003973 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003974 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3975 else:
3976 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003977
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003978 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003979 all_dtypes = [
3980 DType.INT8,
3981 DType.INT16,
3982 DType.INT32,
3983 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003984 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01003985 DType.FP16,
3986 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003987 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003988 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3989 outputDType = rng.choice(wrong_dtypes)
3990 else:
3991 outputDType = a.dtype
3992
3993 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003994
3995 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003996 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003997 if error_name != ErrorIf.RankMismatch:
3998 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003999 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004000
4001 # Do broadcast
4002 shape = []
4003 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004004 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004005 shape.append(b.shape[i])
4006 else:
4007 shape.append(a.shape[i])
4008
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004009 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004010 wrong_dtypes = [
4011 DType.INT8,
4012 DType.INT16,
4013 DType.INT32,
4014 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004015 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004016 DType.FP16,
4017 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004018 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004019 outputDType = rng.choice(wrong_dtypes)
4020 else:
4021 outputDType = DType.BOOL
4022
4023 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004024
4025 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004026 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004027 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004028 if error_name not in [
4029 ErrorIf.AxisSmallerZero,
4030 ErrorIf.AxisLargerRank,
4031 ErrorIf.ShapeOfAxisNotOne,
4032 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004033 shape[axis] = 1
4034 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4035 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004036
Matthew Haddond6ce7252021-09-29 15:35:44 +01004037 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004038 all_dtypes = [
4039 DType.INT8,
4040 DType.INT16,
4041 DType.INT32,
4042 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004043 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004044 DType.FP16,
4045 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004046 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004047 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4048 outputDType = rng.choice(wrong_dtypes)
4049 else:
4050 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004051
Matthew Haddond6ce7252021-09-29 15:35:44 +01004052 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004053
4054 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004055 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004056 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004057
4058 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4059 del shape[axis]
4060
4061 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4062 remove = rng.choice([True, False])
4063 if remove and len(shape) > 1:
4064 del shape[0]
4065 else:
4066 shape.append(1)
4067 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4068 for i in range(len(shape)):
4069 shape[i] = shape[i] + rng.integers(1, 10)
4070
4071 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004072 all_dtypes = [
4073 DType.INT8,
4074 DType.INT16,
4075 DType.INT32,
4076 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004077 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004078 DType.FP16,
4079 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004080 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004081 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4082 outputDType = rng.choice(wrong_dtypes)
4083 else:
4084 outputDType = DType.INT32
4085
4086 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004087
4088 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004089 def conv2dOp(
4090 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4091 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004092
4093 # IFM: NHWC
4094 # Filter: OHWI
4095 # OFM: NHWC
4096
Kevin Cheng550ccc52021-03-03 11:21:43 -08004097 h = (
4098 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004099 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004100 + padding[0]
4101 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004102 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004103 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004104
Kevin Cheng550ccc52021-03-03 11:21:43 -08004105 w = (
4106 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004107 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004108 + padding[2]
4109 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004110 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004111 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004112
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004113 if error_name == ErrorIf.ConvOutputShapeMismatch:
4114 choices = [1, 2, 3]
4115 change = rng.choice(choices)
4116 # increment in multiples of stride to not hit non-integer error case
4117 if change in [1, 3]:
4118 h = h + (rng.choice(choices) * strides[0])
4119 if change in [2, 3]:
4120 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004121
Eric Kunzee5e26762020-10-13 16:11:07 -07004122 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4123
James Ward8b390432022-08-12 20:48:56 +01004124 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004125 # Pick some potentially correct output dtype if input type is incorrect
4126 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004127 else:
James Ward8b390432022-08-12 20:48:56 +01004128 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004129
4130 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004131 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004132 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004133 else:
4134 excludes = [out_dtype]
4135 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004136 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004137
Kevin Cheng550ccc52021-03-03 11:21:43 -08004138 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004139
4140 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004141 def conv3dOp(
4142 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4143 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004144
4145 # IFM: NDHWC
4146 # Filter: ODHWI
4147 # OFM: NDHWC
4148
4149 d = (
4150 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004151 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004152 + padding[0]
4153 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004154 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004155 ) // strides[0] + 1
4156
4157 h = (
4158 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004159 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004160 + padding[2]
4161 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004162 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004163 ) // strides[1] + 1
4164
4165 w = (
4166 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004167 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004168 + padding[4]
4169 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004170 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004171 ) // strides[2] + 1
4172
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004173 if error_name == ErrorIf.ConvOutputShapeMismatch:
4174 choices = [1, 2, 3, 4]
4175 change = rng.choice(choices)
4176 # increment in multiples of stride to not hit non-integer error case
4177 if change in [1, 4]:
4178 d = d + (rng.choice(choices) * strides[0])
4179 if change in [2, 4]:
4180 h = h + (rng.choice(choices) * strides[1])
4181 if change in [3, 4]:
4182 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004183
Kevin Cheng1533b852021-09-01 12:51:58 -07004184 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4185
James Ward8b390432022-08-12 20:48:56 +01004186 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004187 # Pick some potentially correct output dtype if input type is incorrect
4188 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004189 else:
James Ward8b390432022-08-12 20:48:56 +01004190 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004191
4192 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004193 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004194 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004195 else:
4196 excludes = [out_dtype]
4197 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004198 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004199
4200 return ser.addOutput(ofm_shape, out_dtype)
4201
4202 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004203 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004204 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004205 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004206 # IFM: NHWC
4207 # Filter: HWCM
4208 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004209
Kevin Cheng550ccc52021-03-03 11:21:43 -08004210 h = (
4211 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004212 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004213 + padding[0]
4214 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004215 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004216 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004217
Kevin Cheng550ccc52021-03-03 11:21:43 -08004218 w = (
4219 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004220 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004221 + padding[2]
4222 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004223 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004224 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004225
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004226 if error_name == ErrorIf.ConvOutputShapeMismatch:
4227 choices = [1, 2, 3]
4228 change = rng.choice(choices)
4229 # increment in multiples of stride to not hit non-integer error case
4230 if change in [1, 3]:
4231 h = h + (rng.choice(choices) * strides[0])
4232 if change in [2, 3]:
4233 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004234
Eric Kunzee5e26762020-10-13 16:11:07 -07004235 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4236
James Ward8b390432022-08-12 20:48:56 +01004237 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004238 # Pick some potentially correct output dtype if input type is incorrect
4239 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004240 else:
James Ward8b390432022-08-12 20:48:56 +01004241 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004242
4243 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004244 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004245 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004246 else:
4247 excludes = [out_dtype]
4248 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004249 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004250
Kevin Cheng550ccc52021-03-03 11:21:43 -08004251 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004252
4253 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004254 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004255 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004256 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004257 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004258 h = 1
4259 w = 1
4260 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004261 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4262 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004263
4264 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004265 choices = [1, 2, 3]
4266 change = rng.choice(choices)
4267 # increment in multiples of stride to not hit non-integer error case
4268 if change in [1, 3]:
4269 h = h + (rng.choice(choices) * stride[0])
4270 if change in [2, 3]:
4271 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004272 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004273
4274 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004275 all_dtypes = [
4276 DType.INT8,
4277 DType.INT16,
4278 DType.INT32,
4279 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004280 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004281 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004282 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004283 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004284 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4285 outputDType = rng.choice(wrong_dtypes)
4286 else:
4287 outputDType = ifm.dtype
4288
4289 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004290
4291 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004292 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004293 # input: N, IC
4294 # filter: OC, IC
4295 # output: N, OC
4296
4297 output_shape = [input.shape[0], filter.shape[0]]
4298
James Ward8b390432022-08-12 20:48:56 +01004299 # Validated in arg_gen (also invalidated for ErrorIf)
4300 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004301
Kevin Cheng550ccc52021-03-03 11:21:43 -08004302 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004303
4304 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004305 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004306 # a: N, H, C
4307 # b: N, C, W
4308 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004309
Kevin Cheng2d60f002021-06-09 14:18:32 -07004310 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004311
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004312 if error_name == ErrorIf.WrongOutputType:
4313 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004314 incorrect_types = (
4315 DType.INT4,
4316 DType.INT8,
4317 DType.INT16,
4318 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004319 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004320 DType.FP16,
4321 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004322 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004323 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004324 incorrect_types = (
4325 DType.INT4,
4326 DType.INT8,
4327 DType.INT16,
4328 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004329 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004330 DType.FP16,
4331 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004332 )
James Ward24dbc422022-10-19 12:20:31 +01004333 elif (
4334 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4335 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004336 incorrect_types = (
4337 DType.INT4,
4338 DType.INT8,
4339 DType.INT16,
4340 DType.INT32,
4341 DType.INT48,
4342 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004343 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004344 elif error_name == ErrorIf.WrongInputType:
4345 # Pick some potentially correct output dtype if input type is incorrect
4346 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004347 else:
James Ward8b390432022-08-12 20:48:56 +01004348 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004349
Kevin Cheng550ccc52021-03-03 11:21:43 -08004350 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004351
4352 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004353 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004354 input1 = a[0]
4355 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004356
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004357 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004358 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004359 if not (
4360 # unable to concat tensors of different ranks
4361 error_name == ErrorIf.ConcatInputRankMismatch
4362 # unable to concat tensors along an invalid axis
4363 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004364 ):
4365 for tensor in remaining_inputs:
4366 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004367
Matthew Haddon01c359d2021-10-15 16:30:48 +01004368 if error_name == ErrorIf.ConcatShapeSumMismatch:
4369 output_shape[axis] += rng.integers(5, 10)
4370
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004371 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004372 all_dtypes = {
4373 DType.INT8,
4374 DType.INT16,
4375 DType.INT32,
4376 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004377 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004378 DType.FP16,
4379 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004380 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004381 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4382 outputDType = rng.choice(wrong_dtypes)
4383 else:
4384 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004385
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004386 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004387
4388 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004389 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004390
4391 output_shape = a.shape.copy()
4392
4393 for i in range(len(output_shape)):
4394 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4395
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004396 if error_name == ErrorIf.PadOutputShapeMismatch:
4397 bad_dim = rng.choice(range(len(output_shape)))
4398 output_shape[bad_dim] -= rng.choice([1, 2])
4399
Matthew Haddone807aae2021-10-11 18:12:58 +01004400 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004401 all_dtypes = [
4402 DType.INT8,
4403 DType.INT16,
4404 DType.INT32,
4405 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004406 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004407 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004408 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004409 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004410 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4411 outputDType = rng.choice(wrong_dtypes)
4412 else:
4413 outputDType = a.dtype
4414
4415 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004416
4417 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004418 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004419 output_shape = shape.copy()
4420
Matthew Haddone807aae2021-10-11 18:12:58 +01004421 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4422 for i in range(len(output_shape)):
4423 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4424
4425 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004426 all_dtypes = [
4427 DType.INT8,
4428 DType.INT16,
4429 DType.INT32,
4430 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004431 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004432 DType.FP16,
4433 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004434 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004435 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4436 outputDType = rng.choice(wrong_dtypes)
4437 else:
4438 outputDType = a.dtype
4439
4440 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004441
4442 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004443 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004444
Matthew Haddone807aae2021-10-11 18:12:58 +01004445 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004446 all_dtypes = [
4447 DType.INT8,
4448 DType.INT16,
4449 DType.INT32,
4450 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004451 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004452 DType.FP16,
4453 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004454 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004455 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4456 outputDType = rng.choice(wrong_dtypes)
4457 else:
4458 outputDType = a.dtype
4459
4460 if error_name == ErrorIf.SizeOutputShapeMismatch:
4461 output_shape = size.copy()
4462 for index in range(len(output_shape)):
4463 if output_shape[index] <= 2:
4464 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4465 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004466 output_shape[index] = output_shape[index] + rng.choice(
4467 [-2, -1, 1, 2]
4468 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004469 else:
4470 output_shape = size.copy()
4471
4472 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004473
4474 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004475 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004476
4477 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004478 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004479
4480 for i in range(len(output_shape)):
4481 output_shape[i] = a.shape[i] * multiples[i]
4482
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004483 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004484 all_dtypes = [
4485 DType.INT8,
4486 DType.INT16,
4487 DType.INT32,
4488 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004489 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004490 DType.FP16,
4491 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004492 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004493 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4494 outputDType = rng.choice(wrong_dtypes)
4495 else:
4496 outputDType = a.dtype
4497
4498 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004499
4500 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004501 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004502 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004503
Kevin Cheng550ccc52021-03-03 11:21:43 -08004504 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004505
Matthew Haddone807aae2021-10-11 18:12:58 +01004506 if error_name == ErrorIf.IndexOutsideBounds:
4507 for i in range(len(output_shape)):
4508 output_shape[i] = a.shape[0]
4509 else:
4510 for i in range(len(output_shape)):
4511 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004512
Matthew Haddone807aae2021-10-11 18:12:58 +01004513 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004514 all_dtypes = [
4515 DType.INT8,
4516 DType.INT16,
4517 DType.INT32,
4518 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004519 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004520 DType.FP16,
4521 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004522 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004523 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4524 outputDType = rng.choice(wrong_dtypes)
4525 else:
4526 outputDType = a.dtype
4527
4528 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004529
4530 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004531 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004532 if error_name != ErrorIf.WrongRank:
4533 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004534 assert len(indices.shape) == 2
4535 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004536
Kevin Cheng77d0f762020-11-24 10:26:32 -08004537 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4538
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004539 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004540 all_dtypes = [
4541 DType.INT8,
4542 DType.INT16,
4543 DType.INT32,
4544 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004545 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004546 DType.FP16,
4547 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004548 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004549 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4550 outputDType = rng.choice(wrong_dtypes)
4551 else:
4552 outputDType = values.dtype
4553
4554 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004555
4556 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004557 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004558 if error_name != ErrorIf.WrongRank:
4559 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004560 assert len(indices.shape) == 2
4561 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004562 assert values_in.shape[0] == indices.shape[0] # N
4563 assert input.shape[1] == indices.shape[1] # W
4564 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004565
4566 output_shape = values_in.shape
4567
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004568 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004569 all_dtypes = [
4570 DType.INT8,
4571 DType.INT16,
4572 DType.INT32,
4573 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004574 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004575 DType.FP16,
4576 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004577 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004578 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4579 outputDType = rng.choice(wrong_dtypes)
4580 else:
4581 outputDType = values_in.dtype
4582
4583 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004584
4585 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004586 def tableOp(ser, rng, input, error_name=None):
4587 # Same shape as the input, dtype dependent on input dtype
4588 if error_name != ErrorIf.WrongInputType:
4589 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004590 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004591 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004592 wrong_dtypes = [
4593 DType.INT8,
4594 DType.INT16,
4595 DType.INT32,
4596 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004597 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004598 DType.FP16,
4599 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004600 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004601 wrong_dtypes.remove(output_dtype)
4602 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004603 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004604
4605 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004606 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004607 serializer,
4608 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004609 input,
4610 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004611 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004612 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004613 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004614 input_dtype,
4615 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004616 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004617 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004618 # Calculate OH, OW
4619 scale_y_n = scale[0]
4620 scale_y_d = scale[1]
4621 scale_x_n = scale[2]
4622 scale_x_d = scale[3]
4623 if error_name == ErrorIf.ScaleSmallerEqualZero:
4624 scale_y_n = max(scale_y_n, 1)
4625 scale_y_d = max(scale_y_d, 1)
4626 scale_x_n = max(scale_x_n, 1)
4627 scale_x_d = max(scale_x_d, 1)
4628
4629 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4630 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4631
4632 if error_name is not None:
4633 # Make sure the output tensor is valid, which can occur when
4634 # scale, offset or border have been changed for ERROR_IFs
4635 oh = max(oh, 1)
4636 ow = max(ow, 1)
4637 if error_name != ErrorIf.MaxDimExceeded:
4638 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4639 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4640
4641 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4642 choices = [1, 2, 3]
4643 change = rng.choice(choices)
4644 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4645 if change in [1, 3]:
4646 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4647 oh -= scale_y_d
4648 assert oh > 0 # Should have been caught in agResize
4649 else:
4650 oh += scale_y_d
4651 if change in [2, 3]:
4652 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4653 ow -= scale_x_d
4654 assert ow > 0 # Should have been caught in agResize
4655 else:
4656 ow += scale_x_d
4657
Matthew Haddon848efb42021-09-09 12:30:53 +01004658 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004659 output_dims = [
4660 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004661 oh,
4662 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004663 input.shape[0],
4664 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004665 elif error_name == ErrorIf.BatchMismatch:
4666 output_dims = [
4667 input.shape[0] + rng.integers(1, 10),
4668 oh,
4669 ow,
4670 input.shape[3],
4671 ]
4672 elif error_name == ErrorIf.ChannelMismatch:
4673 output_dims = [
4674 input.shape[0],
4675 oh,
4676 ow,
4677 input.shape[3] + rng.integers(1, 10),
4678 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004679 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004680 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004681
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004682 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004683
4684 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004685 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004686 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004687
4688 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004689 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004690 if error_name == ErrorIf.ConvOutputShapeMismatch:
4691 choices = [1, 2, 3]
4692 change = rng.choice(choices)
4693 if change in [1, 3]:
4694 output_shape[1] = output_shape[1] + rng.choice(choices)
4695 if change in [2, 3]:
4696 output_shape[2] = output_shape[2] + rng.choice(choices)
4697
James Ward8b390432022-08-12 20:48:56 +01004698 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004699 # Pick some potentially correct output dtype if input type is incorrect
4700 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004701 else:
James Ward8b390432022-08-12 20:48:56 +01004702 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004703
4704 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004705 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004706 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004707 else:
4708 excludes = [out_dtype]
4709 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004710 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004711
Kevin Cheng550ccc52021-03-03 11:21:43 -08004712 return ser.addOutput(output_shape, out_dtype)