blob: 515e8bb2b546324b2ab0c608803631cd4a06ed00 [file] [log] [blame]
Eric Kunzea1d49852022-01-04 10:07:29 -08001# Copyright (c) 2020-2022, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Jeremy Johnson05c711e2022-12-12 18:00:41 +000017from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010018from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010019from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010020from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000021from tosa.DType import DType
22from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010023
24
Eric Kunzee5e26762020-10-13 16:11:07 -070025class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010026 # Maximum rank of tensor supported by test generator.
27 TOSA_TENSOR_MAX_RANK = 6
28
Eric Kunzee5e26762020-10-13 16:11:07 -070029 def __init__(self, args):
30 self.args = args
31 self.basePath = args.output_dir
32 self.random_seed = args.random_seed
33 self.ser = None
34 self.rng = np.random.default_rng(self.random_seed)
35 self.createDynamicOpLists()
36 self.initOpListDefaults()
37 self.quantGen = TosaQuantGen()
38 # Force makeShape to do a specific starting shape
39 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010040 # Work out floating point range
41 self.random_fp_low = min(args.tensor_fp_value_range)
42 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070043
44 def createSerializer(self, opName, testPath):
45 self.testPath = os.path.join(opName, testPath)
46
47 fullPath = os.path.join(self.basePath, self.testPath)
48 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010049 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070050
51 def getSerializer(self):
52 return self.ser
53
54 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080055 with open(
56 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
57 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070058 fd.write(self.ser.serialize())
59
Kevin Cheng550ccc52021-03-03 11:21:43 -080060 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
61 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070062
Matthew Haddon74567092021-07-16 15:38:20 +010063 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000064 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010065 seed = self.random_seed + 1
66 self.rng = np.random.default_rng(seed)
67
Eric Kunzee5e26762020-10-13 16:11:07 -070068 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070069 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070071 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070072 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070073 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070074 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010075 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
76 elif dtype == DType.UINT8:
77 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070078 elif dtype == DType.INT16:
79 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010080 elif dtype == DType.UINT16:
81 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070082 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080083 return np.int32(
84 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
85 )
Eric Kunzee5e26762020-10-13 16:11:07 -070086 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080087 return np.int64(
88 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
89 )
James Ward8b390432022-08-12 20:48:56 +010090 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010091 return np.float16(
92 self.rng.uniform(
93 low=self.random_fp_low, high=self.random_fp_high, size=shape
94 )
95 )
James Ward24dbc422022-10-19 12:20:31 +010096 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010097 f32_tensor = np.float32(
98 self.rng.uniform(
99 low=self.random_fp_low, high=self.random_fp_high, size=shape
100 )
101 )
James Ward24dbc422022-10-19 12:20:31 +0100102 # Floor the last 16 bits of each f32 value
103 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100104 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100105 return np.float32(
106 self.rng.uniform(
107 low=self.random_fp_low, high=self.random_fp_high, size=shape
108 )
109 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800111 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700112
Kevin Cheng989cb052021-04-28 16:29:44 -0700113 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 placeholders = []
115
Kevin Cheng989cb052021-04-28 16:29:44 -0700116 assert len(shape_list) == len(dtype_list)
117
118 for idx, shape in enumerate(shape_list):
119 arr = self.getRandTensor(shape, dtype_list[idx])
120 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 return placeholders
123
Kevin Cheng989cb052021-04-28 16:29:44 -0700124 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700125 consts = []
126
Kevin Cheng989cb052021-04-28 16:29:44 -0700127 assert len(shape_list) == len(dtype_list)
128
129 for idx, shape in enumerate(shape_list):
130 arr = self.getRandTensor(shape, dtype_list[idx])
131 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132
133 return consts
134
135 def makeShape(self, rank):
136 if self.targetted_shape:
137 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800138 return np.int32(
139 self.rng.integers(
140 low=self.args.tensor_shape_range[0],
141 high=self.args.tensor_shape_range[1],
142 size=rank,
143 )
144 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
146 def setTargetShape(self, shape):
147 self.targetted_shape = shape
148
149 def randInt(self, low=0, high=256):
150 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
151
152 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100153 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100154 return np.float32(
155 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
156 )
James Ward8b390432022-08-12 20:48:56 +0100157 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100158 return np.float16(
159 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
160 )
James Ward24dbc422022-10-19 12:20:31 +0100161 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162 rand_f32 = np.float32(
163 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
164 )
James Ward24dbc422022-10-19 12:20:31 +0100165 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 elif dtype == DType.BOOL:
167 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700168 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700169 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700170 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100172 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 elif dtype == DType.INT16:
174 low, high = (-32768, 32768)
175 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800176 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800178 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 # Special size
180 return np.int64(self.rng.integers(low, high, size=1))[0]
181 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800182 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700183
184 return np.int32(self.rng.integers(low, high, size=1))[0]
185
186 def shapeStr(self, shape):
187
188 sStr = []
189 # Convert to strings
190 for i in shape:
191 sStr.append(str(i))
192
Kevin Cheng550ccc52021-03-03 11:21:43 -0800193 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700194
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100195 def typeStr(self, dtype):
196 if isinstance(dtype, list) or isinstance(dtype, tuple):
197 assert len(dtype) >= 2
198 strs = [self.typeStr(t) for t in dtype]
199 # Limit types to the first 2 as the 3rd is the accumulator
200 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100202 if dtype in DTYPE_ATTRIBUTES:
203 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700204 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100205 raise Exception(
206 "Unknown dtype, cannot convert to string: {}".format(dtype)
207 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100209 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100210 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100211 if dtype in DTYPE_ATTRIBUTES:
212 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 # Argument generators
217 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
218 # Where the string descriptor is used to generate the test name and
219 # The build_fcn_arg_list is expanded and passed to the operator test
220 # build function
221
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100222 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
223 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
224
Matthew Haddon848efb42021-09-09 12:30:53 +0100225 # build_placeholder returns an int, ABS/other ops does not
226 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100228 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000229 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000230 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100231 return result_tens
232
233 # Ensure new output type has correct qinfo
234 if error_name == ErrorIf.WrongOutputType:
235 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000236 qinfo = [
237 TosaQuantGen.getZeroPoint(self, a.dtype),
238 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
239 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100240
241 # Invalidate Input/Output list for error if checks.
242 input_list = [a.name]
243 output_list = [result_tens.name]
244 pCount, cCount = op["operands"]
245 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000246 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
247 self, error_name, input_list, output_list
248 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100249
Les Bell729b0352021-11-24 10:28:21 +0000250 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100251 self.ser,
252 validator_fcns,
253 error_name,
254 op=op,
255 input_dtype=a.dtype,
256 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000257 qinfo=qinfo,
258 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100259 input_list=input_list,
260 output_list=output_list,
261 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000262 ):
263 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100264
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000265 attr = None
266 if op["op"] == Op.NEGATE:
267 attr = ts.TosaSerializerAttribute()
268 attr.NegateAttribute(qinfo[0], qinfo[1])
269
270 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700271 return result_tens
272
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100273 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000274 result_tens = OutputShaper.binaryBroadcastOp(
275 self.ser, self.rng, a, b, error_name
276 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100277
278 # Invalidate Input/Output list for error if checks.
279 input_list = [a.name, b.name]
280 output_list = [result_tens.name]
281 pCount, cCount = op["operands"]
282 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000283 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
284 self, error_name, input_list, output_list
285 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100286
Les Bell729b0352021-11-24 10:28:21 +0000287 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100288 self.ser,
289 validator_fcns,
290 error_name,
291 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000292 input1=a,
293 input2=b,
294 input_dtype=a.dtype,
295 output_dtype=result_tens.dtype,
296 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100297 input_list=input_list,
298 output_list=output_list,
299 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000300 ):
301 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100302
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000303 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 return result_tens
305
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100306 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700307 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000308 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700309 return result_tens
310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000311 def build_arithmetic_right_shift(
312 self, op, a, b, round, validator_fcns=None, error_name=None
313 ):
314 result_tens = OutputShaper.binaryBroadcastOp(
315 self.ser, self.rng, a, b, error_name
316 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100317
318 # Invalidate Input/Output list for error if checks.
319 input_list = [a.name, b.name]
320 output_list = [result_tens.name]
321 pCount, cCount = op["operands"]
322 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000323 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
324 self, error_name, input_list, output_list
325 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100326
Les Bell729b0352021-11-24 10:28:21 +0000327 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100328 self.ser,
329 validator_fcns,
330 error_name,
331 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000332 input1=a,
333 input2=b,
334 input_dtype=a.dtype,
335 output_dtype=result_tens.dtype,
336 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100337 input_list=input_list,
338 output_list=output_list,
339 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000340 ):
341 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800342
343 attr = ts.TosaSerializerAttribute()
344 attr.ArithmeticRightShiftAttribute(round)
345
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000346 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800347 return result_tens
348
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100349 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000350 result_tens = OutputShaper.binaryBroadcastOp(
351 self.ser, self.rng, a, b, error_name
352 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
354 # Special for multiply:
355 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100356 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700357 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100358 if error_name == ErrorIf.WrongOutputType:
359 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
360 outputDType = self.rng.choice(all_dtypes)
361 result_tens.setDtype(outputDType)
362
363 # Invalidate Input/Output list for error if checks.
364 input_list = [a.name, b.name]
365 output_list = [result_tens.name]
366 pCount, cCount = op["operands"]
367 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000368 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
369 self, error_name, input_list, output_list
370 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100371
Les Bell729b0352021-11-24 10:28:21 +0000372 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100373 self.ser,
374 validator_fcns,
375 error_name,
376 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000377 input1=a,
378 input2=b,
379 input_dtype=a.dtype,
380 output_dtype=result_tens.dtype,
381 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100382 input_list=input_list,
383 output_list=output_list,
384 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000385 ):
386 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700387
Kevin Chengaee1fac2020-11-11 13:54:06 -0800388 attr = ts.TosaSerializerAttribute()
389 attr.MulAttribute(shift)
390
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000391 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 return result_tens
393
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100394 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
395 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700396
Kevin Chengfe392ce2021-10-18 21:51:55 +0000397 attr = ts.TosaSerializerAttribute()
398 attr.TableAttribute(table)
399
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100400 # Invalidate Input/Output list for error if checks.
401 input_list = [a.name]
402 output_list = [result_tens.name]
403 pCount, cCount = op["operands"]
404 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000405 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
406 self, error_name, input_list, output_list
407 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100408
Les Bell729b0352021-11-24 10:28:21 +0000409 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410 self.ser,
411 validator_fcns,
412 error_name,
413 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000414 input_shape=a.shape,
415 input_dtype=a.dtype,
416 output_dtype=result_tens.dtype,
417 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100418 input_list=input_list,
419 output_list=output_list,
420 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000421 ):
422 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100423
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700425
426 return result_tens
427
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100428 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
429 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
430
431 # Invalidate Input/Output list for error if checks.
432 input_list = [cond.name, a.name, b.name]
433 output_list = [result_tens.name]
434 pCount, cCount = op["operands"]
435 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
437 self, error_name, input_list, output_list
438 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100439
Les Bell729b0352021-11-24 10:28:21 +0000440 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100441 self.ser,
442 validator_fcns,
443 error_name,
444 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000445 input1=cond,
446 input2=a,
447 input3=b,
448 input_shape=a.shape,
449 input_dtype=a.dtype,
450 output_dtype=result_tens.dtype,
451 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100452 input_list=input_list,
453 output_list=output_list,
454 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000455 ):
456 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000458 self.ser.addOperator(
459 op["op"],
460 input_list,
461 output_list,
462 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700463 return result_tens
464
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100465 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000466 result_tens = OutputShaper.binaryComparisonOp(
467 self.ser, self.rng, a, b, error_name
468 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100469
470 # Invalidate Input/Output list for error if checks.
471 input_list = [a.name, b.name]
472 output_list = [result_tens.name]
473 pCount, cCount = op["operands"]
474 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000475 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
476 self, error_name, input_list, output_list
477 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478
Les Bell729b0352021-11-24 10:28:21 +0000479 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480 self.ser,
481 validator_fcns,
482 error_name,
483 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000484 input1=a,
485 input2=b,
486 input_shape=a.shape,
487 input_dtype=a.dtype,
488 output_shape=result_tens.shape,
489 output_dtype=result_tens.dtype,
490 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491 input_list=input_list,
492 output_list=output_list,
493 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000494 ):
495 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100496
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000497 self.ser.addOperator(
498 op["op"],
499 input_list,
500 output_list,
501 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700502 return result_tens
503
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100504 def build_argmax(self, op, a, axis, validator_fcns, error_name):
505 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
506
507 # Invalidate Input/Output list for error if checks.
508 input_list = [a.name]
509 output_list = [result_tens.name]
510 pCount, cCount = op["operands"]
511 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000512 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
513 self, error_name, input_list, output_list
514 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100515
Les Bell729b0352021-11-24 10:28:21 +0000516 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100517 self.ser,
518 validator_fcns,
519 error_name,
520 op=op,
521 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 input_shape=a.shape,
523 input_dtype=a.dtype,
524 output_shape=result_tens.shape,
525 output_dtype=result_tens.dtype,
526 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100527 input_list=input_list,
528 output_list=output_list,
529 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000530 ):
531 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700532
533 attr = ts.TosaSerializerAttribute()
534 attr.AxisAttribute(axis)
535
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700537 return result_tens
538
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000539 def build_pool2d(
540 self,
541 op,
542 input,
James Ward8b390432022-08-12 20:48:56 +0100543 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 stride,
545 pad,
546 kernel,
547 validator_fcns=None,
548 error_name=None,
549 qinfo=None,
550 ):
551 result_tens = OutputShaper.pool2dOp(
552 self.ser, self.rng, input, kernel, stride, pad, error_name
553 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100554
555 # Ensure new output type has correct qinfo
556 if error_name == ErrorIf.WrongInputType:
557 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000558 qinfo = [
559 TosaQuantGen.getZeroPoint(self, input.dtype),
560 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
561 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100562
563 # Invalidate Input/Output list for error if checks.
564 input_list = [input.name]
565 output_list = [result_tens.name]
566 pCount, cCount = op["operands"]
567 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000568 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
569 self, error_name, input_list, output_list
570 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100571
Les Bell729b0352021-11-24 10:28:21 +0000572 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100573 self.ser,
574 validator_fcns,
575 error_name,
576 op=op,
577 input_shape=input.shape,
578 input_dtype=input.dtype,
579 output_shape=result_tens.shape,
580 output_dtype=result_tens.dtype,
581 kernel=kernel,
582 stride=stride,
583 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 qinfo=qinfo,
585 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100586 input_list=input_list,
587 output_list=output_list,
588 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000589 ):
590 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700591
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000592 if qinfo is None:
593 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700594
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000595 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100596 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000597
598 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700599 return result_tens
600
James Ward8b390432022-08-12 20:48:56 +0100601 def build_maxpool2d(
602 self,
603 op,
604 input,
605 stride,
606 pad,
607 kernel,
608 validator_fcns=None,
609 error_name=None,
610 qinfo=None,
611 ):
612 # Same as build_pool2d but manually sets accum_dtype value
613 # (maxpool has no accum_dtype)
614 return self.build_pool2d(
615 op,
616 input,
617 DType.UNKNOWN,
618 stride,
619 pad,
620 kernel,
621 validator_fcns,
622 error_name,
623 qinfo,
624 )
625
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 def build_conv2d(
627 self,
628 op,
629 ifm,
630 filter,
631 bias,
James Ward8b390432022-08-12 20:48:56 +0100632 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000633 strides,
634 padding,
635 dilations,
636 validator_fcns=None,
637 error_name=None,
638 qinfo=None,
639 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800640 assert len(padding) == 4
641 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100642 self.ser,
643 self.rng,
644 ifm,
645 filter,
646 accum_dtype,
647 strides,
648 padding,
649 dilations,
650 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000651 )
652
653 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
655 DType.INT8,
656 DType.UINT8,
657 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000658 qinfo = [
659 TosaQuantGen.getZeroPoint(self, ifm.dtype),
660 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
661 ]
Les Bell0e027d42021-11-09 14:42:14 +0000662
663 # Invalidate Input/Output list for error_if checks.
664 input_list = [ifm.name, filter.name, bias.name]
665 output_list = [result_tens.name]
666 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
668 self, error_name, input_list, output_list
669 )
Les Bell0e027d42021-11-09 14:42:14 +0000670
Les Bell729b0352021-11-24 10:28:21 +0000671 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000672 self.ser,
673 validator_fcns,
674 error_name,
675 op=op,
676 input_dtype=ifm.dtype,
677 weight_dtype=filter.dtype,
678 output_dtype=result_tens.dtype,
679 qinfo=qinfo,
680 input_list=input_list,
681 num_operands=num_operands,
682 output_list=output_list,
683 pad=padding,
684 stride=strides,
685 dilation=dilations,
686 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100687 weight_shape=filter.shape,
688 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000689 ):
690 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700691
692 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100693 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700694
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000695 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700696 return result_tens
697
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 def build_conv3d(
699 self,
700 op,
701 ifm,
702 filter,
703 bias,
James Ward8b390432022-08-12 20:48:56 +0100704 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 strides,
706 padding,
707 dilations,
708 validator_fcns=None,
709 error_name=None,
710 qinfo=None,
711 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700712 assert len(padding) == 6
713 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100714 self.ser,
715 self.rng,
716 ifm,
717 filter,
718 accum_dtype,
719 strides,
720 padding,
721 dilations,
722 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000723 )
724
725 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000726 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
727 DType.INT8,
728 DType.UINT8,
729 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000730 qinfo = [
731 TosaQuantGen.getZeroPoint(self, ifm.dtype),
732 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
733 ]
Les Bell0e027d42021-11-09 14:42:14 +0000734
735 # Invalidate Input/Output list for error_if checks.
736 input_list = [ifm.name, filter.name, bias.name]
737 output_list = [result_tens.name]
738 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000739 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
740 self, error_name, input_list, output_list
741 )
Les Bell0e027d42021-11-09 14:42:14 +0000742
Les Bell729b0352021-11-24 10:28:21 +0000743 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000744 self.ser,
745 validator_fcns,
746 error_name,
747 op=op,
748 input_dtype=ifm.dtype,
749 weight_dtype=filter.dtype,
750 output_dtype=result_tens.dtype,
751 qinfo=qinfo,
752 input_list=input_list,
753 num_operands=num_operands,
754 output_list=output_list,
755 pad=padding,
756 stride=strides,
757 dilation=dilations,
758 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100759 weight_shape=filter.shape,
760 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000761 ):
762 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700763
764 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100765 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Kevin Cheng1533b852021-09-01 12:51:58 -0700766
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000767 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700768 return result_tens
769
Kevin Cheng550ccc52021-03-03 11:21:43 -0800770 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000771 self,
772 op,
773 ifm,
774 filter,
775 bias,
James Ward8b390432022-08-12 20:48:56 +0100776 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000777 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700778 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000779 output_shape,
780 validator_fcns=None,
781 error_name=None,
782 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800783 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700784 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000785 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100786 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 )
Les Bell0e027d42021-11-09 14:42:14 +0000788
789 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000790 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
791 DType.INT8,
792 DType.UINT8,
793 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000794 qinfo = [
795 TosaQuantGen.getZeroPoint(self, ifm.dtype),
796 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
797 ]
Les Bell0e027d42021-11-09 14:42:14 +0000798
799 # Invalidate Input/Output list for error_if checks.
800 input_list = [ifm.name, filter.name, bias.name]
801 output_list = [result_tens.name]
802 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
804 self, error_name, input_list, output_list
805 )
Les Bell0e027d42021-11-09 14:42:14 +0000806
Les Bell729b0352021-11-24 10:28:21 +0000807 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000808 self.ser,
809 validator_fcns,
810 error_name,
811 op=op,
812 input_dtype=ifm.dtype,
813 weight_dtype=filter.dtype,
814 output_dtype=result_tens.dtype,
815 qinfo=qinfo,
816 input_list=input_list,
817 num_operands=num_operands,
818 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700819 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000820 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000821 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100822 weight_shape=filter.shape,
823 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000824 ):
825 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700826
827 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100828 attr.TransposeConvAttribute(
829 out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype
830 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700831
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000832 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700833 return result_tens
834
Kevin Cheng550ccc52021-03-03 11:21:43 -0800835 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000836 self,
837 op,
838 ifm,
839 filter,
840 bias,
James Ward8b390432022-08-12 20:48:56 +0100841 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000842 strides,
843 padding,
844 dilations,
845 validator_fcns=None,
846 error_name=None,
847 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800848 ):
849 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100850 self.ser,
851 self.rng,
852 ifm,
853 filter,
854 accum_dtype,
855 strides,
856 padding,
857 dilations,
858 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000859 )
860
861 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000862 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
863 DType.INT8,
864 DType.UINT8,
865 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000866 qinfo = [
867 TosaQuantGen.getZeroPoint(self, ifm.dtype),
868 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
869 ]
Les Bell0e027d42021-11-09 14:42:14 +0000870
871 # Invalidate Input/Output list for error_if checks.
872 input_list = [ifm.name, filter.name, bias.name]
873 output_list = [result_tens.name]
874 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000875 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
876 self, error_name, input_list, output_list
877 )
Les Bell0e027d42021-11-09 14:42:14 +0000878
Les Bell729b0352021-11-24 10:28:21 +0000879 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000880 self.ser,
881 validator_fcns,
882 error_name,
883 op=op,
884 input_dtype=ifm.dtype,
885 weight_dtype=filter.dtype,
886 output_dtype=result_tens.dtype,
887 qinfo=qinfo,
888 input_list=input_list,
889 num_operands=num_operands,
890 output_list=output_list,
891 pad=padding,
892 stride=strides,
893 dilation=dilations,
894 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100895 weight_shape=filter.shape,
896 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000897 ):
898 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700899
900 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100901 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700902
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000903 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700904 return result_tens
905
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100907 self,
908 op,
909 ifm,
910 filter,
911 bias,
912 accum_dtype,
913 validator_fcns=None,
914 error_name=None,
915 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000916 ):
917 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100918 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000919 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100920
921 # Invalidate Input/Output list for error if checks.
922 input_list = [ifm.name, filter.name, bias.name]
923 output_list = [result_tens.name]
924 pCount, cCount = op["operands"]
925 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000926 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
927 self, error_name, input_list, output_list
928 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100929
Les Bell729b0352021-11-24 10:28:21 +0000930 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100931 self.ser,
932 validator_fcns,
933 error_name,
934 op=op,
935 input_shape=ifm.shape,
936 input_dtype=ifm.dtype,
937 weight_dtype=filter.dtype,
938 output_shape=result_tens.shape,
939 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000940 qinfo=qinfo,
941 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100942 input_list=input_list,
943 output_list=output_list,
944 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100945 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000946 ):
947 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700948
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000949 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100950 attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000951
952 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700953 return result_tens
954
James Ward8b390432022-08-12 20:48:56 +0100955 def build_matmul(
956 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
957 ):
958 result_tens = OutputShaper.matmulOp(
959 self.ser, self.rng, a, b, accum_dtype, error_name
960 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100961
962 # Invalidate Input/Output list for error if checks.
963 input_list = [a.name, b.name]
964 output_list = [result_tens.name]
965 pCount, cCount = op["operands"]
966 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
968 self, error_name, input_list, output_list
969 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100970
Les Bell729b0352021-11-24 10:28:21 +0000971 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100972 self.ser,
973 validator_fcns,
974 error_name,
975 op=op,
976 input_shape=a.shape,
977 input_dtype=a.dtype,
978 input2_shape=b.shape,
979 input2_dtype=b.dtype,
980 output_shape=result_tens.shape,
981 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000982 qinfo=qinfo,
983 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100984 input_list=input_list,
985 output_list=output_list,
986 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100987 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000988 ):
989 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100990
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000991 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100992 attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000993
994 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700995 return result_tens
996
Matthew Haddond6ce7252021-09-29 15:35:44 +0100997 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
998 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
999
1000 # Invalidate Input/Output list for error if checks.
1001 input_list = [a.name]
1002 output_list = [result_tens.name]
1003 pCount, cCount = op["operands"]
1004 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001005 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1006 self, error_name, input_list, output_list
1007 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001008
Les Bell729b0352021-11-24 10:28:21 +00001009 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001010 self.ser,
1011 validator_fcns,
1012 error_name,
1013 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001014 axis=axis,
1015 input_shape=a.shape,
1016 output_shape=result_tens.shape,
1017 input_dtype=a.dtype,
1018 output_dtype=result_tens.dtype,
1019 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +01001020 input_list=input_list,
1021 output_list=output_list,
1022 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001023 ):
1024 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001025
1026 attr = ts.TosaSerializerAttribute()
1027 attr.AxisAttribute(axis)
1028
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001029 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001030 return result_tens
1031
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001032 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1033 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001034
Jeremy Johnson18e26662021-07-22 16:15:29 +01001035 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001036
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001037 if error_name == ErrorIf.MaxSmallerMin:
1038 # Make sure the numbers are different to invoke this error
1039 while v[0] == v[1]:
1040 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1041 max_val = min(v)
1042 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001044 max_val = max(v)
1045 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001046
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001047 # Invalidate Input/Output list for error if checks.
1048 input_list = [a.name]
1049 output_list = [result_tens.name]
1050 pCount, cCount = op["operands"]
1051 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001052 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1053 self, error_name, input_list, output_list
1054 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001055
Les Bell729b0352021-11-24 10:28:21 +00001056 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001057 self.ser,
1058 validator_fcns,
1059 error_name,
1060 op=op,
1061 max_val=max_val,
1062 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001063 input_shape=a.shape,
1064 output_shape=result_tens.shape,
1065 input_dtype=a.dtype,
1066 output_dtype=result_tens.dtype,
1067 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001068 input_list=input_list,
1069 output_list=output_list,
1070 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001071 ):
1072 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001073
1074 attr = ts.TosaSerializerAttribute()
James Ward24dbc422022-10-19 12:20:31 +01001075 if a.dtype in (DType.FP16, DType.BF16, DType.FP32):
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001076 attr.ClampAttribute(0, 0, min_val, max_val)
1077 else:
1078 attr.ClampAttribute(min_val, max_val, 0, 0)
1079
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001080 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001081 return result_tens
1082
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001083 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1084 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001085 attr = ts.TosaSerializerAttribute()
1086
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001087 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001088
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001089 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001090 return result_tens
1091
1092 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001093 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1094 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001095
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001096 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001097 return result_tens
1098
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001099 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1100 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1101
1102 # Invalidate Input/Output list for error if checks.
1103 input_list = [a.name]
1104 output_list = [result_tens.name]
1105 pCount, cCount = op["operands"]
1106 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001107 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1108 self, error_name, input_list, output_list
1109 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001110
Les Bell729b0352021-11-24 10:28:21 +00001111 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001112 self.ser,
1113 validator_fcns,
1114 error_name,
1115 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001116 input_shape=a.shape,
1117 output_shape=result_tens.shape,
1118 input_dtype=a.dtype,
1119 output_dtype=result_tens.dtype,
1120 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001121 input_list=input_list,
1122 output_list=output_list,
1123 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001124 ):
1125 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001126
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001127 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001128 return result_tens
1129
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001130 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1131 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1132
1133 # Invalidate Input/Output list for error if checks.
1134 input_list = [a.name]
1135 output_list = [result_tens.name]
1136 pCount, cCount = op["operands"]
1137 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001138 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1139 self, error_name, input_list, output_list
1140 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001141
Les Bell729b0352021-11-24 10:28:21 +00001142 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001143 self.ser,
1144 validator_fcns,
1145 error_name,
1146 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001147 input_shape=a.shape,
1148 output_shape=result_tens.shape,
1149 input_dtype=a.dtype,
1150 output_dtype=result_tens.dtype,
1151 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001152 input_list=input_list,
1153 output_list=output_list,
1154 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001155 ):
1156 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001157
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001158 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001159 return result_tens
1160
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001161 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1162 if error_name != ErrorIf.WrongInputType:
1163 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001164
1165 # To store variable length list of input tensors we need to store axis along with it
1166 axis = a[-1]
1167 a = a[:-1]
1168
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001169 result_tens = OutputShaper.concatOp(
1170 self.ser, self.rng, axis, *a, error_name=error_name
1171 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
Matthew Haddon818ab902021-07-27 09:12:49 +01001173 input_tensor_names = []
1174 for tensor in a:
1175 input_tensor_names.append(tensor.name)
1176
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001177 # Invalidate Input/Output list for error if checks.
1178 input_list = input_tensor_names
1179 output_list = [result_tens.name]
1180 pCount, cCount = op["operands"]
1181 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001182 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1183 self, error_name, input_list, output_list
1184 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001185
Les Bell729b0352021-11-24 10:28:21 +00001186 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001187 self.ser,
1188 validator_fcns,
1189 error_name,
1190 op=op,
1191 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001192 input_shape=a[0].shape,
1193 output_shape=result_tens.shape,
1194 input_dtype=a[0].dtype,
1195 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001196 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001198 input_list=input_list,
1199 output_list=output_list,
1200 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001201 ):
1202 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001203
1204 attr = ts.TosaSerializerAttribute()
1205 attr.AxisAttribute(axis)
1206
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001207 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001208 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001209
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001210 def build_pad(
1211 self,
1212 op,
1213 a,
1214 padding,
1215 pad_const_int,
1216 pad_const_float,
1217 validator_fcns=None,
1218 error_name=None,
1219 qinfo=None,
1220 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001221 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001222
Kevin Chengfe392ce2021-10-18 21:51:55 +00001223 attr = ts.TosaSerializerAttribute()
1224 attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
Eric Kunzee5e26762020-10-13 16:11:07 -07001225
Matthew Haddone807aae2021-10-11 18:12:58 +01001226 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001227 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001228 output_list = [result_tens.name]
1229 pCount, cCount = op["operands"]
1230 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001231 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1232 self, error_name, input_list, output_list
1233 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001234
Les Bell729b0352021-11-24 10:28:21 +00001235 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001236 self.ser,
1237 validator_fcns,
1238 error_name,
1239 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001240 input_shape=a.shape,
1241 output_shape=result_tens.shape,
1242 input_dtype=a.dtype,
1243 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001244 pad=padding,
1245 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001246 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001247 input_list=input_list,
1248 output_list=output_list,
1249 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001250 ):
1251 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001252
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001253 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001254 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001255
Matthew Haddone807aae2021-10-11 18:12:58 +01001256 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 result_tens = OutputShaper.reshapeOp(
1258 self.ser, self.rng, a, newShape, error_name
1259 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001260
1261 # Invalidate Input/Output list for error if checks.
1262 input_list = [a.name]
1263 output_list = [result_tens.name]
1264 pCount, cCount = op["operands"]
1265 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001266 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1267 self, error_name, input_list, output_list
1268 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001269
Les Bell729b0352021-11-24 10:28:21 +00001270 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001271 self.ser,
1272 validator_fcns,
1273 error_name,
1274 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001275 input_shape=a.shape,
1276 output_shape=result_tens.shape,
1277 input_dtype=a.dtype,
1278 output_dtype=result_tens.dtype,
1279 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001280 input_list=input_list,
1281 output_list=output_list,
1282 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001283 ):
1284 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001285
1286 attr = ts.TosaSerializerAttribute()
1287 attr.ReshapeAttribute(newShape)
1288
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001289 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001290 return result_tens
1291
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001292 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1293 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1294
1295 # Invalidate Input/Output list for error if checks.
1296 input_list = [a.name]
1297 output_list = [result_tens.name]
1298 pCount, cCount = op["operands"]
1299 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001300 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1301 self, error_name, input_list, output_list
1302 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001303
Les Bell729b0352021-11-24 10:28:21 +00001304 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001305 self.ser,
1306 validator_fcns,
1307 error_name,
1308 op=op,
1309 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 input_shape=a.shape,
1311 output_shape=result_tens.shape,
1312 input_dtype=a.dtype,
1313 output_dtype=result_tens.dtype,
1314 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001315 input_list=input_list,
1316 output_list=output_list,
1317 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001318 ):
1319 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001320
1321 attr = ts.TosaSerializerAttribute()
1322 attr.AxisAttribute(axis)
1323
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001324 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001325 return result_tens
1326
Matthew Haddone807aae2021-10-11 18:12:58 +01001327 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1328 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001329
Kevin Chengfe392ce2021-10-18 21:51:55 +00001330 attr = ts.TosaSerializerAttribute()
1331 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001332
Matthew Haddone807aae2021-10-11 18:12:58 +01001333 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001334 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001335 output_list = [result_tens.name]
1336 pCount, cCount = op["operands"]
1337 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001338 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1339 self, error_name, input_list, output_list
1340 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001341
Les Bell729b0352021-11-24 10:28:21 +00001342 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001343 self.ser,
1344 validator_fcns,
1345 error_name,
1346 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001347 input_shape=a.shape,
1348 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001349 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001350 input_dtype=a.dtype,
1351 output_dtype=result_tens.dtype,
1352 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001353 input_list=input_list,
1354 output_list=output_list,
1355 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001356 ):
1357 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001358
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001359 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001360 return result_tens
1361
Matthew Haddone807aae2021-10-11 18:12:58 +01001362 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001363 result_tens = OutputShaper.sliceOp(
1364 self.ser, self.rng, a, start, size, error_name
1365 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001366
1367 # Invalidate Input/Output list for error if checks.
1368 input_list = [a.name]
1369 output_list = [result_tens.name]
1370 pCount, cCount = op["operands"]
1371 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1373 self, error_name, input_list, output_list
1374 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001375
Les Bell729b0352021-11-24 10:28:21 +00001376 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001377 self.ser,
1378 validator_fcns,
1379 error_name,
1380 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001381 input_shape=a.shape,
1382 output_shape=result_tens.shape,
1383 input_dtype=a.dtype,
1384 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001385 start=start,
1386 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001387 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001388 input_list=input_list,
1389 output_list=output_list,
1390 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001391 ):
1392 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001393
1394 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001395 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001396
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001397 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001398 return result_tens
1399
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001400 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1401 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1402
1403 # Invalidate Input/Output list for error if checks.
1404 input_list = [a.name]
1405 output_list = [result_tens.name]
1406 pCount, cCount = op["operands"]
1407 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1409 self, error_name, input_list, output_list
1410 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001411
Les Bell729b0352021-11-24 10:28:21 +00001412 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001413 self.ser,
1414 validator_fcns,
1415 error_name,
1416 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001417 input_shape=a.shape,
1418 output_shape=result_tens.shape,
1419 input_dtype=a.dtype,
1420 output_dtype=result_tens.dtype,
1421 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422 input_list=input_list,
1423 output_list=output_list,
1424 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001425 ):
1426 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001427
1428 attr = ts.TosaSerializerAttribute()
1429 attr.TileAttribute(multiples)
1430
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001431 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001432 return result_tens
1433
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001434 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001435
1436 # Create a new indicies tensor
1437 # here with data that doesn't exceed the dimensions of the values tensor
1438
Kevin Cheng550ccc52021-03-03 11:21:43 -08001439 K = values.shape[1] # K
1440 W = self.randInt(
1441 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1442 ) # W
1443 indicies_arr = np.int32(
1444 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1445 ) # (N, W)
1446 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001447
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001448 result_tens = OutputShaper.gatherOp(
1449 self.ser, self.rng, values, indicies, error_name
1450 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001451
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001452 # Invalidate Input/Output list for error if checks.
1453 input_list = [values.name, indicies.name]
1454 output_list = [result_tens.name]
1455 pCount, cCount = op["operands"]
1456 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001457 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1458 self, error_name, input_list, output_list
1459 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001460
Les Bell729b0352021-11-24 10:28:21 +00001461 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001462 self.ser,
1463 validator_fcns,
1464 error_name,
1465 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001466 input_shape=values.shape,
1467 output_shape=result_tens.shape,
1468 input_dtype=values.dtype,
1469 output_dtype=result_tens.dtype,
1470 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001471 input_list=input_list,
1472 output_list=output_list,
1473 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001474 ):
1475 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001476
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001477 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001478
1479 return result_tens
1480
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001481 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001482
1483 # Create a new indicies tensor
1484 # here with data that doesn't exceed the dimensions of the values_in tensor
1485
Kevin Cheng550ccc52021-03-03 11:21:43 -08001486 K = values_in.shape[1] # K
1487 W = input.shape[1] # W
1488 indicies_arr = np.int32(
1489 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1490 ) # (N, W)
1491 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001492
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001493 result_tens = OutputShaper.scatterOp(
1494 self.ser, self.rng, values_in, indicies, input, error_name
1495 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001496
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001497 # Invalidate Input/Output list for error if checks.
1498 input_list = [values_in.name, indicies.name, input.name]
1499 output_list = [result_tens.name]
1500 pCount, cCount = op["operands"]
1501 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001502 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1503 self, error_name, input_list, output_list
1504 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001505
Les Bell729b0352021-11-24 10:28:21 +00001506 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001507 self.ser,
1508 validator_fcns,
1509 error_name,
1510 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001511 input_shape=values_in.shape,
1512 output_shape=result_tens.shape,
1513 input_dtype=values_in.dtype,
1514 output_dtype=result_tens.dtype,
1515 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001516 input_list=input_list,
1517 output_list=output_list,
1518 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001519 ):
1520 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001521
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001522 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001523
Kevin Cheng77d0f762020-11-24 10:26:32 -08001524 return result_tens
1525
Kevin Cheng550ccc52021-03-03 11:21:43 -08001526 def build_resize(
1527 self,
1528 op,
1529 input,
1530 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001531 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001532 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001533 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001534 input_dtype,
1535 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001536 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001537 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001538 ):
1539 result_tens = OutputShaper.resizeOp(
1540 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001541 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001542 input,
1543 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001544 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001545 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001546 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001547 input_dtype,
1548 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001549 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001550 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001551
Matthew Haddon848efb42021-09-09 12:30:53 +01001552 # Invalidate Input/Output list for error if checks.
1553 input_list = [input.name]
1554 output_list = [result_tens.name]
1555 pCount, cCount = op["operands"]
1556 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001557 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1558 self, error_name, input_list, output_list
1559 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001560
Les Bell729b0352021-11-24 10:28:21 +00001561 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001562 self.ser,
1563 validator_fcns,
1564 error_name,
1565 op=op,
1566 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001567 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001568 input_dtype=input_dtype,
1569 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001570 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001571 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001572 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001573 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001574 input_list=input_list,
1575 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001576 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001577 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001578 ):
1579 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001580
Eric Kunzee5e26762020-10-13 16:11:07 -07001581 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001582
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001583 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001584
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001585 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001586 return result_tens
1587
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001588 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1589 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1590 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001591 self.ser.addOperator(
1592 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1593 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001594 return result_tens
1595
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001596 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001597 self.ser.addOutputTensor(val)
1598 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001599
1600 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001601 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001602 result_tens = OutputShaper.typeConversionOp(
1603 self.ser, self.rng, val, out_dtype, error_name
1604 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001605
1606 # Invalidate Input/Output list for error if checks.
1607 input_list = [val.name]
1608 output_list = [result_tens.name]
1609 pCount, cCount = op["operands"]
1610 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001611 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1612 self, error_name, input_list, output_list
1613 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001614
Les Bell729b0352021-11-24 10:28:21 +00001615 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001616 self.ser,
1617 validator_fcns,
1618 error_name,
1619 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001620 input_shape=val.shape,
1621 output_shape=result_tens.shape,
1622 input_dtype=val.dtype,
1623 output_dtype=result_tens.dtype,
1624 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001625 input_list=input_list,
1626 output_list=output_list,
1627 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001628 ):
1629 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001630
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001631 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001632 return result_tens
1633
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001634 def build_rescale(
1635 self,
1636 op,
1637 val,
1638 out_dtype,
1639 scale32,
1640 double_round,
1641 per_channel,
1642 validator_fcns,
1643 error_name,
1644 ):
1645 result_tens = OutputShaper.typeConversionOp(
1646 self.ser, self.rng, val, out_dtype, error_name
1647 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001648
1649 if per_channel:
1650 nc = val.shape[-1]
1651 else:
1652 nc = 1
1653
1654 in_type_width = self.typeWidth(val.dtype)
1655 out_type_width = self.typeWidth(out_dtype)
1656
Kevin Cheng3a478572021-01-22 17:21:02 -08001657 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001658 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001659 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001660 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001661 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001662 in_type_width += 1
1663 elif error_name in [
1664 ErrorIf.InputZeroPointNotZero,
1665 ErrorIf.U16InputZeroPointNotValid,
1666 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001667 input_zp = self.randInt(-128, 128)
1668 if input_zp == 0:
1669 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001670 in_type_width += 1
1671 elif val.dtype == DType.UINT16:
1672 # Must come after ErrorIf.U16InputZeroPointNotValid check
1673 input_zp = self.rng.choice([0, 32768])
1674 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001675 else:
1676 input_zp = 0
1677
Kevin Cheng3a478572021-01-22 17:21:02 -08001678 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001679 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001680 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001681 elif out_dtype == DType.UINT8:
1682 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001683 out_type_width += 1
1684 elif error_name in [
1685 ErrorIf.OutputZeroPointNotZero,
1686 ErrorIf.U16OutputZeroPointNotValid,
1687 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001688 output_zp = self.randInt(-128, 128)
1689 if output_zp == 0:
1690 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001691 out_type_width += 1
1692 elif out_dtype == DType.UINT16:
1693 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1694 output_zp = self.rng.choice([0, 32768])
1695 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001696 else:
1697 output_zp = 0
1698
1699 # Calculate scale based on:
1700 # scale = a *(2^output_width)/(2^input_width))
1701
1702 a = np.float32(self.rng.random(size=[nc]))
1703 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1704
1705 if scale32:
1706 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001707 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001708 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1709 else:
1710 # Cap the scaling at 2^15 - 1 for scale16
1711 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1712
Kevin Cheng550ccc52021-03-03 11:21:43 -08001713 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001714
1715 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1716 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001717 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1718 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001719
1720 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001721 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1722 scale_arr[i], scale32
1723 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001724 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1725 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
Kevin Cheng550ccc52021-03-03 11:21:43 -08001727 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001728 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001729 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001730 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001731 assert val.placeholderFilename
1732 values = np.load(
1733 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1734 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001735 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1736 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1737 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1738 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001739 if not np.all(np.array_equal(values, val_adj)):
1740 # Values changed so overwrite file with new values
1741 np.save(
1742 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1743 val_adj,
1744 False,
1745 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001746
Matthew Haddonc2025212021-10-08 21:21:05 +01001747 # Invalidate Input/Output list for error if checks.
1748 input_list = [val.name]
1749 output_list = [result_tens.name]
1750 pCount, cCount = op["operands"]
1751 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001752 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1753 self, error_name, input_list, output_list
1754 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001755
1756 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001757 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001758 self.ser,
1759 validator_fcns,
1760 error_name,
1761 op=op,
1762 input_dtype=val.dtype,
1763 output_dtype=out_dtype,
1764 input_shape=val.shape,
1765 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001766 scale32=scale32,
1767 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001768 input_list=input_list,
1769 output_list=output_list,
1770 result_tensor=result_tens,
1771 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001772 ):
1773 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001774
Eric Kunzee5e26762020-10-13 16:11:07 -07001775 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001776 attr.RescaleAttribute(
1777 input_zp,
1778 output_zp,
1779 multiplier_arr,
1780 shift_arr,
1781 scale32,
1782 double_round,
1783 per_channel,
1784 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001785
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001786 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001787 return result_tens
1788
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001789 def _get_condition_tensor(self, op, cond, error_name):
1790 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1791 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1792 else:
1793 cond_type = DType.BOOL
1794 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1795 choice = self.rng.choice([1, 2])
1796 if choice == 1:
1797 cond_shape = [2]
1798 else:
1799 cond_shape = [1, 2]
1800 else:
1801 # Must be of size 1 (rank 0)
1802 cond_shape = []
1803 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1804 return cond_tens
1805
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001806 def build_cond_if_const(
1807 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1808 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001809 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001810 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001811 # and fill them with const nodes for the body.
1812
1813 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001814 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001815
1816 # Make then/else tensors
1817 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001818
1819 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001820 if error_name in [
1821 ErrorIf.CondIfOutputListThenGraphMismatch,
1822 ErrorIf.CondIfOutputListElseGraphMismatch,
1823 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001824 incorrect_shape = deepcopy(then_tens.shape)
1825 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001826 incorrect_shape[i] += (
1827 self.rng.choice([-3, -2, 2, 3])
1828 if incorrect_shape[i] > 3
1829 else self.rng.choice([1, 2, 4])
1830 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001831 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1832
Jeremy Johnson18e26662021-07-22 16:15:29 +01001833 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1834 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001835
1836 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001837 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001838
1839 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001840 then_block = "THEN_BLOCK"
1841 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001842 attr = ts.TosaSerializerAttribute()
1843 attr.CondIfAttribute(then_block, else_block)
1844
1845 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001846 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001847
1848 self.ser.startBasicBlock(then_block)
1849 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001850 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1851 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1852 else:
1853 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001854 self.ser.addOutputTensor(then_tens)
1855
1856 self.ser.startBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001857 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1858 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1859 else:
1860 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001861 self.ser.addOutputTensor(else_tens)
1862
Les Bell729b0352021-11-24 10:28:21 +00001863 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001864 self.ser,
1865 validator_fcns,
1866 error_name,
1867 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001868 basicBlocks=self.ser.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001869 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001870 ):
1871 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001872
Eric Kunzee5e26762020-10-13 16:11:07 -07001873 return result_tens
1874
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001875 def build_cond_if_binary(
1876 self, op, a, b, cond, validator_fcns=None, error_name=None
1877 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001878 # For cond_if with a binary op in the then/else blocks, take a and b and
1879 # alternately add or subtract them based on the condition
1880
1881 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001882 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001883
Kevin Cheng550ccc52021-03-03 11:21:43 -08001884 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001885
1886 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001887 then_block = "THEN_BLOCK"
1888 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001889 attr = ts.TosaSerializerAttribute()
1890 attr.CondIfAttribute(then_block, else_block)
1891
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001892 if error_name in [
1893 ErrorIf.CondIfInputListThenGraphMismatch,
1894 ErrorIf.CondIfInputListElseGraphMismatch,
1895 ErrorIf.CondIfOutputListElseGraphMismatch,
1896 ErrorIf.CondIfOutputListThenGraphMismatch,
1897 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001898 incorrect_shape = a.shape.copy()
1899 for i in range(len(incorrect_shape)):
1900 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1901 incorrect_block_input = deepcopy(a)
1902 incorrect_block_input.shape = incorrect_shape
1903
Eric Kunzee5e26762020-10-13 16:11:07 -07001904 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001905 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001906 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001907 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
James Ward24dbc422022-10-19 12:20:31 +01001909 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001910 then_op, else_op = Op.ADD, Op.SUB
1911 elif a.dtype in (DType.INT8, DType.INT16):
1912 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1913 else:
1914 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001915
Les Bell6040b4d2021-10-11 12:50:31 +01001916 for block, op in ((then_block, then_op), (else_block, else_op)):
1917 self.ser.startBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001918 if (
1919 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1920 and block == then_block
1921 ) or (
1922 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1923 and block == else_block
1924 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001925 self.ser.addInputTensor(incorrect_block_input)
1926 self.ser.addInputTensor(b)
1927 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001928 elif (
1929 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1930 and block == then_block
1931 ) or (
1932 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1933 and block == else_block
1934 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001935 self.ser.addInputTensor(a)
1936 self.ser.addInputTensor(b)
1937 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1938 else:
1939 self.ser.addInputTensor(a)
1940 self.ser.addInputTensor(b)
1941 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001942 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
Les Bell729b0352021-11-24 10:28:21 +00001944 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001945 self.ser,
1946 validator_fcns,
1947 error_name,
1948 op=op,
1949 a=a,
1950 b=b,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001951 basicBlocks=self.ser.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001952 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001953 ):
1954 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001955
Eric Kunzee5e26762020-10-13 16:11:07 -07001956 return result_tens
1957
Matthew Haddon630c17c2021-10-14 15:05:41 +01001958 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001959 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001960
Kevin Cheng550ccc52021-03-03 11:21:43 -08001961 cond_block = "COND_BLOCK"
1962 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001963
1964 attr = ts.TosaSerializerAttribute()
1965 attr.WhileLoopAttribute(cond_block, body_block)
1966
1967 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001968 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001969 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001970 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001971
1972 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001973 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1974 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001975 if error_name == ErrorIf.InputListOutputListMismatch:
1976 incorrect_acc = deepcopy(acc)
1977 for i in range(len(incorrect_acc.shape)):
1978 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1979 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1980 else:
1981 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001982
1983 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001984 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001985 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001986 [iter.name, a.name, acc.name],
1987 [iter_out.name, a_out.name, acc_out.name],
1988 attr,
1989 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001990 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001991
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001992 if error_name in [
1993 ErrorIf.InputListCondGraphMismatch,
1994 ErrorIf.InputListBodyGraphInputMismatch,
1995 ErrorIf.InputListBodyGraphOutputMismatch,
1996 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001997 incorrect_iter = deepcopy(iter)
1998 for i in range(len(incorrect_iter.shape)):
1999 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2000 if len(incorrect_iter.shape) == 0:
2001 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2002
2003 incorrect_acc = deepcopy(acc)
2004 for i in range(len(incorrect_acc.shape)):
2005 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2006
Eric Kunzee5e26762020-10-13 16:11:07 -07002007 # COND block (input: iter, output: cond_tens )
2008 self.ser.startBasicBlock(cond_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002009 if error_name == ErrorIf.InputListCondGraphMismatch:
2010 self.ser.addInputTensor(incorrect_iter)
2011 self.ser.addInputTensor(a)
2012 self.ser.addInputTensor(incorrect_acc)
2013 else:
2014 self.ser.addInputTensor(iter)
2015 self.ser.addInputTensor(a)
2016 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002017 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002018
2019 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002020 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002021 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002022 cond_type = DType.BOOL
2023 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2024 choice = self.rng.choice([1, 2])
2025 if choice == 1:
2026 cond_shape = [3]
2027 else:
2028 cond_shape = [1, 2]
2029 else:
2030 cond_shape = []
2031 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002032
Kevin Cheng550ccc52021-03-03 11:21:43 -08002033 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002034
2035 # BODY block (input: a, acc, iter, output: a, acc, iter)
2036 # Note that local intermediate tensors need to be declared here for the outputs
2037 self.ser.startBasicBlock(body_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002038 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2039 self.ser.addInputTensor(incorrect_iter)
2040 self.ser.addInputTensor(a)
2041 self.ser.addInputTensor(incorrect_acc)
2042 else:
2043 self.ser.addInputTensor(iter)
2044 self.ser.addInputTensor(a)
2045 self.ser.addInputTensor(acc)
2046
Kevin Cheng550ccc52021-03-03 11:21:43 -08002047 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002048
2049 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002050 iter_body_out = self.ser.addIntermediate(
2051 incorrect_iter.shape, incorrect_iter.dtype
2052 )
2053 acc_body_out = self.ser.addIntermediate(
2054 incorrect_acc.shape, incorrect_acc.dtype
2055 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002056 else:
2057 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2058 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2059
Eric Kunzee5e26762020-10-13 16:11:07 -07002060 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2061 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2062 self.ser.addOutputTensor(iter_body_out)
2063 self.ser.addOutputTensor(a)
2064 self.ser.addOutputTensor(acc_body_out)
2065
Les Bell729b0352021-11-24 10:28:21 +00002066 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002067 self.ser,
2068 validator_fcns,
2069 error_name,
2070 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002071 basicBlocks=self.ser.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002072 ):
2073 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002074
Eric Kunzee5e26762020-10-13 16:11:07 -07002075 return acc_out
2076
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002077 def create_filter_lists(
2078 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2079 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002080 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2081 default_test_rank_range = range(1, 5)
2082 if not shapeFilter:
2083 shapeFilter = [None]
2084
2085 # Calculate the filters based on what is requested and what the operator allows
2086 rmin, rmax = op["rank"]
2087 if rankFilter is not None:
2088 cleanRankFilter = []
2089 # Ensure rankFilter values are allowed by operator
2090 for rank in rankFilter:
2091 if rank >= rmin and rank <= rmax:
2092 cleanRankFilter.append(rank)
2093 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002094 # Ensure default behaviour is bounded by default range or by operator,
2095 # whichever is the smaller range of ranks.
2096 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002097 cleanRankFilter = (
2098 opRankRange
2099 if len(opRankRange) <= len(default_test_rank_range)
2100 else default_test_rank_range
2101 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002102 else:
2103 cleanRankFilter = range(rmin, rmax + 1)
2104
2105 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002106
Matthew Haddon1c00b712021-10-01 15:51:03 +01002107 if dtypeFilter is not None:
2108 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002109 # Create list of operator dtypes filtered by requested dtypes
2110 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002111 if dtype in dtypeFilter or (
2112 isinstance(dtype, list) and dtype[0] in dtypeFilter
2113 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002114 cleanDtypeFilter.append(dtype)
2115 else:
2116 cleanDtypeFilter = dtypes
2117
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002118 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002119 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002120 "shapeFilter": shapeFilter,
2121 "rankFilter": cleanRankFilter,
2122 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002123 }
2124 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002125 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002126 if validator is not None:
2127 validator_info = validator(check=False, op=op)
2128 else:
2129 return None
2130
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002131 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002132
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 # Set parameters as required
2134 if error_arguments["rank"] is not None:
2135 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002136 else:
2137 rankFilter = cleanRankFilter
2138
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002139 if error_arguments["dtype"] is not None:
2140 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002141 else:
2142 dtypeFilter = cleanDtypeFilter
2143
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002144 if error_arguments["shape"] is not None:
2145 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002146 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002147 shapeFilter = shapeFilter[
2148 :2
2149 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002150
2151 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002152 "shapeFilter": shapeFilter,
2153 "rankFilter": rankFilter,
2154 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002155 }
2156 return filterDict
2157
Kevin Cheng550ccc52021-03-03 11:21:43 -08002158 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002159 self,
2160 opName,
2161 shapeFilter=[None],
2162 rankFilter=None,
2163 dtypeFilter=None,
2164 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002165 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002166
2167 try:
2168 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002169 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002170 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002171
2172 # Initialize a new random number generator
2173 self.rng = np.random.default_rng(self.random_seed)
2174
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002175 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002176
Eric Kunzee5e26762020-10-13 16:11:07 -07002177 # Test list consists of a tuple of:
2178 # (opName, testNameStr, dtype, shapeList, argumentsList)
2179 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002180 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002181 error_if_validators = op["error_if_validators"]
2182 else:
2183 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002184
Matthew Haddon1c00b712021-10-01 15:51:03 +01002185 for validator in error_if_validators:
2186 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002187 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002188 else:
2189 error_name = None
2190
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002191 filterDict = self.create_filter_lists(
2192 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2193 )
2194 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002195 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002196 cleanRankFilter = filterDict["rankFilter"]
2197 cleanDtypeFilter = filterDict["dtypeFilter"]
2198 cleanShapeFilter = filterDict["shapeFilter"]
2199 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002200
2201 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002202 for t in cleanDtypeFilter:
2203 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002204 # Filter out by rank
2205 if shape is not None and len(shape) != r:
2206 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002207 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002208 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002209
Matthew Haddon74567092021-07-16 15:38:20 +01002210 shapeStr = self.shapeStr(shapeList[0])
2211 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002212
Matthew Haddon74567092021-07-16 15:38:20 +01002213 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2214 argList = []
2215 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002216 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002217 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002218 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002219
Matthew Haddon74567092021-07-16 15:38:20 +01002220 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002221 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002222 if argStr:
2223 testStr = "{}_{}_{}_{}".format(
2224 opName, shapeStr, typeStr, argStr
2225 )
2226 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002227 testStr = "{}_{}_{}".format(
2228 opName, shapeStr, typeStr
2229 )
2230 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002231 if argStr:
2232 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2233 opName, error_name, shapeStr, typeStr, argStr
2234 )
2235 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002236 testStr = "{}_ERRORIF_{}_{}_{}".format(
2237 opName, error_name, shapeStr, typeStr
2238 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002239
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002240 testList.append(
2241 (opName, testStr, t, error_name, shapeList, args)
2242 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002243
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002244 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002245 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2246 if "invalid_test_validators" in op:
2247 invalid_test_validators = op["invalid_test_validators"]
2248 clean_testList = []
2249 for test in testList:
2250 for validator_fcn in invalid_test_validators:
2251 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002252 if validator_fcn(
2253 opName=test[0],
2254 input_dtype=test[2],
2255 shapeList=test[4],
2256 args=test[5],
2257 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002258 remove_test = True
2259 if not remove_test:
2260 clean_testList.append(test)
2261 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002262
2263 return testList
2264
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002265 def serializeTest(
2266 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2267 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002268 try:
2269 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002270 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002271 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002272
2273 # Create a serializer
2274 self.createSerializer(opName, testStr)
2275
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002276 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002277 if "error_if_validators" in op:
2278 error_if_validators = op["error_if_validators"]
2279 else:
2280 error_if_validators = None
2281
Kevin Cheng550ccc52021-03-03 11:21:43 -08002282 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002283 num_operands = pCount + cCount
2284
2285 if isinstance(dtype_or_dtypeList, list):
2286 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002287 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002288 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002289 else:
2290 dtypeList = [dtype_or_dtypeList] * (num_operands)
2291
Kevin Cheng93a16282021-08-31 16:14:03 -07002292 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002293 assert (
2294 len(shapeList) == num_operands
2295 ), "shapeList length {} must match number of operands {}".format(
2296 len(shapeList), num_operands
2297 )
2298 assert (
2299 len(dtypeList) == num_operands
2300 ), "dtypeList length {} must match number of operands {}".format(
2301 len(dtypeList), num_operands
2302 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002303
2304 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002305 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002306 except KeyError:
2307 qgen = None
2308
2309 # Build the random tensor operands and the test
2310 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002311
Matthew Haddon1c00b712021-10-01 15:51:03 +01002312 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002313 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002314 else:
2315 qinfo = None
2316
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002317 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002318
Matthew Haddon1c00b712021-10-01 15:51:03 +01002319 try:
2320 if error_if_validators is None:
2321 if qinfo is not None:
2322 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2323 else:
2324 resultName = build_fcn(self, op, *tens, *testArgs)
2325 else:
2326 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002327 resultName = build_fcn(
2328 self,
2329 op,
2330 *tens,
2331 *testArgs,
2332 validator_fcns=error_if_validators,
2333 error_name=error_name,
2334 qinfo=qinfo,
2335 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002336 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002337 resultName = build_fcn(
2338 self,
2339 op,
2340 *tens,
2341 *testArgs,
2342 validator_fcns=error_if_validators,
2343 error_name=error_name,
2344 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002345 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002346 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002347 raise e
2348
Les Bell729b0352021-11-24 10:28:21 +00002349 if resultName:
2350 # The test is valid, serialize it
2351 self.serialize("test")
2352 else:
2353 # The test is not valid
2354 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002355
Eric Kunzee5e26762020-10-13 16:11:07 -07002356 def createDynamicOpLists(self):
2357
Jeremy Johnson00423432022-09-12 17:27:37 +01002358 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2359 # Already created these lists (can occur when class is initialized more than once)
2360 return
2361
Eric Kunzee5e26762020-10-13 16:11:07 -07002362 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002363 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002364
Kevin Cheng1533b852021-09-01 12:51:58 -07002365 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002366 testName = "conv2d_{}x{}".format(k[0], k[1])
2367 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2368 self.TOSA_OP_LIST[testName]["filter"] = k
2369 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002370
Kevin Cheng550ccc52021-03-03 11:21:43 -08002371 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2372 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2373 "depthwise_conv2d_TEMPLATE"
2374 ].copy()
2375 self.TOSA_OP_LIST[testName]["filter"] = k
2376 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002377
Kevin Cheng550ccc52021-03-03 11:21:43 -08002378 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2379 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2380 "transpose_conv2d_TEMPLATE"
2381 ].copy()
2382 self.TOSA_OP_LIST[testName]["filter"] = k
2383 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002384
Kevin Cheng1533b852021-09-01 12:51:58 -07002385 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2386 for k in KERNELS_3D:
2387 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2388 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2389 self.TOSA_OP_LIST[testName]["filter"] = k
2390 self.TOSA_OP_LIST[testName]["template"] = False
2391
Eric Kunzee5e26762020-10-13 16:11:07 -07002392 # Delete any templates after having created any dynamic ops
2393 # This is a two-pass operation because it's bad practice to delete
2394 # keys from dictionaries while iterating
2395 keyList = []
2396 for k in self.TOSA_OP_LIST:
2397 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002398 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002399 keyList.append(k)
2400 continue
2401 except KeyError:
2402 pass
2403
2404 for k in keyList:
2405 del self.TOSA_OP_LIST[k]
2406
2407 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002408 """Fill in default fields for ops if they aren't already specified.
2409 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002410 for op in self.TOSA_OP_LIST:
2411
2412 # Required fields
2413 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002414 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002415 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002416 raise Exception(
2417 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2418 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002419
2420 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002421 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002422 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002423 raise Exception(
2424 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2425 op
2426 )
2427 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002428
2429 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002430 _ = self.TOSA_OP_LIST[op]["types"]
2431 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002432 raise Exception(
2433 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2434 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002435
2436 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002437 _ = self.TOSA_OP_LIST[op]["op"]
2438 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002439 raise Exception(
2440 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2441 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002442
2443 # Put in default rank range, if missing
2444 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002445 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002446 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002447 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002448
2449 # Tensor operator list
2450 # 'op': op name
2451 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002452 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2453 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002454 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2455 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002456 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002457
Kevin Cheng550ccc52021-03-03 11:21:43 -08002458 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002459 TYPE_INT_FP = [
2460 DType.INT8,
2461 DType.INT16,
2462 DType.INT32,
2463 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002464 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002465 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002466 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002467
Kevin Cheng550ccc52021-03-03 11:21:43 -08002468 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002469 TYPE_FI32 = [
2470 DType.FP32,
2471 DType.FP16,
2472 DType.BF16,
2473 DType.INT32,
2474 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002475 TYPE_FIB = [
2476 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002477 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002478 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002479 DType.INT8,
2480 DType.INT16,
2481 DType.INT32,
2482 DType.BOOL,
2483 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002484 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002485
James Ward24dbc422022-10-19 12:20:31 +01002486 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002487
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002488 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002489 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002490 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002491 [DType.INT8, DType.INT8, DType.INT32],
2492 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002493 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002494 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002495 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002496 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002497 ]
2498
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002499 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002500
2501 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002502 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002503 "argmax": {
2504 "op": Op.ARGMAX,
2505 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002506 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002507 "build_fcn": (
2508 build_argmax,
2509 TosaTensorGen.tgBasic,
2510 TosaTensorValuesGen.tvgDefault,
2511 TosaArgGen.agAxis,
2512 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002513 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002514 "error_if_validators": (
2515 TosaErrorValidator.evAxisSmallerZero,
2516 TosaErrorValidator.evAxisLargerRank,
2517 TosaErrorValidator.evArgmaxOutputRankMismatch,
2518 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2519 TosaErrorValidator.evWrongRank,
2520 TosaErrorValidator.evWrongInputType,
2521 TosaErrorValidator.evWrongOutputType,
2522 TosaErrorValidator.evWrongInputList,
2523 TosaErrorValidator.evWrongOutputList,
2524 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002525 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002526 "avg_pool2d": {
2527 "op": Op.AVG_POOL2D,
2528 "operands": (1, 0),
2529 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002530 "build_fcn": (
2531 build_pool2d,
2532 TosaTensorGen.tgNHWC,
2533 TosaTensorValuesGen.tvgDefault,
2534 TosaArgGen.agPooling,
2535 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002536 "qgen": TosaQuantGen.qgUnary,
2537 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002538 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002539 "error_if_validators": (
2540 TosaErrorValidator.evKernelSmallerOne,
2541 TosaErrorValidator.evStrideSmallerOne,
2542 TosaErrorValidator.evPadSmallerZero,
2543 TosaErrorValidator.evWrongRank,
2544 TosaErrorValidator.evWrongInputType,
2545 TosaErrorValidator.evWrongOutputType,
2546 TosaErrorValidator.evWrongInputList,
2547 TosaErrorValidator.evWrongOutputList,
2548 TosaErrorValidator.evInputZeroPointNotZero,
2549 TosaErrorValidator.evOutputZeroPointNotZero,
2550 TosaErrorValidator.evPadLargerEqualKernel,
2551 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002552 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002553 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002554 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002555 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002556 "conv2d_TEMPLATE": {
2557 "op": Op.CONV2D,
2558 "operands": (1, 2),
2559 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002560 "build_fcn": (
2561 build_conv2d,
2562 TosaTensorGen.tgConv2D,
2563 TosaTensorValuesGen.tvgDefault,
2564 TosaArgGen.agConv,
2565 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002566 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002567 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002568 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2569 "error_if_validators": (
2570 TosaErrorValidator.evWrongInputType,
2571 TosaErrorValidator.evWrongOutputType,
2572 TosaErrorValidator.evWrongInputList,
2573 TosaErrorValidator.evWrongOutputList,
2574 TosaErrorValidator.evInputZeroPointNotZero,
2575 TosaErrorValidator.evWeightZeroPointNotZero,
2576 TosaErrorValidator.evPadSmallerZero,
2577 TosaErrorValidator.evStrideSmallerOne,
2578 TosaErrorValidator.evDilationSmallerOne,
2579 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002580 TosaErrorValidator.evConvOutputShapeMismatch,
2581 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002582 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002583 "template": True,
2584 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002585 # Templated operator. Filled in by createDynamicOpLists
2586 "conv3d_TEMPLATE": {
2587 "op": Op.CONV3D,
2588 "operands": (1, 2),
2589 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002590 "build_fcn": (
2591 build_conv3d,
2592 TosaTensorGen.tgConv3D,
2593 TosaTensorValuesGen.tvgDefault,
2594 TosaArgGen.agConv,
2595 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002596 "qgen": TosaQuantGen.qgConv,
2597 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002598 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2599 "error_if_validators": (
2600 TosaErrorValidator.evWrongInputType,
2601 TosaErrorValidator.evWrongOutputType,
2602 TosaErrorValidator.evWrongInputList,
2603 TosaErrorValidator.evWrongOutputList,
2604 TosaErrorValidator.evInputZeroPointNotZero,
2605 TosaErrorValidator.evWeightZeroPointNotZero,
2606 TosaErrorValidator.evPadSmallerZero,
2607 TosaErrorValidator.evStrideSmallerOne,
2608 TosaErrorValidator.evDilationSmallerOne,
2609 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002610 TosaErrorValidator.evConvOutputShapeMismatch,
2611 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002612 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002613 "template": True,
2614 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002615 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002616 "depthwise_conv2d_TEMPLATE": {
2617 "op": Op.DEPTHWISE_CONV2D,
2618 "operands": (1, 2),
2619 "filter": [1, 1],
2620 "rank": (4, 4),
2621 "build_fcn": (
2622 build_depthwise_conv2d,
2623 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002624 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002625 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002626 ),
2627 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002628 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002629 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2630 "error_if_validators": (
2631 TosaErrorValidator.evWrongInputType,
2632 TosaErrorValidator.evWrongOutputType,
2633 TosaErrorValidator.evWrongInputList,
2634 TosaErrorValidator.evWrongOutputList,
2635 TosaErrorValidator.evInputZeroPointNotZero,
2636 TosaErrorValidator.evWeightZeroPointNotZero,
2637 TosaErrorValidator.evPadSmallerZero,
2638 TosaErrorValidator.evStrideSmallerOne,
2639 TosaErrorValidator.evDilationSmallerOne,
2640 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002641 TosaErrorValidator.evConvOutputShapeMismatch,
2642 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002643 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002644 "template": True,
2645 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002646 "fully_connected": {
2647 "op": Op.FULLY_CONNECTED,
2648 "operands": (1, 2),
2649 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002650 "build_fcn": (
2651 build_fully_connected,
2652 TosaTensorGen.tgFullyConnected,
2653 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002654 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002655 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002656 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002657 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002658 "error_if_validators": (
2659 TosaErrorValidator.evInputZeroPointNotZero,
2660 TosaErrorValidator.evWeightZeroPointNotZero,
2661 TosaErrorValidator.evWrongRank,
2662 TosaErrorValidator.evWrongInputType,
2663 TosaErrorValidator.evWrongOutputType,
2664 TosaErrorValidator.evWrongInputList,
2665 TosaErrorValidator.evWrongOutputList,
2666 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002668 "matmul": {
2669 "op": Op.MATMUL,
2670 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002671 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002672 "build_fcn": (
2673 build_matmul,
2674 TosaTensorGen.tgMatmul,
2675 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002676 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002677 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002678 "qgen": TosaQuantGen.qgMatmul,
2679 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002680 "error_if_validators": (
2681 TosaErrorValidator.evInputZeroPointNotZero,
2682 TosaErrorValidator.evWrongRank,
2683 TosaErrorValidator.evWrongInputType,
2684 TosaErrorValidator.evWrongOutputType,
2685 TosaErrorValidator.evWrongInputList,
2686 TosaErrorValidator.evWrongOutputList,
2687 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002688 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002689 "max_pool2d": {
2690 "op": Op.MAX_POOL2D,
2691 "operands": (1, 0),
2692 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002693 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002694 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002695 TosaTensorGen.tgNHWC,
2696 TosaTensorValuesGen.tvgDefault,
2697 TosaArgGen.agPooling,
2698 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002699 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002700 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002701 "error_if_validators": (
2702 TosaErrorValidator.evKernelSmallerOne,
2703 TosaErrorValidator.evStrideSmallerOne,
2704 TosaErrorValidator.evPadSmallerZero,
2705 TosaErrorValidator.evWrongRank,
2706 TosaErrorValidator.evWrongInputType,
2707 TosaErrorValidator.evWrongOutputType,
2708 TosaErrorValidator.evWrongInputList,
2709 TosaErrorValidator.evWrongOutputList,
2710 TosaErrorValidator.evPadLargerEqualKernel,
2711 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002712 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002713 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002714 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002715 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002716 "transpose_conv2d_TEMPLATE": {
2717 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002718 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002719 "rank": (4, 4),
2720 "build_fcn": (
2721 build_transpose_conv2d,
2722 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002723 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002724 TosaArgGen.agTransposeConv2D,
2725 ),
2726 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002727 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002728 "invalid_test_validators": (
2729 TosaInvalidValidator.ivHeightWidthInvalid,
2730 TosaInvalidValidator.ivNonPositiveOutputShape,
2731 ),
2732 "error_if_validators": (
2733 TosaErrorValidator.evWrongInputType,
2734 TosaErrorValidator.evWrongOutputType,
2735 TosaErrorValidator.evWrongInputList,
2736 TosaErrorValidator.evWrongOutputList,
2737 TosaErrorValidator.evInputZeroPointNotZero,
2738 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002739 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002740 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002741 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002742 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002743 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002744 "template": True,
2745 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002746 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002747 "clamp": {
2748 "op": Op.CLAMP,
2749 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002750 "build_fcn": (
2751 build_clamp,
2752 TosaTensorGen.tgBasic,
2753 TosaTensorValuesGen.tvgDefault,
2754 None,
2755 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002756 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002757 "error_if_validators": (
2758 TosaErrorValidator.evMaxSmallerMin,
2759 TosaErrorValidator.evWrongInputType,
2760 TosaErrorValidator.evWrongOutputType,
2761 TosaErrorValidator.evWrongInputList,
2762 TosaErrorValidator.evWrongOutputList,
2763 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002764 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002765 "sigmoid": {
2766 "op": Op.SIGMOID,
2767 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002768 "build_fcn": (
2769 build_sigmoid,
2770 TosaTensorGen.tgBasic,
2771 TosaTensorValuesGen.tvgDefault,
2772 None,
2773 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002774 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002775 "error_if_validators": (
2776 TosaErrorValidator.evWrongInputType,
2777 TosaErrorValidator.evWrongOutputType,
2778 TosaErrorValidator.evWrongInputList,
2779 TosaErrorValidator.evWrongOutputList,
2780 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002781 },
2782 "tanh": {
2783 "op": Op.TANH,
2784 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002785 "build_fcn": (
2786 build_tanh,
2787 TosaTensorGen.tgBasic,
2788 TosaTensorValuesGen.tvgDefault,
2789 None,
2790 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002791 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002792 "error_if_validators": (
2793 TosaErrorValidator.evWrongInputType,
2794 TosaErrorValidator.evWrongOutputType,
2795 TosaErrorValidator.evWrongInputList,
2796 TosaErrorValidator.evWrongOutputList,
2797 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002798 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002799 # Elementwise Binary Operators
2800 "add": {
2801 "op": Op.ADD,
2802 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002803 "build_fcn": (
2804 build_binary_broadcast,
2805 TosaTensorGen.tgBroadcastFuzz,
2806 TosaTensorValuesGen.tvgAddSub,
2807 None,
2808 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002809 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002810 "error_if_validators": (
2811 TosaErrorValidator.evRankMismatch,
2812 TosaErrorValidator.evWrongInputType,
2813 TosaErrorValidator.evWrongOutputType,
2814 TosaErrorValidator.evWrongInputList,
2815 TosaErrorValidator.evWrongOutputList,
2816 TosaErrorValidator.evDimensionMismatch,
2817 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002818 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002819 "arithmetic_right_shift": {
2820 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2821 "operands": (2, 0),
2822 "build_fcn": (
2823 build_arithmetic_right_shift,
2824 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002825 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002826 TosaArgGen.agArithmeticRightShift,
2827 ),
2828 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002829 "error_if_validators": (
2830 TosaErrorValidator.evRankMismatch,
2831 TosaErrorValidator.evWrongInputType,
2832 TosaErrorValidator.evWrongOutputType,
2833 TosaErrorValidator.evWrongInputList,
2834 TosaErrorValidator.evWrongOutputList,
2835 TosaErrorValidator.evDimensionMismatch,
2836 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002837 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002838 "bitwise_and": {
2839 "op": Op.BITWISE_AND,
2840 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002841 "build_fcn": (
2842 build_binary_broadcast,
2843 TosaTensorGen.tgBroadcastFuzz,
2844 TosaTensorValuesGen.tvgDefault,
2845 None,
2846 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002847 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002848 "error_if_validators": (
2849 TosaErrorValidator.evRankMismatch,
2850 TosaErrorValidator.evWrongInputType,
2851 TosaErrorValidator.evWrongOutputType,
2852 TosaErrorValidator.evWrongInputList,
2853 TosaErrorValidator.evWrongOutputList,
2854 TosaErrorValidator.evDimensionMismatch,
2855 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002856 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002857 "bitwise_or": {
2858 "op": Op.BITWISE_OR,
2859 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002860 "build_fcn": (
2861 build_binary_broadcast,
2862 TosaTensorGen.tgBroadcastFuzz,
2863 TosaTensorValuesGen.tvgDefault,
2864 None,
2865 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002866 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002867 "error_if_validators": (
2868 TosaErrorValidator.evRankMismatch,
2869 TosaErrorValidator.evWrongInputType,
2870 TosaErrorValidator.evWrongOutputType,
2871 TosaErrorValidator.evWrongInputList,
2872 TosaErrorValidator.evWrongOutputList,
2873 TosaErrorValidator.evDimensionMismatch,
2874 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002875 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002876 "bitwise_xor": {
2877 "op": Op.BITWISE_XOR,
2878 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002879 "build_fcn": (
2880 build_binary_broadcast,
2881 TosaTensorGen.tgBroadcastFuzz,
2882 TosaTensorValuesGen.tvgDefault,
2883 None,
2884 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002885 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002886 "error_if_validators": (
2887 TosaErrorValidator.evRankMismatch,
2888 TosaErrorValidator.evWrongInputType,
2889 TosaErrorValidator.evWrongOutputType,
2890 TosaErrorValidator.evWrongInputList,
2891 TosaErrorValidator.evWrongOutputList,
2892 TosaErrorValidator.evDimensionMismatch,
2893 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002894 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002895 "intdiv": {
2896 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002897 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002898 "build_fcn": (
2899 build_binary_broadcast,
2900 TosaTensorGen.tgBroadcastFuzz,
2901 TosaTensorValuesGen.tvgIntDiv,
2902 None,
2903 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002904 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002905 "error_if_validators": (
2906 TosaErrorValidator.evRankMismatch,
2907 TosaErrorValidator.evWrongInputType,
2908 TosaErrorValidator.evWrongOutputType,
2909 TosaErrorValidator.evWrongInputList,
2910 TosaErrorValidator.evWrongOutputList,
2911 TosaErrorValidator.evDimensionMismatch,
2912 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002913 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002914 "logical_and": {
2915 "op": Op.LOGICAL_AND,
2916 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002917 "build_fcn": (
2918 build_binary_broadcast,
2919 TosaTensorGen.tgBroadcastFuzz,
2920 TosaTensorValuesGen.tvgDefault,
2921 None,
2922 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002923 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002924 "error_if_validators": (
2925 TosaErrorValidator.evRankMismatch,
2926 TosaErrorValidator.evWrongInputType,
2927 TosaErrorValidator.evWrongOutputType,
2928 TosaErrorValidator.evWrongInputList,
2929 TosaErrorValidator.evWrongOutputList,
2930 TosaErrorValidator.evDimensionMismatch,
2931 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002932 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002933 "logical_left_shift": {
2934 "op": Op.LOGICAL_LEFT_SHIFT,
2935 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002936 "build_fcn": (
2937 build_binary_broadcast,
2938 TosaTensorGen.tgBroadcastFuzz,
2939 TosaTensorValuesGen.tvgLogicalShift,
2940 None,
2941 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002942 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002943 "error_if_validators": (
2944 TosaErrorValidator.evRankMismatch,
2945 TosaErrorValidator.evWrongInputType,
2946 TosaErrorValidator.evWrongOutputType,
2947 TosaErrorValidator.evWrongInputList,
2948 TosaErrorValidator.evWrongOutputList,
2949 TosaErrorValidator.evDimensionMismatch,
2950 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002951 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002952 "logical_right_shift": {
2953 "op": Op.LOGICAL_RIGHT_SHIFT,
2954 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002955 "build_fcn": (
2956 build_binary_broadcast,
2957 TosaTensorGen.tgBroadcastFuzz,
2958 TosaTensorValuesGen.tvgLogicalShift,
2959 None,
2960 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002961 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002962 "error_if_validators": (
2963 TosaErrorValidator.evRankMismatch,
2964 TosaErrorValidator.evWrongInputType,
2965 TosaErrorValidator.evWrongOutputType,
2966 TosaErrorValidator.evWrongInputList,
2967 TosaErrorValidator.evWrongOutputList,
2968 TosaErrorValidator.evDimensionMismatch,
2969 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002970 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002971 "logical_or": {
2972 "op": Op.LOGICAL_OR,
2973 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002974 "build_fcn": (
2975 build_binary_broadcast,
2976 TosaTensorGen.tgBroadcastFuzz,
2977 TosaTensorValuesGen.tvgDefault,
2978 None,
2979 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002980 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002981 "error_if_validators": (
2982 TosaErrorValidator.evRankMismatch,
2983 TosaErrorValidator.evWrongInputType,
2984 TosaErrorValidator.evWrongOutputType,
2985 TosaErrorValidator.evWrongInputList,
2986 TosaErrorValidator.evWrongOutputList,
2987 TosaErrorValidator.evDimensionMismatch,
2988 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002989 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002990 "logical_xor": {
2991 "op": Op.LOGICAL_XOR,
2992 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002993 "build_fcn": (
2994 build_binary_broadcast,
2995 TosaTensorGen.tgBroadcastFuzz,
2996 TosaTensorValuesGen.tvgDefault,
2997 None,
2998 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002999 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003000 "error_if_validators": (
3001 TosaErrorValidator.evRankMismatch,
3002 TosaErrorValidator.evWrongInputType,
3003 TosaErrorValidator.evWrongOutputType,
3004 TosaErrorValidator.evWrongInputList,
3005 TosaErrorValidator.evWrongOutputList,
3006 TosaErrorValidator.evDimensionMismatch,
3007 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003008 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003009 "maximum": {
3010 "op": Op.MAXIMUM,
3011 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003012 "build_fcn": (
3013 build_binary_broadcast,
3014 TosaTensorGen.tgBroadcastFuzz,
3015 TosaTensorValuesGen.tvgDefault,
3016 None,
3017 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003018 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003019 "error_if_validators": (
3020 TosaErrorValidator.evRankMismatch,
3021 TosaErrorValidator.evWrongInputType,
3022 TosaErrorValidator.evWrongOutputType,
3023 TosaErrorValidator.evWrongInputList,
3024 TosaErrorValidator.evWrongOutputList,
3025 TosaErrorValidator.evDimensionMismatch,
3026 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003027 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003028 "minimum": {
3029 "op": Op.MINIMUM,
3030 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003031 "build_fcn": (
3032 build_binary_broadcast,
3033 TosaTensorGen.tgBroadcastFuzz,
3034 TosaTensorValuesGen.tvgDefault,
3035 None,
3036 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003037 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003038 "error_if_validators": (
3039 TosaErrorValidator.evRankMismatch,
3040 TosaErrorValidator.evWrongInputType,
3041 TosaErrorValidator.evWrongOutputType,
3042 TosaErrorValidator.evWrongInputList,
3043 TosaErrorValidator.evWrongOutputList,
3044 TosaErrorValidator.evDimensionMismatch,
3045 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003046 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003047 "mul": {
3048 "op": Op.MUL,
3049 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003050 "build_fcn": (
3051 build_mul,
3052 TosaTensorGen.tgBroadcastFuzz,
3053 TosaTensorValuesGen.tvgMul,
3054 TosaArgGen.agMul,
3055 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003056 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003057 "error_if_validators": (
3058 TosaErrorValidator.evWrongInputType,
3059 TosaErrorValidator.evWrongOutputType,
3060 TosaErrorValidator.evWrongInputList,
3061 TosaErrorValidator.evWrongOutputList,
3062 TosaErrorValidator.evRankMismatch,
3063 TosaErrorValidator.evDimensionMismatch,
3064 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003065 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003066 "pow": {
3067 "op": Op.POW,
3068 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003069 "build_fcn": (
3070 build_binary_broadcast,
3071 TosaTensorGen.tgBroadcastFuzz,
3072 TosaTensorValuesGen.tvgDefault,
3073 None,
3074 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003075 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003076 "error_if_validators": (
3077 TosaErrorValidator.evRankMismatch,
3078 TosaErrorValidator.evWrongInputType,
3079 TosaErrorValidator.evWrongOutputType,
3080 TosaErrorValidator.evWrongInputList,
3081 TosaErrorValidator.evWrongOutputList,
3082 TosaErrorValidator.evDimensionMismatch,
3083 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003084 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003085 "sub": {
3086 "op": Op.SUB,
3087 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003088 "build_fcn": (
3089 build_binary_broadcast,
3090 TosaTensorGen.tgBroadcastFuzz,
3091 TosaTensorValuesGen.tvgAddSub,
3092 None,
3093 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003094 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003095 "error_if_validators": (
3096 TosaErrorValidator.evRankMismatch,
3097 TosaErrorValidator.evWrongInputType,
3098 TosaErrorValidator.evWrongOutputType,
3099 TosaErrorValidator.evWrongInputList,
3100 TosaErrorValidator.evWrongOutputList,
3101 TosaErrorValidator.evDimensionMismatch,
3102 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003103 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003104 "table": {
3105 "op": Op.TABLE,
3106 # Use the automatic generation functions to create the input array
3107 # but create the table tensor in the build function, as it may be
3108 # a different type from the input
3109 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003110 "build_fcn": (
3111 build_table,
3112 TosaTensorGen.tgBasic,
3113 TosaTensorValuesGen.tvgDefault,
3114 TosaArgGen.agTable,
3115 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003116 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003117 "error_if_validators": (
3118 TosaErrorValidator.evWrongInputType,
3119 TosaErrorValidator.evWrongOutputType,
3120 TosaErrorValidator.evWrongInputList,
3121 TosaErrorValidator.evWrongOutputList,
3122 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003123 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003124 # Elementwise Unary operators
3125 "abs": {
3126 "op": Op.ABS,
3127 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003128 "build_fcn": (
3129 build_unary,
3130 TosaTensorGen.tgBasic,
3131 TosaTensorValuesGen.tvgDefault,
3132 None,
3133 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003134 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003135 "error_if_validators": (
3136 TosaErrorValidator.evWrongInputType,
3137 TosaErrorValidator.evWrongOutputType,
3138 TosaErrorValidator.evWrongInputList,
3139 TosaErrorValidator.evWrongOutputList,
3140 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003141 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003142 "bitwise_not": {
3143 "op": Op.BITWISE_NOT,
3144 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003145 "build_fcn": (
3146 build_unary,
3147 TosaTensorGen.tgBasic,
3148 TosaTensorValuesGen.tvgDefault,
3149 None,
3150 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003151 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003152 "error_if_validators": (
3153 TosaErrorValidator.evWrongInputType,
3154 TosaErrorValidator.evWrongOutputType,
3155 TosaErrorValidator.evWrongInputList,
3156 TosaErrorValidator.evWrongOutputList,
3157 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003158 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003159 "ceil": {
3160 "op": Op.CEIL,
3161 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003162 "build_fcn": (
3163 build_unary,
3164 TosaTensorGen.tgBasic,
3165 TosaTensorValuesGen.tvgDefault,
3166 None,
3167 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003168 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003169 "error_if_validators": (
3170 TosaErrorValidator.evWrongInputType,
3171 TosaErrorValidator.evWrongOutputType,
3172 TosaErrorValidator.evWrongInputList,
3173 TosaErrorValidator.evWrongOutputList,
3174 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003175 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003176 "clz": {
3177 "op": Op.CLZ,
3178 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003179 "build_fcn": (
3180 build_unary,
3181 TosaTensorGen.tgBasic,
3182 TosaTensorValuesGen.tvgDefault,
3183 None,
3184 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003185 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003186 "error_if_validators": (
3187 TosaErrorValidator.evWrongInputType,
3188 TosaErrorValidator.evWrongOutputType,
3189 TosaErrorValidator.evWrongInputList,
3190 TosaErrorValidator.evWrongOutputList,
3191 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003192 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003193 "exp": {
3194 "op": Op.EXP,
3195 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003196 "build_fcn": (
3197 build_unary,
3198 TosaTensorGen.tgBasic,
3199 TosaTensorValuesGen.tvgDefault,
3200 None,
3201 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003202 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003203 "error_if_validators": (
3204 TosaErrorValidator.evWrongInputType,
3205 TosaErrorValidator.evWrongOutputType,
3206 TosaErrorValidator.evWrongInputList,
3207 TosaErrorValidator.evWrongOutputList,
3208 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003209 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003210 "floor": {
3211 "op": Op.FLOOR,
3212 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003213 "build_fcn": (
3214 build_unary,
3215 TosaTensorGen.tgBasic,
3216 TosaTensorValuesGen.tvgDefault,
3217 None,
3218 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003219 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003220 "error_if_validators": (
3221 TosaErrorValidator.evWrongInputType,
3222 TosaErrorValidator.evWrongOutputType,
3223 TosaErrorValidator.evWrongInputList,
3224 TosaErrorValidator.evWrongOutputList,
3225 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003226 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003227 "log": {
3228 "op": Op.LOG,
3229 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003230 "build_fcn": (
3231 build_unary,
3232 TosaTensorGen.tgBasic,
3233 TosaTensorValuesGen.tvgDefault,
3234 None,
3235 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003236 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003237 "error_if_validators": (
3238 TosaErrorValidator.evWrongInputType,
3239 TosaErrorValidator.evWrongOutputType,
3240 TosaErrorValidator.evWrongInputList,
3241 TosaErrorValidator.evWrongOutputList,
3242 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003243 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003244 "logical_not": {
3245 "op": Op.LOGICAL_NOT,
3246 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003247 "build_fcn": (
3248 build_unary,
3249 TosaTensorGen.tgBasic,
3250 TosaTensorValuesGen.tvgDefault,
3251 None,
3252 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003253 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003254 "error_if_validators": (
3255 TosaErrorValidator.evWrongInputType,
3256 TosaErrorValidator.evWrongOutputType,
3257 TosaErrorValidator.evWrongInputList,
3258 TosaErrorValidator.evWrongOutputList,
3259 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003260 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003261 "negate": {
3262 "op": Op.NEGATE,
3263 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003264 "build_fcn": (
3265 build_unary,
3266 TosaTensorGen.tgBasic,
3267 TosaTensorValuesGen.tvgNegate,
3268 None,
3269 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003270 "qgen": TosaQuantGen.qgUnary,
3271 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003272 "error_if_validators": (
3273 TosaErrorValidator.evInputZeroPointNotZero,
3274 TosaErrorValidator.evOutputZeroPointNotZero,
3275 TosaErrorValidator.evWrongInputType,
3276 TosaErrorValidator.evWrongOutputType,
3277 TosaErrorValidator.evWrongInputList,
3278 TosaErrorValidator.evWrongOutputList,
3279 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003280 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003281 "reciprocal": {
3282 "op": Op.RECIPROCAL,
3283 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284 "build_fcn": (
3285 build_unary,
3286 TosaTensorGen.tgBasic,
3287 TosaTensorValuesGen.tvgDefault,
3288 None,
3289 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003290 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003291 "error_if_validators": (
3292 TosaErrorValidator.evWrongInputType,
3293 TosaErrorValidator.evWrongOutputType,
3294 TosaErrorValidator.evWrongInputList,
3295 TosaErrorValidator.evWrongOutputList,
3296 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003297 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003298 "rsqrt": {
3299 "op": Op.RSQRT,
3300 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003301 "build_fcn": (
3302 build_unary,
3303 TosaTensorGen.tgBasic,
3304 TosaTensorValuesGen.tvgDefault,
3305 None,
3306 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003307 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003308 "error_if_validators": (
3309 TosaErrorValidator.evWrongInputType,
3310 TosaErrorValidator.evWrongOutputType,
3311 TosaErrorValidator.evWrongInputList,
3312 TosaErrorValidator.evWrongOutputList,
3313 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003314 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003315 # Elementwise Ternary operators
3316 "select": {
3317 "op": Op.SELECT,
3318 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003319 "build_fcn": (
3320 build_select,
3321 TosaTensorGen.tgBroadcastFuzz,
3322 TosaTensorValuesGen.tvgSelect,
3323 None,
3324 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003325 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003326 "error_if_validators": (
3327 TosaErrorValidator.evRankMismatch,
3328 TosaErrorValidator.evWrongInputType,
3329 TosaErrorValidator.evWrongOutputType,
3330 TosaErrorValidator.evWrongInputList,
3331 TosaErrorValidator.evWrongOutputList,
3332 TosaErrorValidator.evDimensionMismatch,
3333 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003334 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003335 # Comparison operators
3336 "equal": {
3337 "op": Op.EQUAL,
3338 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003339 "build_fcn": (
3340 build_comparison,
3341 TosaTensorGen.tgBroadcastFuzz,
3342 TosaTensorValuesGen.tvgEqual,
3343 None,
3344 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003345 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003346 "error_if_validators": (
3347 TosaErrorValidator.evRankMismatch,
3348 TosaErrorValidator.evWrongInputType,
3349 TosaErrorValidator.evWrongOutputType,
3350 TosaErrorValidator.evWrongInputList,
3351 TosaErrorValidator.evWrongOutputList,
3352 TosaErrorValidator.evDimensionMismatch,
3353 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003354 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003355 "greater_equal": {
3356 "op": Op.GREATER_EQUAL,
3357 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003358 "build_fcn": (
3359 build_comparison,
3360 TosaTensorGen.tgBroadcastFuzz,
3361 TosaTensorValuesGen.tvgDefault,
3362 None,
3363 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003364 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003365 "error_if_validators": (
3366 TosaErrorValidator.evRankMismatch,
3367 TosaErrorValidator.evWrongInputType,
3368 TosaErrorValidator.evWrongOutputType,
3369 TosaErrorValidator.evWrongInputList,
3370 TosaErrorValidator.evWrongOutputList,
3371 TosaErrorValidator.evDimensionMismatch,
3372 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003373 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003374 "greater": {
3375 "op": Op.GREATER,
3376 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003377 "build_fcn": (
3378 build_comparison,
3379 TosaTensorGen.tgBroadcastFuzz,
3380 TosaTensorValuesGen.tvgDefault,
3381 None,
3382 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003383 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003384 "error_if_validators": (
3385 TosaErrorValidator.evRankMismatch,
3386 TosaErrorValidator.evWrongInputType,
3387 TosaErrorValidator.evWrongOutputType,
3388 TosaErrorValidator.evWrongInputList,
3389 TosaErrorValidator.evWrongOutputList,
3390 TosaErrorValidator.evDimensionMismatch,
3391 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003392 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003393 # Reduction operators
3394 "reduce_all": {
3395 "op": Op.REDUCE_ALL,
3396 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003397 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003398 "build_fcn": (
3399 build_reduce,
3400 TosaTensorGen.tgBasic,
3401 TosaTensorValuesGen.tvgDefault,
3402 TosaArgGen.agAxis,
3403 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003404 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003405 "error_if_validators": (
3406 TosaErrorValidator.evAxisLargerRank,
3407 TosaErrorValidator.evAxisSmallerZero,
3408 TosaErrorValidator.evShapeOfAxisNotOne,
3409 TosaErrorValidator.evWrongInputType,
3410 TosaErrorValidator.evWrongOutputType,
3411 TosaErrorValidator.evWrongRank,
3412 TosaErrorValidator.evWrongInputList,
3413 TosaErrorValidator.evWrongOutputList,
3414 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003415 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003416 "reduce_any": {
3417 "op": Op.REDUCE_ANY,
3418 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003419 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003420 "build_fcn": (
3421 build_reduce,
3422 TosaTensorGen.tgBasic,
3423 TosaTensorValuesGen.tvgDefault,
3424 TosaArgGen.agAxis,
3425 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003426 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003427 "error_if_validators": (
3428 TosaErrorValidator.evAxisLargerRank,
3429 TosaErrorValidator.evAxisSmallerZero,
3430 TosaErrorValidator.evShapeOfAxisNotOne,
3431 TosaErrorValidator.evWrongInputType,
3432 TosaErrorValidator.evWrongOutputType,
3433 TosaErrorValidator.evWrongRank,
3434 TosaErrorValidator.evWrongInputList,
3435 TosaErrorValidator.evWrongOutputList,
3436 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003437 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003438 "reduce_max": {
3439 "op": Op.REDUCE_MAX,
3440 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003441 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003442 "build_fcn": (
3443 build_reduce,
3444 TosaTensorGen.tgBasic,
3445 TosaTensorValuesGen.tvgDefault,
3446 TosaArgGen.agAxis,
3447 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003448 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003449 "error_if_validators": (
3450 TosaErrorValidator.evAxisLargerRank,
3451 TosaErrorValidator.evAxisSmallerZero,
3452 TosaErrorValidator.evShapeOfAxisNotOne,
3453 TosaErrorValidator.evWrongInputType,
3454 TosaErrorValidator.evWrongOutputType,
3455 TosaErrorValidator.evWrongRank,
3456 TosaErrorValidator.evWrongInputList,
3457 TosaErrorValidator.evWrongOutputList,
3458 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003459 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003460 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003461 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003462 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003463 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003464 "build_fcn": (
3465 build_reduce,
3466 TosaTensorGen.tgBasic,
3467 TosaTensorValuesGen.tvgDefault,
3468 TosaArgGen.agAxis,
3469 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003470 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003471 "error_if_validators": (
3472 TosaErrorValidator.evAxisLargerRank,
3473 TosaErrorValidator.evAxisSmallerZero,
3474 TosaErrorValidator.evShapeOfAxisNotOne,
3475 TosaErrorValidator.evWrongInputType,
3476 TosaErrorValidator.evWrongOutputType,
3477 TosaErrorValidator.evWrongRank,
3478 TosaErrorValidator.evWrongInputList,
3479 TosaErrorValidator.evWrongOutputList,
3480 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003481 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003482 "reduce_product": {
3483 "op": Op.REDUCE_PRODUCT,
3484 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003485 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003486 "build_fcn": (
3487 build_reduce,
3488 TosaTensorGen.tgBasic,
3489 TosaTensorValuesGen.tvgDefault,
3490 TosaArgGen.agAxis,
3491 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003492 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003493 "error_if_validators": (
3494 TosaErrorValidator.evAxisLargerRank,
3495 TosaErrorValidator.evAxisSmallerZero,
3496 TosaErrorValidator.evShapeOfAxisNotOne,
3497 TosaErrorValidator.evWrongInputType,
3498 TosaErrorValidator.evWrongOutputType,
3499 TosaErrorValidator.evWrongRank,
3500 TosaErrorValidator.evWrongInputList,
3501 TosaErrorValidator.evWrongOutputList,
3502 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003503 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 "reduce_sum": {
3505 "op": Op.REDUCE_SUM,
3506 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003507 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003508 "build_fcn": (
3509 build_reduce,
3510 TosaTensorGen.tgBasic,
3511 TosaTensorValuesGen.tvgReduceSum,
3512 TosaArgGen.agAxis,
3513 ),
James Ward24dbc422022-10-19 12:20:31 +01003514 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003515 "error_if_validators": (
3516 TosaErrorValidator.evAxisLargerRank,
3517 TosaErrorValidator.evAxisSmallerZero,
3518 TosaErrorValidator.evShapeOfAxisNotOne,
3519 TosaErrorValidator.evWrongInputType,
3520 TosaErrorValidator.evWrongOutputType,
3521 TosaErrorValidator.evWrongRank,
3522 TosaErrorValidator.evWrongInputList,
3523 TosaErrorValidator.evWrongOutputList,
3524 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003526 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003527 "concat": {
3528 "op": Op.CONCAT,
3529 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003530 "build_fcn": (
3531 build_concat,
3532 TosaTensorGen.tgConcat,
3533 TosaTensorValuesGen.tvgConcat,
3534 TosaArgGen.agAxis,
3535 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003536 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003537 "error_if_validators": (
3538 TosaErrorValidator.evAxisLargerRank,
3539 TosaErrorValidator.evAxisSmallerZero,
3540 TosaErrorValidator.evConcatInputRankMismatch,
3541 TosaErrorValidator.evConcatShapeSumMismatch,
3542 TosaErrorValidator.evConcatInputDimMismatch,
3543 TosaErrorValidator.evWrongInputType,
3544 TosaErrorValidator.evWrongOutputType,
3545 TosaErrorValidator.evWrongOutputList,
3546 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003547 },
3548 "pad": {
3549 "op": Op.PAD,
3550 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003551 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003552 "build_fcn": (
3553 build_pad,
3554 TosaTensorGen.tgBasic,
3555 TosaTensorValuesGen.tvgDefault,
3556 TosaArgGen.agPad,
3557 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003558 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003559 "error_if_validators": (
3560 TosaErrorValidator.evWrongInputType,
3561 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003562 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003563 TosaErrorValidator.evWrongOutputType,
3564 TosaErrorValidator.evWrongInputList,
3565 TosaErrorValidator.evWrongOutputList,
3566 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003567 },
3568 "reshape": {
3569 "op": Op.RESHAPE,
3570 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003571 "build_fcn": (
3572 build_reshape,
3573 TosaTensorGen.tgBasic,
3574 TosaTensorValuesGen.tvgDefault,
3575 TosaArgGen.agReshape,
3576 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003577 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003578 "error_if_validators": (
3579 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3580 TosaErrorValidator.evWrongInputType,
3581 TosaErrorValidator.evWrongOutputType,
3582 TosaErrorValidator.evWrongInputList,
3583 TosaErrorValidator.evWrongOutputList,
3584 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003585 },
3586 "reverse": {
3587 "op": Op.REVERSE,
3588 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003589 "build_fcn": (
3590 build_reverse,
3591 TosaTensorGen.tgBasic,
3592 TosaTensorValuesGen.tvgDefault,
3593 TosaArgGen.agAxis,
3594 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003595 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003596 "error_if_validators": (
3597 TosaErrorValidator.evAxisSmallerZero,
3598 TosaErrorValidator.evAxisLargerRank,
3599 TosaErrorValidator.evWrongInputType,
3600 TosaErrorValidator.evWrongOutputType,
3601 TosaErrorValidator.evWrongInputList,
3602 TosaErrorValidator.evWrongOutputList,
3603 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003604 },
3605 "slice": {
3606 "op": Op.SLICE,
3607 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003608 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003609 "build_fcn": (
3610 build_slice,
3611 TosaTensorGen.tgBasic,
3612 TosaTensorValuesGen.tvgDefault,
3613 TosaArgGen.agSlice,
3614 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003615 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003616 "error_if_validators": (
3617 TosaErrorValidator.evStartSmallerZero,
3618 TosaErrorValidator.evSizeSmallerEqualZero,
3619 TosaErrorValidator.evStartSizeOutsideBounds,
3620 TosaErrorValidator.evSizeOutputShapeMismatch,
3621 TosaErrorValidator.evInputSizeStartLengthMismatch,
3622 TosaErrorValidator.evWrongRank,
3623 TosaErrorValidator.evWrongInputType,
3624 TosaErrorValidator.evWrongOutputType,
3625 TosaErrorValidator.evWrongInputList,
3626 TosaErrorValidator.evWrongOutputList,
3627 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003628 },
3629 "tile": {
3630 "op": Op.TILE,
3631 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003632 "build_fcn": (
3633 build_tile,
3634 TosaTensorGen.tgBasic,
3635 TosaTensorValuesGen.tvgDefault,
3636 TosaArgGen.agTile,
3637 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003638 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003639 "error_if_validators": (
3640 TosaErrorValidator.evWrongInputType,
3641 TosaErrorValidator.evWrongOutputType,
3642 TosaErrorValidator.evWrongInputList,
3643 TosaErrorValidator.evWrongOutputList,
3644 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003645 },
3646 "transpose": {
3647 "op": Op.TRANSPOSE,
3648 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003649 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003650 "build_fcn": (
3651 build_transpose,
3652 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003653 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003654 TosaArgGen.agTranspose,
3655 ),
3656 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003657 "error_if_validators": (
3658 TosaErrorValidator.evIndexOutsideBounds,
3659 TosaErrorValidator.evIndexUsedTwice,
3660 TosaErrorValidator.evWrongInputType,
3661 TosaErrorValidator.evWrongOutputType,
3662 TosaErrorValidator.evWrongInputList,
3663 TosaErrorValidator.evWrongOutputList,
3664 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003665 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003666 # Data nodes
3667 "const": {
3668 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003669 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003670 "build_fcn": (
3671 build_const,
3672 TosaTensorGen.tgBasic,
3673 TosaTensorValuesGen.tvgDefault,
3674 None,
3675 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003676 "types": TYPE_FIB,
3677 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003678 "identity": {
3679 "op": Op.IDENTITY,
3680 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003681 "build_fcn": (
3682 build_unary,
3683 TosaTensorGen.tgBasic,
3684 TosaTensorValuesGen.tvgDefault,
3685 None,
3686 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 "types": TYPE_FIB,
3688 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003689 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003690 "gather": {
3691 "op": Op.GATHER,
3692 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3693 "operands": (1, 0),
3694 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003695 "build_fcn": (
3696 build_gather,
3697 TosaTensorGen.tgBasic,
3698 TosaTensorValuesGen.tvgDefault,
3699 None,
3700 ),
James Ward24dbc422022-10-19 12:20:31 +01003701 "types": (
3702 DType.INT8,
3703 DType.INT16,
3704 DType.INT32,
3705 DType.FP16,
3706 DType.BF16,
3707 DType.FP32,
3708 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003709 "error_if_validators": (
3710 TosaErrorValidator.evWrongInputType,
3711 TosaErrorValidator.evWrongOutputType,
3712 TosaErrorValidator.evWrongInputList,
3713 TosaErrorValidator.evWrongOutputList,
3714 TosaErrorValidator.evWrongRank,
3715 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003716 },
3717 "scatter": {
3718 "op": Op.SCATTER,
3719 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003720 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003721 "operands": (2, 0),
3722 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003723 "build_fcn": (
3724 build_scatter,
3725 TosaTensorGen.tgScatter,
3726 TosaTensorValuesGen.tvgDefault,
3727 None,
3728 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003729 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003730 "error_if_validators": (
3731 TosaErrorValidator.evWrongInputType,
3732 TosaErrorValidator.evWrongOutputType,
3733 TosaErrorValidator.evWrongInputList,
3734 TosaErrorValidator.evWrongOutputList,
3735 TosaErrorValidator.evWrongRank,
3736 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003737 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003738 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003739 "resize": {
3740 "op": Op.RESIZE,
3741 "operands": (1, 0),
3742 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003743 "build_fcn": (
3744 build_resize,
3745 TosaTensorGen.tgNHWC,
3746 TosaTensorValuesGen.tvgDefault,
3747 TosaArgGen.agResize,
3748 ),
James Ward24dbc422022-10-19 12:20:31 +01003749 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003750 "invalid_test_validators": (
3751 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003752 ),
3753 "error_if_validators": (
3754 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003755 TosaErrorValidator.evScaleSmallerEqualZero,
3756 TosaErrorValidator.evScaleNLargerMax,
3757 TosaErrorValidator.evScaleDLargerMax,
3758 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003760 TosaErrorValidator.evBorderSmallerMin,
3761 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003762 TosaErrorValidator.evWrongInputType,
3763 TosaErrorValidator.evWrongOutputType,
3764 TosaErrorValidator.evWrongRank,
3765 TosaErrorValidator.evWrongInputList,
3766 TosaErrorValidator.evWrongOutputList,
3767 TosaErrorValidator.evBatchMismatch,
3768 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003769 TosaErrorValidator.evResizeOutputShapeMismatch,
3770 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003771 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003772 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003773 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003774 "cast": {
3775 "op": Op.CAST,
3776 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003777 "build_fcn": (
3778 build_cast,
3779 TosaTensorGen.tgBasic,
3780 TosaTensorValuesGen.tvgDefault,
3781 TosaArgGen.agCast,
3782 ),
James Ward8b390432022-08-12 20:48:56 +01003783 "types": (
3784 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003785 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003786 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003787 DType.INT8,
3788 DType.INT16,
3789 DType.INT32,
3790 DType.BOOL,
3791 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003792 "error_if_validators": (
3793 TosaErrorValidator.evWrongInputType,
3794 TosaErrorValidator.evWrongOutputType,
3795 TosaErrorValidator.evWrongInputList,
3796 TosaErrorValidator.evWrongOutputList,
3797 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003798 },
3799 "rescale": {
3800 "op": Op.RESCALE,
3801 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003802 "build_fcn": (
3803 build_rescale,
3804 TosaTensorGen.tgBasic,
3805 TosaTensorValuesGen.tvgDefault,
3806 TosaArgGen.agRescale,
3807 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003808 "types": [
3809 DType.UINT8,
3810 DType.INT8,
3811 DType.INT16,
3812 DType.INT32,
3813 DType.INT48,
3814 DType.UINT16,
3815 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003816 "error_if_validators": (
3817 TosaErrorValidator.evInputZeroPointNotZero,
3818 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003819 TosaErrorValidator.evU16InputZeroPointNotValid,
3820 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003821 TosaErrorValidator.evScaleTrue,
3822 TosaErrorValidator.evScaleNotTrue,
3823 TosaErrorValidator.evWrongInputType,
3824 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003825 TosaErrorValidator.evWrongInputList,
3826 TosaErrorValidator.evWrongOutputList,
3827 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003828 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003829 # Custom
3830 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003831 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003832 # Two varients of cond_if, one that generates one of two constant tensors (no
3833 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3834 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003835 "cond_if_const": {
3836 "op": Op.COND_IF,
3837 "operands": (0, 2),
3838 "build_fcn": (
3839 build_cond_if_const,
3840 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003841 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003842 TosaArgGen.agCondIf,
3843 ),
3844 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003845 "error_if_validators": (
3846 TosaErrorValidator.evOutputListThenGraphMismatch,
3847 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003848 TosaErrorValidator.evCondIfCondNotMatchingBool,
3849 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003850 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003851 },
3852 "cond_if_binary": {
3853 "op": Op.COND_IF,
3854 "operands": (2, 0),
3855 "build_fcn": (
3856 build_cond_if_binary,
3857 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003858 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003859 TosaArgGen.agCondIf,
3860 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003861 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003862 "error_if_validators": (
3863 TosaErrorValidator.evInputListThenGraphMismatch,
3864 TosaErrorValidator.evInputListElseGraphMismatch,
3865 TosaErrorValidator.evOutputListThenGraphMismatch,
3866 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003867 TosaErrorValidator.evCondIfCondNotMatchingBool,
3868 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003869 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003870 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003871 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003872 "while_loop": {
3873 "op": Op.WHILE_LOOP,
3874 "operands": (0, 1),
3875 "build_fcn": (
3876 build_while_loop,
3877 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003878 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003879 TosaArgGen.agWhileLoop,
3880 ),
3881 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003882 "error_if_validators": (
3883 TosaErrorValidator.evInputListOutputListMismatch,
3884 TosaErrorValidator.evInputListCondGraphMismatch,
3885 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3886 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3887 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003888 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003889 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003890 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003891 }
3892
Kevin Cheng550ccc52021-03-03 11:21:43 -08003893
Eric Kunzee5e26762020-10-13 16:11:07 -07003894class OutputShaper:
3895 # Methods in this class compute the expected output shape and datatype
3896 # for common classes of operations
3897 def __init__(self):
3898 pass
3899
3900 # These methods return arguments that can be used for
3901 # creating a new output tensor
3902 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003903 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3904 if error_name != ErrorIf.RankMismatch:
3905 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003906 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003907
3908 shape = []
3909 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003910 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003911 shape.append(b.shape[i])
3912 else:
3913 shape.append(a.shape[i])
3914
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003915 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003916 all_dtypes = [
3917 DType.INT8,
3918 DType.INT16,
3919 DType.INT32,
3920 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01003921 DType.FP16,
3922 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003923 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003924 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003925 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3926 outputDType = rng.choice(wrong_dtypes)
3927 else:
3928 outputDType = a.dtype
3929
3930 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003931
3932 @staticmethod
3933 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003934 assert len(a.shape) == len(b.shape)
3935 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003936
3937 shape = []
3938 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003939 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003940 shape.append(a.shape[i])
3941
Kevin Cheng550ccc52021-03-03 11:21:43 -08003942 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003943
3944 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003945 def unaryOp(ser, rng, a, error_name=None):
3946 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003947 all_dtypes = [
3948 DType.INT8,
3949 DType.INT16,
3950 DType.INT32,
3951 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003952 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01003953 DType.FP16,
3954 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003955 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003956 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3957 outputDType = rng.choice(wrong_dtypes)
3958 else:
3959 outputDType = a.dtype
3960
3961 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003962
3963 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003964 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003965 if error_name != ErrorIf.RankMismatch:
3966 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003967 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003968
3969 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003970 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003971 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003972 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3973 else:
3974 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003975
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003976 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003977 all_dtypes = [
3978 DType.INT8,
3979 DType.INT16,
3980 DType.INT32,
3981 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003982 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01003983 DType.FP16,
3984 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003985 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003986 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3987 outputDType = rng.choice(wrong_dtypes)
3988 else:
3989 outputDType = a.dtype
3990
3991 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003992
3993 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003994 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003995 if error_name != ErrorIf.RankMismatch:
3996 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003997 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003998
3999 # Do broadcast
4000 shape = []
4001 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004002 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004003 shape.append(b.shape[i])
4004 else:
4005 shape.append(a.shape[i])
4006
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004007 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004008 wrong_dtypes = [
4009 DType.INT8,
4010 DType.INT16,
4011 DType.INT32,
4012 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004013 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004014 DType.FP16,
4015 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004016 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004017 outputDType = rng.choice(wrong_dtypes)
4018 else:
4019 outputDType = DType.BOOL
4020
4021 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004022
4023 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004024 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004025 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004026 if error_name not in [
4027 ErrorIf.AxisSmallerZero,
4028 ErrorIf.AxisLargerRank,
4029 ErrorIf.ShapeOfAxisNotOne,
4030 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004031 shape[axis] = 1
4032 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4033 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004034
Matthew Haddond6ce7252021-09-29 15:35:44 +01004035 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004036 all_dtypes = [
4037 DType.INT8,
4038 DType.INT16,
4039 DType.INT32,
4040 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004041 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004042 DType.FP16,
4043 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004044 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004045 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4046 outputDType = rng.choice(wrong_dtypes)
4047 else:
4048 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004049
Matthew Haddond6ce7252021-09-29 15:35:44 +01004050 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004051
4052 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004053 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004054 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004055
4056 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4057 del shape[axis]
4058
4059 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4060 remove = rng.choice([True, False])
4061 if remove and len(shape) > 1:
4062 del shape[0]
4063 else:
4064 shape.append(1)
4065 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4066 for i in range(len(shape)):
4067 shape[i] = shape[i] + rng.integers(1, 10)
4068
4069 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004070 all_dtypes = [
4071 DType.INT8,
4072 DType.INT16,
4073 DType.INT32,
4074 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004075 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004076 DType.FP16,
4077 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004078 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004079 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4080 outputDType = rng.choice(wrong_dtypes)
4081 else:
4082 outputDType = DType.INT32
4083
4084 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004085
4086 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004087 def conv2dOp(
4088 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4089 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004090
4091 # IFM: NHWC
4092 # Filter: OHWI
4093 # OFM: NHWC
4094
Kevin Cheng550ccc52021-03-03 11:21:43 -08004095 h = (
4096 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004097 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004098 + padding[0]
4099 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004100 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004101 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004102
Kevin Cheng550ccc52021-03-03 11:21:43 -08004103 w = (
4104 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004105 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004106 + padding[2]
4107 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004108 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004109 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004110
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004111 if error_name == ErrorIf.ConvOutputShapeMismatch:
4112 choices = [1, 2, 3]
4113 change = rng.choice(choices)
4114 # increment in multiples of stride to not hit non-integer error case
4115 if change in [1, 3]:
4116 h = h + (rng.choice(choices) * strides[0])
4117 if change in [2, 3]:
4118 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004119
Eric Kunzee5e26762020-10-13 16:11:07 -07004120 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4121
James Ward8b390432022-08-12 20:48:56 +01004122 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004123 # Pick some potentially correct output dtype if input type is incorrect
4124 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004125 else:
James Ward8b390432022-08-12 20:48:56 +01004126 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004127
4128 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004129 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004130 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004131 else:
4132 excludes = [out_dtype]
4133 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004134 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004135
Kevin Cheng550ccc52021-03-03 11:21:43 -08004136 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004137
4138 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004139 def conv3dOp(
4140 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4141 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004142
4143 # IFM: NDHWC
4144 # Filter: ODHWI
4145 # OFM: NDHWC
4146
4147 d = (
4148 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004149 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004150 + padding[0]
4151 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004152 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004153 ) // strides[0] + 1
4154
4155 h = (
4156 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004157 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004158 + padding[2]
4159 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004160 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004161 ) // strides[1] + 1
4162
4163 w = (
4164 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004165 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004166 + padding[4]
4167 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004168 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004169 ) // strides[2] + 1
4170
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004171 if error_name == ErrorIf.ConvOutputShapeMismatch:
4172 choices = [1, 2, 3, 4]
4173 change = rng.choice(choices)
4174 # increment in multiples of stride to not hit non-integer error case
4175 if change in [1, 4]:
4176 d = d + (rng.choice(choices) * strides[0])
4177 if change in [2, 4]:
4178 h = h + (rng.choice(choices) * strides[1])
4179 if change in [3, 4]:
4180 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004181
Kevin Cheng1533b852021-09-01 12:51:58 -07004182 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4183
James Ward8b390432022-08-12 20:48:56 +01004184 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004185 # Pick some potentially correct output dtype if input type is incorrect
4186 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004187 else:
James Ward8b390432022-08-12 20:48:56 +01004188 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004189
4190 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004191 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004192 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004193 else:
4194 excludes = [out_dtype]
4195 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004196 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004197
4198 return ser.addOutput(ofm_shape, out_dtype)
4199
4200 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004201 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004202 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004203 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004204 # IFM: NHWC
4205 # Filter: HWCM
4206 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004207
Kevin Cheng550ccc52021-03-03 11:21:43 -08004208 h = (
4209 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004210 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004211 + padding[0]
4212 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004213 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004214 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004215
Kevin Cheng550ccc52021-03-03 11:21:43 -08004216 w = (
4217 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004218 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004219 + padding[2]
4220 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004221 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004222 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004223
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004224 if error_name == ErrorIf.ConvOutputShapeMismatch:
4225 choices = [1, 2, 3]
4226 change = rng.choice(choices)
4227 # increment in multiples of stride to not hit non-integer error case
4228 if change in [1, 3]:
4229 h = h + (rng.choice(choices) * strides[0])
4230 if change in [2, 3]:
4231 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004232
Eric Kunzee5e26762020-10-13 16:11:07 -07004233 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4234
James Ward8b390432022-08-12 20:48:56 +01004235 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004236 # Pick some potentially correct output dtype if input type is incorrect
4237 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004238 else:
James Ward8b390432022-08-12 20:48:56 +01004239 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004240
4241 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004242 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004243 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004244 else:
4245 excludes = [out_dtype]
4246 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004247 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004248
Kevin Cheng550ccc52021-03-03 11:21:43 -08004249 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004250
4251 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004252 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004253 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004254 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004255 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004256 h = 1
4257 w = 1
4258 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004259 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4260 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004261
4262 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004263 choices = [1, 2, 3]
4264 change = rng.choice(choices)
4265 # increment in multiples of stride to not hit non-integer error case
4266 if change in [1, 3]:
4267 h = h + (rng.choice(choices) * stride[0])
4268 if change in [2, 3]:
4269 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004270 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004271
4272 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004273 all_dtypes = [
4274 DType.INT8,
4275 DType.INT16,
4276 DType.INT32,
4277 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004278 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004279 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004280 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004281 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004282 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4283 outputDType = rng.choice(wrong_dtypes)
4284 else:
4285 outputDType = ifm.dtype
4286
4287 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004288
4289 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004290 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004291 # input: N, IC
4292 # filter: OC, IC
4293 # output: N, OC
4294
4295 output_shape = [input.shape[0], filter.shape[0]]
4296
James Ward8b390432022-08-12 20:48:56 +01004297 # Validated in arg_gen (also invalidated for ErrorIf)
4298 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004299
Kevin Cheng550ccc52021-03-03 11:21:43 -08004300 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004301
4302 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004303 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004304 # a: N, H, C
4305 # b: N, C, W
4306 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004307
Kevin Cheng2d60f002021-06-09 14:18:32 -07004308 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004309
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004310 if error_name == ErrorIf.WrongOutputType:
4311 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004312 incorrect_types = (
4313 DType.INT4,
4314 DType.INT8,
4315 DType.INT16,
4316 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004317 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004318 DType.FP16,
4319 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004320 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004321 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004322 incorrect_types = (
4323 DType.INT4,
4324 DType.INT8,
4325 DType.INT16,
4326 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004327 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004328 DType.FP16,
4329 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004330 )
James Ward24dbc422022-10-19 12:20:31 +01004331 elif (
4332 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4333 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004334 incorrect_types = (
4335 DType.INT4,
4336 DType.INT8,
4337 DType.INT16,
4338 DType.INT32,
4339 DType.INT48,
4340 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004341 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004342 elif error_name == ErrorIf.WrongInputType:
4343 # Pick some potentially correct output dtype if input type is incorrect
4344 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004345 else:
James Ward8b390432022-08-12 20:48:56 +01004346 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004347
Kevin Cheng550ccc52021-03-03 11:21:43 -08004348 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004349
4350 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004351 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004352 input1 = a[0]
4353 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004354
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004355 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004356 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004357 if not (
4358 # unable to concat tensors of different ranks
4359 error_name == ErrorIf.ConcatInputRankMismatch
4360 # unable to concat tensors along an invalid axis
4361 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004362 ):
4363 for tensor in remaining_inputs:
4364 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004365
Matthew Haddon01c359d2021-10-15 16:30:48 +01004366 if error_name == ErrorIf.ConcatShapeSumMismatch:
4367 output_shape[axis] += rng.integers(5, 10)
4368
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004369 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004370 all_dtypes = {
4371 DType.INT8,
4372 DType.INT16,
4373 DType.INT32,
4374 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004375 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004376 DType.FP16,
4377 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004378 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004379 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4380 outputDType = rng.choice(wrong_dtypes)
4381 else:
4382 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004383
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004384 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004385
4386 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004387 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004388
4389 output_shape = a.shape.copy()
4390
4391 for i in range(len(output_shape)):
4392 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4393
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004394 if error_name == ErrorIf.PadOutputShapeMismatch:
4395 bad_dim = rng.choice(range(len(output_shape)))
4396 output_shape[bad_dim] -= rng.choice([1, 2])
4397
Matthew Haddone807aae2021-10-11 18:12:58 +01004398 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004399 all_dtypes = [
4400 DType.INT8,
4401 DType.INT16,
4402 DType.INT32,
4403 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004404 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004405 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004406 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004407 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004408 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4409 outputDType = rng.choice(wrong_dtypes)
4410 else:
4411 outputDType = a.dtype
4412
4413 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004414
4415 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004416 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004417 output_shape = shape.copy()
4418
Matthew Haddone807aae2021-10-11 18:12:58 +01004419 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4420 for i in range(len(output_shape)):
4421 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4422
4423 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004424 all_dtypes = [
4425 DType.INT8,
4426 DType.INT16,
4427 DType.INT32,
4428 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004429 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004430 DType.FP16,
4431 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004432 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004433 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4434 outputDType = rng.choice(wrong_dtypes)
4435 else:
4436 outputDType = a.dtype
4437
4438 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004439
4440 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004441 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004442
Matthew Haddone807aae2021-10-11 18:12:58 +01004443 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004444 all_dtypes = [
4445 DType.INT8,
4446 DType.INT16,
4447 DType.INT32,
4448 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004449 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004450 DType.FP16,
4451 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004452 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004453 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4454 outputDType = rng.choice(wrong_dtypes)
4455 else:
4456 outputDType = a.dtype
4457
4458 if error_name == ErrorIf.SizeOutputShapeMismatch:
4459 output_shape = size.copy()
4460 for index in range(len(output_shape)):
4461 if output_shape[index] <= 2:
4462 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4463 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004464 output_shape[index] = output_shape[index] + rng.choice(
4465 [-2, -1, 1, 2]
4466 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004467 else:
4468 output_shape = size.copy()
4469
4470 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004471
4472 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004473 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004474
4475 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004476 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004477
4478 for i in range(len(output_shape)):
4479 output_shape[i] = a.shape[i] * multiples[i]
4480
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004481 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004482 all_dtypes = [
4483 DType.INT8,
4484 DType.INT16,
4485 DType.INT32,
4486 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004487 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004488 DType.FP16,
4489 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004490 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004491 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4492 outputDType = rng.choice(wrong_dtypes)
4493 else:
4494 outputDType = a.dtype
4495
4496 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004497
4498 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004499 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004500 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004501
Kevin Cheng550ccc52021-03-03 11:21:43 -08004502 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004503
Matthew Haddone807aae2021-10-11 18:12:58 +01004504 if error_name == ErrorIf.IndexOutsideBounds:
4505 for i in range(len(output_shape)):
4506 output_shape[i] = a.shape[0]
4507 else:
4508 for i in range(len(output_shape)):
4509 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004510
Matthew Haddone807aae2021-10-11 18:12:58 +01004511 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004512 all_dtypes = [
4513 DType.INT8,
4514 DType.INT16,
4515 DType.INT32,
4516 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004517 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004518 DType.FP16,
4519 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004520 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004521 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4522 outputDType = rng.choice(wrong_dtypes)
4523 else:
4524 outputDType = a.dtype
4525
4526 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004527
4528 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004529 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004530 if error_name != ErrorIf.WrongRank:
4531 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004532 assert len(indices.shape) == 2
4533 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004534
Kevin Cheng77d0f762020-11-24 10:26:32 -08004535 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4536
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004537 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004538 all_dtypes = [
4539 DType.INT8,
4540 DType.INT16,
4541 DType.INT32,
4542 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004543 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004544 DType.FP16,
4545 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004546 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004547 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4548 outputDType = rng.choice(wrong_dtypes)
4549 else:
4550 outputDType = values.dtype
4551
4552 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004553
4554 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004555 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004556 if error_name != ErrorIf.WrongRank:
4557 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004558 assert len(indices.shape) == 2
4559 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004560 assert values_in.shape[0] == indices.shape[0] # N
4561 assert input.shape[1] == indices.shape[1] # W
4562 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004563
4564 output_shape = values_in.shape
4565
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004566 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004567 all_dtypes = [
4568 DType.INT8,
4569 DType.INT16,
4570 DType.INT32,
4571 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004572 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004573 DType.FP16,
4574 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004575 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004576 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4577 outputDType = rng.choice(wrong_dtypes)
4578 else:
4579 outputDType = values_in.dtype
4580
4581 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004582
4583 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004584 def tableOp(ser, rng, input, error_name=None):
4585 # Same shape as the input, dtype dependent on input dtype
4586 if error_name != ErrorIf.WrongInputType:
4587 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004588 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004589 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004590 wrong_dtypes = [
4591 DType.INT8,
4592 DType.INT16,
4593 DType.INT32,
4594 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004595 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004596 DType.FP16,
4597 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004598 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004599 wrong_dtypes.remove(output_dtype)
4600 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004601 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004602
4603 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004604 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004605 serializer,
4606 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004607 input,
4608 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004609 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004610 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004611 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004612 input_dtype,
4613 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004614 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004615 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004616 # Calculate OH, OW
4617 scale_y_n = scale[0]
4618 scale_y_d = scale[1]
4619 scale_x_n = scale[2]
4620 scale_x_d = scale[3]
4621 if error_name == ErrorIf.ScaleSmallerEqualZero:
4622 scale_y_n = max(scale_y_n, 1)
4623 scale_y_d = max(scale_y_d, 1)
4624 scale_x_n = max(scale_x_n, 1)
4625 scale_x_d = max(scale_x_d, 1)
4626
4627 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4628 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4629
4630 if error_name is not None:
4631 # Make sure the output tensor is valid, which can occur when
4632 # scale, offset or border have been changed for ERROR_IFs
4633 oh = max(oh, 1)
4634 ow = max(ow, 1)
4635 if error_name != ErrorIf.MaxDimExceeded:
4636 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4637 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4638
4639 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4640 choices = [1, 2, 3]
4641 change = rng.choice(choices)
4642 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4643 if change in [1, 3]:
4644 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4645 oh -= scale_y_d
4646 assert oh > 0 # Should have been caught in agResize
4647 else:
4648 oh += scale_y_d
4649 if change in [2, 3]:
4650 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4651 ow -= scale_x_d
4652 assert ow > 0 # Should have been caught in agResize
4653 else:
4654 ow += scale_x_d
4655
Matthew Haddon848efb42021-09-09 12:30:53 +01004656 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004657 output_dims = [
4658 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004659 oh,
4660 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004661 input.shape[0],
4662 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004663 elif error_name == ErrorIf.BatchMismatch:
4664 output_dims = [
4665 input.shape[0] + rng.integers(1, 10),
4666 oh,
4667 ow,
4668 input.shape[3],
4669 ]
4670 elif error_name == ErrorIf.ChannelMismatch:
4671 output_dims = [
4672 input.shape[0],
4673 oh,
4674 ow,
4675 input.shape[3] + rng.integers(1, 10),
4676 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004677 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004678 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004679
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004680 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004681
4682 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004683 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004684 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004685
4686 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004687 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004688 if error_name == ErrorIf.ConvOutputShapeMismatch:
4689 choices = [1, 2, 3]
4690 change = rng.choice(choices)
4691 if change in [1, 3]:
4692 output_shape[1] = output_shape[1] + rng.choice(choices)
4693 if change in [2, 3]:
4694 output_shape[2] = output_shape[2] + rng.choice(choices)
4695
James Ward8b390432022-08-12 20:48:56 +01004696 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004697 # Pick some potentially correct output dtype if input type is incorrect
4698 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004699 else:
James Ward8b390432022-08-12 20:48:56 +01004700 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004701
4702 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004703 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004704 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004705 else:
4706 excludes = [out_dtype]
4707 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004708 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004709
Kevin Cheng550ccc52021-03-03 11:21:43 -08004710 return ser.addOutput(output_shape, out_dtype)