blob: c29763bac36a51ff6451a73b94acbcd55cc37d42 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Jeremy Johnson05c711e2022-12-12 18:00:41 +000017from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010018from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010019from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010020from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000021from tosa.DType import DType
22from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010023
24
Eric Kunzee5e26762020-10-13 16:11:07 -070025class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010026 # Maximum rank of tensor supported by test generator.
27 TOSA_TENSOR_MAX_RANK = 6
28
Eric Kunzee5e26762020-10-13 16:11:07 -070029 def __init__(self, args):
30 self.args = args
31 self.basePath = args.output_dir
32 self.random_seed = args.random_seed
33 self.ser = None
34 self.rng = np.random.default_rng(self.random_seed)
35 self.createDynamicOpLists()
36 self.initOpListDefaults()
37 self.quantGen = TosaQuantGen()
38 # Force makeShape to do a specific starting shape
39 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010040 # Work out floating point range
41 self.random_fp_low = min(args.tensor_fp_value_range)
42 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070043
44 def createSerializer(self, opName, testPath):
45 self.testPath = os.path.join(opName, testPath)
46
47 fullPath = os.path.join(self.basePath, self.testPath)
48 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010049 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070050
51 def getSerializer(self):
52 return self.ser
53
54 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080055 with open(
56 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
57 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070058 fd.write(self.ser.serialize())
59
Kevin Cheng550ccc52021-03-03 11:21:43 -080060 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
61 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070062
Matthew Haddon74567092021-07-16 15:38:20 +010063 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000064 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010065 seed = self.random_seed + 1
66 self.rng = np.random.default_rng(seed)
67
Eric Kunzee5e26762020-10-13 16:11:07 -070068 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070069 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070071 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070072 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070073 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070074 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010075 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
76 elif dtype == DType.UINT8:
77 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070078 elif dtype == DType.INT16:
79 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010080 elif dtype == DType.UINT16:
81 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070082 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080083 return np.int32(
84 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
85 )
Eric Kunzee5e26762020-10-13 16:11:07 -070086 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080087 return np.int64(
88 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
89 )
James Ward8b390432022-08-12 20:48:56 +010090 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010091 return np.float16(
92 self.rng.uniform(
93 low=self.random_fp_low, high=self.random_fp_high, size=shape
94 )
95 )
James Ward24dbc422022-10-19 12:20:31 +010096 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010097 f32_tensor = np.float32(
98 self.rng.uniform(
99 low=self.random_fp_low, high=self.random_fp_high, size=shape
100 )
101 )
James Ward24dbc422022-10-19 12:20:31 +0100102 # Floor the last 16 bits of each f32 value
103 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100104 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100105 return np.float32(
106 self.rng.uniform(
107 low=self.random_fp_low, high=self.random_fp_high, size=shape
108 )
109 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800111 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700112
Kevin Cheng989cb052021-04-28 16:29:44 -0700113 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 placeholders = []
115
Kevin Cheng989cb052021-04-28 16:29:44 -0700116 assert len(shape_list) == len(dtype_list)
117
118 for idx, shape in enumerate(shape_list):
119 arr = self.getRandTensor(shape, dtype_list[idx])
120 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 return placeholders
123
Kevin Cheng989cb052021-04-28 16:29:44 -0700124 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700125 consts = []
126
Kevin Cheng989cb052021-04-28 16:29:44 -0700127 assert len(shape_list) == len(dtype_list)
128
129 for idx, shape in enumerate(shape_list):
130 arr = self.getRandTensor(shape, dtype_list[idx])
131 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132
133 return consts
134
135 def makeShape(self, rank):
136 if self.targetted_shape:
137 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800138 return np.int32(
139 self.rng.integers(
140 low=self.args.tensor_shape_range[0],
141 high=self.args.tensor_shape_range[1],
142 size=rank,
143 )
144 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
146 def setTargetShape(self, shape):
147 self.targetted_shape = shape
148
149 def randInt(self, low=0, high=256):
150 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
151
152 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100153 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100154 return np.float32(
155 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
156 )
James Ward8b390432022-08-12 20:48:56 +0100157 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100158 return np.float16(
159 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
160 )
James Ward24dbc422022-10-19 12:20:31 +0100161 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162 rand_f32 = np.float32(
163 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
164 )
James Ward24dbc422022-10-19 12:20:31 +0100165 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 elif dtype == DType.BOOL:
167 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700168 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700169 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700170 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100172 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 elif dtype == DType.INT16:
174 low, high = (-32768, 32768)
175 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800176 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800178 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 # Special size
180 return np.int64(self.rng.integers(low, high, size=1))[0]
181 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800182 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700183
184 return np.int32(self.rng.integers(low, high, size=1))[0]
185
186 def shapeStr(self, shape):
187
188 sStr = []
189 # Convert to strings
190 for i in shape:
191 sStr.append(str(i))
192
Kevin Cheng550ccc52021-03-03 11:21:43 -0800193 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700194
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100195 def typeStr(self, dtype):
196 if isinstance(dtype, list) or isinstance(dtype, tuple):
197 assert len(dtype) >= 2
198 strs = [self.typeStr(t) for t in dtype]
199 # Limit types to the first 2 as the 3rd is the accumulator
200 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100202 if dtype in DTYPE_ATTRIBUTES:
203 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700204 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100205 raise Exception(
206 "Unknown dtype, cannot convert to string: {}".format(dtype)
207 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100209 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100210 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100211 if dtype in DTYPE_ATTRIBUTES:
212 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
216 # Argument generators
217 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
218 # Where the string descriptor is used to generate the test name and
219 # The build_fcn_arg_list is expanded and passed to the operator test
220 # build function
221
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100222 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
223 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
224
Matthew Haddon848efb42021-09-09 12:30:53 +0100225 # build_placeholder returns an int, ABS/other ops does not
226 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000227 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100228 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000229 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000230 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100231 return result_tens
232
233 # Ensure new output type has correct qinfo
234 if error_name == ErrorIf.WrongOutputType:
235 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000236 qinfo = [
237 TosaQuantGen.getZeroPoint(self, a.dtype),
238 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
239 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100240
241 # Invalidate Input/Output list for error if checks.
242 input_list = [a.name]
243 output_list = [result_tens.name]
244 pCount, cCount = op["operands"]
245 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000246 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
247 self, error_name, input_list, output_list
248 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100249
Les Bell729b0352021-11-24 10:28:21 +0000250 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100251 self.ser,
252 validator_fcns,
253 error_name,
254 op=op,
255 input_dtype=a.dtype,
256 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000257 qinfo=qinfo,
258 result_tensor=result_tens,
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100259 input_list=input_list,
260 output_list=output_list,
261 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000262 ):
263 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100264
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000265 attr = None
266 if op["op"] == Op.NEGATE:
267 attr = ts.TosaSerializerAttribute()
268 attr.NegateAttribute(qinfo[0], qinfo[1])
269
270 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700271 return result_tens
272
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100273 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000274 result_tens = OutputShaper.binaryBroadcastOp(
275 self.ser, self.rng, a, b, error_name
276 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100277
278 # Invalidate Input/Output list for error if checks.
279 input_list = [a.name, b.name]
280 output_list = [result_tens.name]
281 pCount, cCount = op["operands"]
282 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000283 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
284 self, error_name, input_list, output_list
285 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100286
Les Bell729b0352021-11-24 10:28:21 +0000287 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100288 self.ser,
289 validator_fcns,
290 error_name,
291 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000292 input1=a,
293 input2=b,
294 input_dtype=a.dtype,
295 output_dtype=result_tens.dtype,
296 result_tensor=result_tens,
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100297 input_list=input_list,
298 output_list=output_list,
299 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000300 ):
301 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100302
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000303 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 return result_tens
305
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100306 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700307 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000308 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700309 return result_tens
310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000311 def build_arithmetic_right_shift(
312 self, op, a, b, round, validator_fcns=None, error_name=None
313 ):
314 result_tens = OutputShaper.binaryBroadcastOp(
315 self.ser, self.rng, a, b, error_name
316 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100317
318 # Invalidate Input/Output list for error if checks.
319 input_list = [a.name, b.name]
320 output_list = [result_tens.name]
321 pCount, cCount = op["operands"]
322 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000323 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
324 self, error_name, input_list, output_list
325 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100326
Les Bell729b0352021-11-24 10:28:21 +0000327 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100328 self.ser,
329 validator_fcns,
330 error_name,
331 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000332 input1=a,
333 input2=b,
334 input_dtype=a.dtype,
335 output_dtype=result_tens.dtype,
336 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100337 input_list=input_list,
338 output_list=output_list,
339 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000340 ):
341 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800342
343 attr = ts.TosaSerializerAttribute()
344 attr.ArithmeticRightShiftAttribute(round)
345
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000346 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800347 return result_tens
348
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100349 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000350 result_tens = OutputShaper.binaryBroadcastOp(
351 self.ser, self.rng, a, b, error_name
352 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700353
354 # Special for multiply:
355 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100356 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700357 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100358 if error_name == ErrorIf.WrongOutputType:
359 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
360 outputDType = self.rng.choice(all_dtypes)
361 result_tens.setDtype(outputDType)
362
363 # Invalidate Input/Output list for error if checks.
364 input_list = [a.name, b.name]
365 output_list = [result_tens.name]
366 pCount, cCount = op["operands"]
367 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000368 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
369 self, error_name, input_list, output_list
370 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100371
Les Bell729b0352021-11-24 10:28:21 +0000372 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100373 self.ser,
374 validator_fcns,
375 error_name,
376 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000377 input1=a,
378 input2=b,
379 input_dtype=a.dtype,
380 output_dtype=result_tens.dtype,
381 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100382 input_list=input_list,
383 output_list=output_list,
384 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000385 ):
386 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700387
Kevin Chengaee1fac2020-11-11 13:54:06 -0800388 attr = ts.TosaSerializerAttribute()
389 attr.MulAttribute(shift)
390
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000391 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 return result_tens
393
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100394 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
395 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700396
Kevin Chengfe392ce2021-10-18 21:51:55 +0000397 attr = ts.TosaSerializerAttribute()
398 attr.TableAttribute(table)
399
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100400 # Invalidate Input/Output list for error if checks.
401 input_list = [a.name]
402 output_list = [result_tens.name]
403 pCount, cCount = op["operands"]
404 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000405 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
406 self, error_name, input_list, output_list
407 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100408
Les Bell729b0352021-11-24 10:28:21 +0000409 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410 self.ser,
411 validator_fcns,
412 error_name,
413 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000414 input_shape=a.shape,
415 input_dtype=a.dtype,
416 output_dtype=result_tens.dtype,
417 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100418 input_list=input_list,
419 output_list=output_list,
420 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000421 ):
422 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100423
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000424 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700425
426 return result_tens
427
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100428 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
429 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
430
431 # Invalidate Input/Output list for error if checks.
432 input_list = [cond.name, a.name, b.name]
433 output_list = [result_tens.name]
434 pCount, cCount = op["operands"]
435 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
437 self, error_name, input_list, output_list
438 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100439
Les Bell729b0352021-11-24 10:28:21 +0000440 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100441 self.ser,
442 validator_fcns,
443 error_name,
444 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000445 input1=cond,
446 input2=a,
447 input3=b,
448 input_shape=a.shape,
449 input_dtype=a.dtype,
450 output_dtype=result_tens.dtype,
451 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100452 input_list=input_list,
453 output_list=output_list,
454 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000455 ):
456 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000458 self.ser.addOperator(
459 op["op"],
460 input_list,
461 output_list,
462 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700463 return result_tens
464
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100465 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000466 result_tens = OutputShaper.binaryComparisonOp(
467 self.ser, self.rng, a, b, error_name
468 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100469
470 # Invalidate Input/Output list for error if checks.
471 input_list = [a.name, b.name]
472 output_list = [result_tens.name]
473 pCount, cCount = op["operands"]
474 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000475 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
476 self, error_name, input_list, output_list
477 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478
Les Bell729b0352021-11-24 10:28:21 +0000479 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100480 self.ser,
481 validator_fcns,
482 error_name,
483 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000484 input1=a,
485 input2=b,
486 input_shape=a.shape,
487 input_dtype=a.dtype,
488 output_shape=result_tens.shape,
489 output_dtype=result_tens.dtype,
490 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100491 input_list=input_list,
492 output_list=output_list,
493 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000494 ):
495 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100496
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000497 self.ser.addOperator(
498 op["op"],
499 input_list,
500 output_list,
501 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700502 return result_tens
503
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100504 def build_argmax(self, op, a, axis, validator_fcns, error_name):
505 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
506
507 # Invalidate Input/Output list for error if checks.
508 input_list = [a.name]
509 output_list = [result_tens.name]
510 pCount, cCount = op["operands"]
511 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000512 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
513 self, error_name, input_list, output_list
514 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100515
Les Bell729b0352021-11-24 10:28:21 +0000516 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100517 self.ser,
518 validator_fcns,
519 error_name,
520 op=op,
521 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000522 input_shape=a.shape,
523 input_dtype=a.dtype,
524 output_shape=result_tens.shape,
525 output_dtype=result_tens.dtype,
526 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100527 input_list=input_list,
528 output_list=output_list,
529 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000530 ):
531 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700532
533 attr = ts.TosaSerializerAttribute()
534 attr.AxisAttribute(axis)
535
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700537 return result_tens
538
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000539 def build_pool2d(
540 self,
541 op,
542 input,
James Ward8b390432022-08-12 20:48:56 +0100543 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000544 stride,
545 pad,
546 kernel,
547 validator_fcns=None,
548 error_name=None,
549 qinfo=None,
550 ):
551 result_tens = OutputShaper.pool2dOp(
552 self.ser, self.rng, input, kernel, stride, pad, error_name
553 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100554
555 # Ensure new output type has correct qinfo
556 if error_name == ErrorIf.WrongInputType:
557 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000558 qinfo = [
559 TosaQuantGen.getZeroPoint(self, input.dtype),
560 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
561 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100562
563 # Invalidate Input/Output list for error if checks.
564 input_list = [input.name]
565 output_list = [result_tens.name]
566 pCount, cCount = op["operands"]
567 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000568 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
569 self, error_name, input_list, output_list
570 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100571
Les Bell729b0352021-11-24 10:28:21 +0000572 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100573 self.ser,
574 validator_fcns,
575 error_name,
576 op=op,
577 input_shape=input.shape,
578 input_dtype=input.dtype,
579 output_shape=result_tens.shape,
580 output_dtype=result_tens.dtype,
581 kernel=kernel,
582 stride=stride,
583 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 qinfo=qinfo,
585 result_tensor=result_tens,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100586 input_list=input_list,
587 output_list=output_list,
588 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000589 ):
590 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700591
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000592 if qinfo is None:
593 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700594
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000595 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100596 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000597
598 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700599 return result_tens
600
James Ward8b390432022-08-12 20:48:56 +0100601 def build_maxpool2d(
602 self,
603 op,
604 input,
605 stride,
606 pad,
607 kernel,
608 validator_fcns=None,
609 error_name=None,
610 qinfo=None,
611 ):
612 # Same as build_pool2d but manually sets accum_dtype value
613 # (maxpool has no accum_dtype)
614 return self.build_pool2d(
615 op,
616 input,
617 DType.UNKNOWN,
618 stride,
619 pad,
620 kernel,
621 validator_fcns,
622 error_name,
623 qinfo,
624 )
625
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 def build_conv2d(
627 self,
628 op,
629 ifm,
630 filter,
631 bias,
James Ward8b390432022-08-12 20:48:56 +0100632 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000633 strides,
634 padding,
635 dilations,
636 validator_fcns=None,
637 error_name=None,
638 qinfo=None,
639 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800640 assert len(padding) == 4
641 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100642 self.ser,
643 self.rng,
644 ifm,
645 filter,
646 accum_dtype,
647 strides,
648 padding,
649 dilations,
650 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000651 )
652
653 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
655 DType.INT8,
656 DType.UINT8,
657 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000658 qinfo = [
659 TosaQuantGen.getZeroPoint(self, ifm.dtype),
660 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
661 ]
Les Bell0e027d42021-11-09 14:42:14 +0000662
663 # Invalidate Input/Output list for error_if checks.
664 input_list = [ifm.name, filter.name, bias.name]
665 output_list = [result_tens.name]
666 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
668 self, error_name, input_list, output_list
669 )
Les Bell0e027d42021-11-09 14:42:14 +0000670
Les Bell729b0352021-11-24 10:28:21 +0000671 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000672 self.ser,
673 validator_fcns,
674 error_name,
675 op=op,
676 input_dtype=ifm.dtype,
677 weight_dtype=filter.dtype,
678 output_dtype=result_tens.dtype,
679 qinfo=qinfo,
680 input_list=input_list,
681 num_operands=num_operands,
682 output_list=output_list,
683 pad=padding,
684 stride=strides,
685 dilation=dilations,
686 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100687 weight_shape=filter.shape,
688 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000689 ):
690 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700691
692 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100693 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700694
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000695 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700696 return result_tens
697
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 def build_conv3d(
699 self,
700 op,
701 ifm,
702 filter,
703 bias,
James Ward8b390432022-08-12 20:48:56 +0100704 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000705 strides,
706 padding,
707 dilations,
708 validator_fcns=None,
709 error_name=None,
710 qinfo=None,
711 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700712 assert len(padding) == 6
713 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100714 self.ser,
715 self.rng,
716 ifm,
717 filter,
718 accum_dtype,
719 strides,
720 padding,
721 dilations,
722 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000723 )
724
725 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000726 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
727 DType.INT8,
728 DType.UINT8,
729 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000730 qinfo = [
731 TosaQuantGen.getZeroPoint(self, ifm.dtype),
732 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
733 ]
Les Bell0e027d42021-11-09 14:42:14 +0000734
735 # Invalidate Input/Output list for error_if checks.
736 input_list = [ifm.name, filter.name, bias.name]
737 output_list = [result_tens.name]
738 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000739 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
740 self, error_name, input_list, output_list
741 )
Les Bell0e027d42021-11-09 14:42:14 +0000742
Les Bell729b0352021-11-24 10:28:21 +0000743 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000744 self.ser,
745 validator_fcns,
746 error_name,
747 op=op,
748 input_dtype=ifm.dtype,
749 weight_dtype=filter.dtype,
750 output_dtype=result_tens.dtype,
751 qinfo=qinfo,
752 input_list=input_list,
753 num_operands=num_operands,
754 output_list=output_list,
755 pad=padding,
756 stride=strides,
757 dilation=dilations,
758 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100759 weight_shape=filter.shape,
760 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000761 ):
762 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700763
764 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100765 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Kevin Cheng1533b852021-09-01 12:51:58 -0700766
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000767 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700768 return result_tens
769
Kevin Cheng550ccc52021-03-03 11:21:43 -0800770 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000771 self,
772 op,
773 ifm,
774 filter,
775 bias,
James Ward8b390432022-08-12 20:48:56 +0100776 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000777 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700778 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000779 output_shape,
780 validator_fcns=None,
781 error_name=None,
782 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800783 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700784 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000785 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100786 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 )
Les Bell0e027d42021-11-09 14:42:14 +0000788
789 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000790 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
791 DType.INT8,
792 DType.UINT8,
793 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000794 qinfo = [
795 TosaQuantGen.getZeroPoint(self, ifm.dtype),
796 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
797 ]
Les Bell0e027d42021-11-09 14:42:14 +0000798
799 # Invalidate Input/Output list for error_if checks.
800 input_list = [ifm.name, filter.name, bias.name]
801 output_list = [result_tens.name]
802 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
804 self, error_name, input_list, output_list
805 )
Les Bell0e027d42021-11-09 14:42:14 +0000806
Les Bell729b0352021-11-24 10:28:21 +0000807 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000808 self.ser,
809 validator_fcns,
810 error_name,
811 op=op,
812 input_dtype=ifm.dtype,
813 weight_dtype=filter.dtype,
814 output_dtype=result_tens.dtype,
815 qinfo=qinfo,
816 input_list=input_list,
817 num_operands=num_operands,
818 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700819 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000820 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000821 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100822 weight_shape=filter.shape,
823 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000824 ):
825 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700826
827 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100828 attr.TransposeConvAttribute(
829 out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype
830 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700831
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000832 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700833 return result_tens
834
Kevin Cheng550ccc52021-03-03 11:21:43 -0800835 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000836 self,
837 op,
838 ifm,
839 filter,
840 bias,
James Ward8b390432022-08-12 20:48:56 +0100841 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000842 strides,
843 padding,
844 dilations,
845 validator_fcns=None,
846 error_name=None,
847 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800848 ):
849 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100850 self.ser,
851 self.rng,
852 ifm,
853 filter,
854 accum_dtype,
855 strides,
856 padding,
857 dilations,
858 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000859 )
860
861 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000862 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
863 DType.INT8,
864 DType.UINT8,
865 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000866 qinfo = [
867 TosaQuantGen.getZeroPoint(self, ifm.dtype),
868 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
869 ]
Les Bell0e027d42021-11-09 14:42:14 +0000870
871 # Invalidate Input/Output list for error_if checks.
872 input_list = [ifm.name, filter.name, bias.name]
873 output_list = [result_tens.name]
874 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000875 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
876 self, error_name, input_list, output_list
877 )
Les Bell0e027d42021-11-09 14:42:14 +0000878
Les Bell729b0352021-11-24 10:28:21 +0000879 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000880 self.ser,
881 validator_fcns,
882 error_name,
883 op=op,
884 input_dtype=ifm.dtype,
885 weight_dtype=filter.dtype,
886 output_dtype=result_tens.dtype,
887 qinfo=qinfo,
888 input_list=input_list,
889 num_operands=num_operands,
890 output_list=output_list,
891 pad=padding,
892 stride=strides,
893 dilation=dilations,
894 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100895 weight_shape=filter.shape,
896 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000897 ):
898 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700899
900 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100901 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700902
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000903 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700904 return result_tens
905
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000906 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100907 self,
908 op,
909 ifm,
910 filter,
911 bias,
912 accum_dtype,
913 validator_fcns=None,
914 error_name=None,
915 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000916 ):
917 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100918 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000919 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100920
921 # Invalidate Input/Output list for error if checks.
922 input_list = [ifm.name, filter.name, bias.name]
923 output_list = [result_tens.name]
924 pCount, cCount = op["operands"]
925 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000926 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
927 self, error_name, input_list, output_list
928 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100929
Les Bell729b0352021-11-24 10:28:21 +0000930 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100931 self.ser,
932 validator_fcns,
933 error_name,
934 op=op,
935 input_shape=ifm.shape,
936 input_dtype=ifm.dtype,
937 weight_dtype=filter.dtype,
938 output_shape=result_tens.shape,
939 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000940 qinfo=qinfo,
941 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100942 input_list=input_list,
943 output_list=output_list,
944 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100945 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000946 ):
947 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700948
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000949 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100950 attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000951
952 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700953 return result_tens
954
James Ward8b390432022-08-12 20:48:56 +0100955 def build_matmul(
956 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
957 ):
958 result_tens = OutputShaper.matmulOp(
959 self.ser, self.rng, a, b, accum_dtype, error_name
960 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100961
962 # Invalidate Input/Output list for error if checks.
963 input_list = [a.name, b.name]
964 output_list = [result_tens.name]
965 pCount, cCount = op["operands"]
966 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
968 self, error_name, input_list, output_list
969 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100970
Les Bell729b0352021-11-24 10:28:21 +0000971 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100972 self.ser,
973 validator_fcns,
974 error_name,
975 op=op,
976 input_shape=a.shape,
977 input_dtype=a.dtype,
978 input2_shape=b.shape,
979 input2_dtype=b.dtype,
980 output_shape=result_tens.shape,
981 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000982 qinfo=qinfo,
983 result_tensor=result_tens,
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100984 input_list=input_list,
985 output_list=output_list,
986 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100987 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000988 ):
989 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100990
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000991 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100992 attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000993
994 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700995 return result_tens
996
Matthew Haddond6ce7252021-09-29 15:35:44 +0100997 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
998 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
999
1000 # Invalidate Input/Output list for error if checks.
1001 input_list = [a.name]
1002 output_list = [result_tens.name]
1003 pCount, cCount = op["operands"]
1004 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001005 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1006 self, error_name, input_list, output_list
1007 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001008
Les Bell729b0352021-11-24 10:28:21 +00001009 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001010 self.ser,
1011 validator_fcns,
1012 error_name,
1013 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001014 axis=axis,
1015 input_shape=a.shape,
1016 output_shape=result_tens.shape,
1017 input_dtype=a.dtype,
1018 output_dtype=result_tens.dtype,
1019 result_tensor=result_tens,
Matthew Haddond6ce7252021-09-29 15:35:44 +01001020 input_list=input_list,
1021 output_list=output_list,
1022 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001023 ):
1024 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001025
1026 attr = ts.TosaSerializerAttribute()
1027 attr.AxisAttribute(axis)
1028
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001029 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001030 return result_tens
1031
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001032 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1033 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001034
Jeremy Johnson18e26662021-07-22 16:15:29 +01001035 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001036
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001037 if error_name == ErrorIf.MaxSmallerMin:
1038 # Make sure the numbers are different to invoke this error
1039 while v[0] == v[1]:
1040 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1041 max_val = min(v)
1042 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001043 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001044 max_val = max(v)
1045 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001046
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001047 # Invalidate Input/Output list for error if checks.
1048 input_list = [a.name]
1049 output_list = [result_tens.name]
1050 pCount, cCount = op["operands"]
1051 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001052 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1053 self, error_name, input_list, output_list
1054 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001055
Les Bell729b0352021-11-24 10:28:21 +00001056 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001057 self.ser,
1058 validator_fcns,
1059 error_name,
1060 op=op,
1061 max_val=max_val,
1062 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001063 input_shape=a.shape,
1064 output_shape=result_tens.shape,
1065 input_dtype=a.dtype,
1066 output_dtype=result_tens.dtype,
1067 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001068 input_list=input_list,
1069 output_list=output_list,
1070 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001071 ):
1072 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001073
1074 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001075 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1076 if a.dtype == DType.FP16:
1077 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1078 min_val = min_val.astype(np.float32)
1079 max_val = max_val.astype(np.float32)
1080
1081 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001082 else:
James Ward34071252022-12-07 15:48:47 +00001083 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001084
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001085 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001086 return result_tens
1087
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001088 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1089 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001090 attr = ts.TosaSerializerAttribute()
1091
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001092 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001093
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001094 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001095 return result_tens
1096
1097 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001098 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1099 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001100
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001101 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001102 return result_tens
1103
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001104 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1105 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1106
1107 # Invalidate Input/Output list for error if checks.
1108 input_list = [a.name]
1109 output_list = [result_tens.name]
1110 pCount, cCount = op["operands"]
1111 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001112 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1113 self, error_name, input_list, output_list
1114 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001115
Les Bell729b0352021-11-24 10:28:21 +00001116 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001117 self.ser,
1118 validator_fcns,
1119 error_name,
1120 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001121 input_shape=a.shape,
1122 output_shape=result_tens.shape,
1123 input_dtype=a.dtype,
1124 output_dtype=result_tens.dtype,
1125 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001126 input_list=input_list,
1127 output_list=output_list,
1128 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001129 ):
1130 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001131
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001132 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001133 return result_tens
1134
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001135 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1136 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1137
1138 # Invalidate Input/Output list for error if checks.
1139 input_list = [a.name]
1140 output_list = [result_tens.name]
1141 pCount, cCount = op["operands"]
1142 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001143 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1144 self, error_name, input_list, output_list
1145 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001146
Les Bell729b0352021-11-24 10:28:21 +00001147 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001148 self.ser,
1149 validator_fcns,
1150 error_name,
1151 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001152 input_shape=a.shape,
1153 output_shape=result_tens.shape,
1154 input_dtype=a.dtype,
1155 output_dtype=result_tens.dtype,
1156 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001157 input_list=input_list,
1158 output_list=output_list,
1159 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001160 ):
1161 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001162
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001163 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001164 return result_tens
1165
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001166 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1167 if error_name != ErrorIf.WrongInputType:
1168 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001169
1170 # To store variable length list of input tensors we need to store axis along with it
1171 axis = a[-1]
1172 a = a[:-1]
1173
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001174 result_tens = OutputShaper.concatOp(
1175 self.ser, self.rng, axis, *a, error_name=error_name
1176 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001177
Matthew Haddon818ab902021-07-27 09:12:49 +01001178 input_tensor_names = []
1179 for tensor in a:
1180 input_tensor_names.append(tensor.name)
1181
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001182 # Invalidate Input/Output list for error if checks.
1183 input_list = input_tensor_names
1184 output_list = [result_tens.name]
1185 pCount, cCount = op["operands"]
1186 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001187 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1188 self, error_name, input_list, output_list
1189 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001190
Les Bell729b0352021-11-24 10:28:21 +00001191 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001192 self.ser,
1193 validator_fcns,
1194 error_name,
1195 op=op,
1196 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 input_shape=a[0].shape,
1198 output_shape=result_tens.shape,
1199 input_dtype=a[0].dtype,
1200 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001201 inputs=a,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001202 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001203 input_list=input_list,
1204 output_list=output_list,
1205 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001206 ):
1207 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001208
1209 attr = ts.TosaSerializerAttribute()
1210 attr.AxisAttribute(axis)
1211
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001212 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001213 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001214
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001215 def build_pad(
1216 self,
1217 op,
1218 a,
1219 padding,
1220 pad_const_int,
1221 pad_const_float,
1222 validator_fcns=None,
1223 error_name=None,
1224 qinfo=None,
1225 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001226 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001227
Kevin Chengfe392ce2021-10-18 21:51:55 +00001228 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001229 attr.PadAttribute(
1230 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1231 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001232
Matthew Haddone807aae2021-10-11 18:12:58 +01001233 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001234 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001235 output_list = [result_tens.name]
1236 pCount, cCount = op["operands"]
1237 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001238 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1239 self, error_name, input_list, output_list
1240 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001241
Les Bell729b0352021-11-24 10:28:21 +00001242 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001243 self.ser,
1244 validator_fcns,
1245 error_name,
1246 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001247 input_shape=a.shape,
1248 output_shape=result_tens.shape,
1249 input_dtype=a.dtype,
1250 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001251 pad=padding,
1252 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001253 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001254 input_list=input_list,
1255 output_list=output_list,
1256 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001257 ):
1258 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001259
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001260 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001261 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001262
Matthew Haddone807aae2021-10-11 18:12:58 +01001263 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001264 result_tens = OutputShaper.reshapeOp(
1265 self.ser, self.rng, a, newShape, error_name
1266 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001267
1268 # Invalidate Input/Output list for error if checks.
1269 input_list = [a.name]
1270 output_list = [result_tens.name]
1271 pCount, cCount = op["operands"]
1272 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001273 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1274 self, error_name, input_list, output_list
1275 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001276
Les Bell729b0352021-11-24 10:28:21 +00001277 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001278 self.ser,
1279 validator_fcns,
1280 error_name,
1281 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001282 input_shape=a.shape,
1283 output_shape=result_tens.shape,
1284 input_dtype=a.dtype,
1285 output_dtype=result_tens.dtype,
1286 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001287 input_list=input_list,
1288 output_list=output_list,
1289 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001290 ):
1291 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001292
1293 attr = ts.TosaSerializerAttribute()
1294 attr.ReshapeAttribute(newShape)
1295
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001296 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001297 return result_tens
1298
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001299 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1300 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1301
1302 # Invalidate Input/Output list for error if checks.
1303 input_list = [a.name]
1304 output_list = [result_tens.name]
1305 pCount, cCount = op["operands"]
1306 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001307 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1308 self, error_name, input_list, output_list
1309 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001310
Les Bell729b0352021-11-24 10:28:21 +00001311 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001312 self.ser,
1313 validator_fcns,
1314 error_name,
1315 op=op,
1316 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001317 input_shape=a.shape,
1318 output_shape=result_tens.shape,
1319 input_dtype=a.dtype,
1320 output_dtype=result_tens.dtype,
1321 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001322 input_list=input_list,
1323 output_list=output_list,
1324 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001325 ):
1326 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001327
1328 attr = ts.TosaSerializerAttribute()
1329 attr.AxisAttribute(axis)
1330
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001331 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001332 return result_tens
1333
Matthew Haddone807aae2021-10-11 18:12:58 +01001334 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1335 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001336
Kevin Chengfe392ce2021-10-18 21:51:55 +00001337 attr = ts.TosaSerializerAttribute()
1338 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001339
Matthew Haddone807aae2021-10-11 18:12:58 +01001340 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001341 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001342 output_list = [result_tens.name]
1343 pCount, cCount = op["operands"]
1344 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001345 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1346 self, error_name, input_list, output_list
1347 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001348
Les Bell729b0352021-11-24 10:28:21 +00001349 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001350 self.ser,
1351 validator_fcns,
1352 error_name,
1353 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001354 input_shape=a.shape,
1355 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001356 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001357 input_dtype=a.dtype,
1358 output_dtype=result_tens.dtype,
1359 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001360 input_list=input_list,
1361 output_list=output_list,
1362 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001363 ):
1364 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001365
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001366 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001367 return result_tens
1368
Matthew Haddone807aae2021-10-11 18:12:58 +01001369 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001370 result_tens = OutputShaper.sliceOp(
1371 self.ser, self.rng, a, start, size, error_name
1372 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001373
1374 # Invalidate Input/Output list for error if checks.
1375 input_list = [a.name]
1376 output_list = [result_tens.name]
1377 pCount, cCount = op["operands"]
1378 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001379 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1380 self, error_name, input_list, output_list
1381 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001382
Les Bell729b0352021-11-24 10:28:21 +00001383 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001384 self.ser,
1385 validator_fcns,
1386 error_name,
1387 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001388 input_shape=a.shape,
1389 output_shape=result_tens.shape,
1390 input_dtype=a.dtype,
1391 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001392 start=start,
1393 size=size,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001394 result_tensor=result_tens,
Matthew Haddone807aae2021-10-11 18:12:58 +01001395 input_list=input_list,
1396 output_list=output_list,
1397 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001398 ):
1399 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001400
1401 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001402 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001403
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001404 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001405 return result_tens
1406
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001407 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1408 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1409
1410 # Invalidate Input/Output list for error if checks.
1411 input_list = [a.name]
1412 output_list = [result_tens.name]
1413 pCount, cCount = op["operands"]
1414 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001415 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1416 self, error_name, input_list, output_list
1417 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001418
Les Bell729b0352021-11-24 10:28:21 +00001419 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001420 self.ser,
1421 validator_fcns,
1422 error_name,
1423 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001424 input_shape=a.shape,
1425 output_shape=result_tens.shape,
1426 input_dtype=a.dtype,
1427 output_dtype=result_tens.dtype,
1428 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001429 input_list=input_list,
1430 output_list=output_list,
1431 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001432 ):
1433 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001434
1435 attr = ts.TosaSerializerAttribute()
1436 attr.TileAttribute(multiples)
1437
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001438 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001439 return result_tens
1440
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001441 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001442
1443 # Create a new indicies tensor
1444 # here with data that doesn't exceed the dimensions of the values tensor
1445
Kevin Cheng550ccc52021-03-03 11:21:43 -08001446 K = values.shape[1] # K
1447 W = self.randInt(
1448 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1449 ) # W
1450 indicies_arr = np.int32(
1451 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1452 ) # (N, W)
1453 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001454
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001455 result_tens = OutputShaper.gatherOp(
1456 self.ser, self.rng, values, indicies, error_name
1457 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001458
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001459 # Invalidate Input/Output list for error if checks.
1460 input_list = [values.name, indicies.name]
1461 output_list = [result_tens.name]
1462 pCount, cCount = op["operands"]
1463 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001464 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1465 self, error_name, input_list, output_list
1466 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001467
Les Bell729b0352021-11-24 10:28:21 +00001468 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001469 self.ser,
1470 validator_fcns,
1471 error_name,
1472 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001473 input_shape=values.shape,
1474 output_shape=result_tens.shape,
1475 input_dtype=values.dtype,
1476 output_dtype=result_tens.dtype,
1477 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001478 input_list=input_list,
1479 output_list=output_list,
1480 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001481 ):
1482 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001483
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001484 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001485
1486 return result_tens
1487
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001488 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001489
1490 # Create a new indicies tensor
1491 # here with data that doesn't exceed the dimensions of the values_in tensor
1492
Kevin Cheng550ccc52021-03-03 11:21:43 -08001493 K = values_in.shape[1] # K
1494 W = input.shape[1] # W
1495 indicies_arr = np.int32(
1496 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1497 ) # (N, W)
1498 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001499
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001500 result_tens = OutputShaper.scatterOp(
1501 self.ser, self.rng, values_in, indicies, input, error_name
1502 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001503
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001504 # Invalidate Input/Output list for error if checks.
1505 input_list = [values_in.name, indicies.name, input.name]
1506 output_list = [result_tens.name]
1507 pCount, cCount = op["operands"]
1508 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001509 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1510 self, error_name, input_list, output_list
1511 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001512
Les Bell729b0352021-11-24 10:28:21 +00001513 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001514 self.ser,
1515 validator_fcns,
1516 error_name,
1517 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001518 input_shape=values_in.shape,
1519 output_shape=result_tens.shape,
1520 input_dtype=values_in.dtype,
1521 output_dtype=result_tens.dtype,
1522 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001523 input_list=input_list,
1524 output_list=output_list,
1525 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001526 ):
1527 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001528
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001529 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001530
Kevin Cheng77d0f762020-11-24 10:26:32 -08001531 return result_tens
1532
Kevin Cheng550ccc52021-03-03 11:21:43 -08001533 def build_resize(
1534 self,
1535 op,
1536 input,
1537 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001538 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001539 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001540 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001541 input_dtype,
1542 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001543 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001544 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001545 ):
1546 result_tens = OutputShaper.resizeOp(
1547 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001548 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001549 input,
1550 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001551 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001552 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001553 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001554 input_dtype,
1555 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001556 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001557 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001558
Matthew Haddon848efb42021-09-09 12:30:53 +01001559 # Invalidate Input/Output list for error if checks.
1560 input_list = [input.name]
1561 output_list = [result_tens.name]
1562 pCount, cCount = op["operands"]
1563 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001564 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1565 self, error_name, input_list, output_list
1566 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001567
Les Bell729b0352021-11-24 10:28:21 +00001568 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001569 self.ser,
1570 validator_fcns,
1571 error_name,
1572 op=op,
1573 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001574 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001575 input_dtype=input_dtype,
1576 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001577 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001578 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001579 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001580 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001581 input_list=input_list,
1582 output_list=output_list,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001583 result_tensor=result_tens,
Matthew Haddon848efb42021-09-09 12:30:53 +01001584 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001585 ):
1586 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001587
Eric Kunzee5e26762020-10-13 16:11:07 -07001588 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001589
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001590 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001591
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001592 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001593 return result_tens
1594
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001595 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1596 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1597 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001598 self.ser.addOperator(
1599 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1600 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001601 return result_tens
1602
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001603 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001604 self.ser.addOutputTensor(val)
1605 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001606
1607 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001608 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001609 result_tens = OutputShaper.typeConversionOp(
1610 self.ser, self.rng, val, out_dtype, error_name
1611 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001612
1613 # Invalidate Input/Output list for error if checks.
1614 input_list = [val.name]
1615 output_list = [result_tens.name]
1616 pCount, cCount = op["operands"]
1617 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001618 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1619 self, error_name, input_list, output_list
1620 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001621
Les Bell729b0352021-11-24 10:28:21 +00001622 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001623 self.ser,
1624 validator_fcns,
1625 error_name,
1626 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001627 input_shape=val.shape,
1628 output_shape=result_tens.shape,
1629 input_dtype=val.dtype,
1630 output_dtype=result_tens.dtype,
1631 result_tensor=result_tens,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001632 input_list=input_list,
1633 output_list=output_list,
1634 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001635 ):
1636 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001637
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001638 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001639 return result_tens
1640
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001641 def build_rescale(
1642 self,
1643 op,
1644 val,
1645 out_dtype,
1646 scale32,
1647 double_round,
1648 per_channel,
1649 validator_fcns,
1650 error_name,
1651 ):
1652 result_tens = OutputShaper.typeConversionOp(
1653 self.ser, self.rng, val, out_dtype, error_name
1654 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001655
1656 if per_channel:
1657 nc = val.shape[-1]
1658 else:
1659 nc = 1
1660
1661 in_type_width = self.typeWidth(val.dtype)
1662 out_type_width = self.typeWidth(out_dtype)
1663
Kevin Cheng3a478572021-01-22 17:21:02 -08001664 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001665 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001666 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001667 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001668 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001669 in_type_width += 1
1670 elif error_name in [
1671 ErrorIf.InputZeroPointNotZero,
1672 ErrorIf.U16InputZeroPointNotValid,
1673 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001674 input_zp = self.randInt(-128, 128)
1675 if input_zp == 0:
1676 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001677 in_type_width += 1
1678 elif val.dtype == DType.UINT16:
1679 # Must come after ErrorIf.U16InputZeroPointNotValid check
1680 input_zp = self.rng.choice([0, 32768])
1681 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001682 else:
1683 input_zp = 0
1684
Kevin Cheng3a478572021-01-22 17:21:02 -08001685 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001686 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001687 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001688 elif out_dtype == DType.UINT8:
1689 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001690 out_type_width += 1
1691 elif error_name in [
1692 ErrorIf.OutputZeroPointNotZero,
1693 ErrorIf.U16OutputZeroPointNotValid,
1694 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001695 output_zp = self.randInt(-128, 128)
1696 if output_zp == 0:
1697 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001698 out_type_width += 1
1699 elif out_dtype == DType.UINT16:
1700 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1701 output_zp = self.rng.choice([0, 32768])
1702 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001703 else:
1704 output_zp = 0
1705
1706 # Calculate scale based on:
1707 # scale = a *(2^output_width)/(2^input_width))
1708
1709 a = np.float32(self.rng.random(size=[nc]))
1710 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1711
1712 if scale32:
1713 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001714 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001715 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1716 else:
1717 # Cap the scaling at 2^15 - 1 for scale16
1718 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1719
Kevin Cheng550ccc52021-03-03 11:21:43 -08001720 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001721
1722 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1723 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001724 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1725 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
1727 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001728 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1729 scale_arr[i], scale32
1730 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001731 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1732 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001733
Kevin Cheng550ccc52021-03-03 11:21:43 -08001734 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001735 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001736 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001737 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001738 assert val.placeholderFilename
1739 values = np.load(
1740 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1741 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001742 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1743 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1744 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1745 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001746 if not np.all(np.array_equal(values, val_adj)):
1747 # Values changed so overwrite file with new values
1748 np.save(
1749 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1750 val_adj,
1751 False,
1752 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001753
Matthew Haddonc2025212021-10-08 21:21:05 +01001754 # Invalidate Input/Output list for error if checks.
1755 input_list = [val.name]
1756 output_list = [result_tens.name]
1757 pCount, cCount = op["operands"]
1758 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001759 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1760 self, error_name, input_list, output_list
1761 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001762
1763 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001764 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001765 self.ser,
1766 validator_fcns,
1767 error_name,
1768 op=op,
1769 input_dtype=val.dtype,
1770 output_dtype=out_dtype,
1771 input_shape=val.shape,
1772 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 scale32=scale32,
1774 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001775 input_list=input_list,
1776 output_list=output_list,
1777 result_tensor=result_tens,
1778 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001779 ):
1780 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001781
Eric Kunzee5e26762020-10-13 16:11:07 -07001782 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001783 attr.RescaleAttribute(
1784 input_zp,
1785 output_zp,
1786 multiplier_arr,
1787 shift_arr,
1788 scale32,
1789 double_round,
1790 per_channel,
1791 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001792
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001793 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001794 return result_tens
1795
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001796 def _get_condition_tensor(self, op, cond, error_name):
1797 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1798 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1799 else:
1800 cond_type = DType.BOOL
1801 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1802 choice = self.rng.choice([1, 2])
1803 if choice == 1:
1804 cond_shape = [2]
1805 else:
1806 cond_shape = [1, 2]
1807 else:
1808 # Must be of size 1 (rank 0)
1809 cond_shape = []
1810 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1811 return cond_tens
1812
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 def build_cond_if_const(
1814 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1815 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001816 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001817 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001818 # and fill them with const nodes for the body.
1819
1820 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001821 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001822
1823 # Make then/else tensors
1824 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001825
1826 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 if error_name in [
1828 ErrorIf.CondIfOutputListThenGraphMismatch,
1829 ErrorIf.CondIfOutputListElseGraphMismatch,
1830 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001831 incorrect_shape = deepcopy(then_tens.shape)
1832 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001833 incorrect_shape[i] += (
1834 self.rng.choice([-3, -2, 2, 3])
1835 if incorrect_shape[i] > 3
1836 else self.rng.choice([1, 2, 4])
1837 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001838 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1839
Jeremy Johnson18e26662021-07-22 16:15:29 +01001840 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1841 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001842
1843 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001844 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001845
1846 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001847 then_block = "THEN_BLOCK"
1848 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001849 attr = ts.TosaSerializerAttribute()
1850 attr.CondIfAttribute(then_block, else_block)
1851
1852 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001853 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001854
Jerry Ge9e94af82022-10-27 09:57:00 -07001855 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001856 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001857 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1858 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1859 else:
1860 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001861 self.ser.addOutputTensor(then_tens)
1862
Jerry Ge9e94af82022-10-27 09:57:00 -07001863 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001864 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1865 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1866 else:
1867 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001868 self.ser.addOutputTensor(else_tens)
1869
Les Bell729b0352021-11-24 10:28:21 +00001870 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001871 self.ser,
1872 validator_fcns,
1873 error_name,
1874 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001875 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001876 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001877 ):
1878 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001879
Eric Kunzee5e26762020-10-13 16:11:07 -07001880 return result_tens
1881
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001882 def build_cond_if_binary(
1883 self, op, a, b, cond, validator_fcns=None, error_name=None
1884 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001885 # For cond_if with a binary op in the then/else blocks, take a and b and
1886 # alternately add or subtract them based on the condition
1887
1888 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001889 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001890
Kevin Cheng550ccc52021-03-03 11:21:43 -08001891 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001892
1893 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001894 then_block = "THEN_BLOCK"
1895 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001896 attr = ts.TosaSerializerAttribute()
1897 attr.CondIfAttribute(then_block, else_block)
1898
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001899 if error_name in [
1900 ErrorIf.CondIfInputListThenGraphMismatch,
1901 ErrorIf.CondIfInputListElseGraphMismatch,
1902 ErrorIf.CondIfOutputListElseGraphMismatch,
1903 ErrorIf.CondIfOutputListThenGraphMismatch,
1904 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001905 incorrect_shape = a.shape.copy()
1906 for i in range(len(incorrect_shape)):
1907 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1908 incorrect_block_input = deepcopy(a)
1909 incorrect_block_input.shape = incorrect_shape
1910
Eric Kunzee5e26762020-10-13 16:11:07 -07001911 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001912 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001913 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001914 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001915
James Ward24dbc422022-10-19 12:20:31 +01001916 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001917 then_op, else_op = Op.ADD, Op.SUB
1918 elif a.dtype in (DType.INT8, DType.INT16):
1919 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1920 else:
1921 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001922
Les Bell6040b4d2021-10-11 12:50:31 +01001923 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001924 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001925 if (
1926 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1927 and block == then_block
1928 ) or (
1929 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1930 and block == else_block
1931 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001932 self.ser.addInputTensor(incorrect_block_input)
1933 self.ser.addInputTensor(b)
1934 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001935 elif (
1936 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1937 and block == then_block
1938 ) or (
1939 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1940 and block == else_block
1941 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001942 self.ser.addInputTensor(a)
1943 self.ser.addInputTensor(b)
1944 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1945 else:
1946 self.ser.addInputTensor(a)
1947 self.ser.addInputTensor(b)
1948 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001949 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001950
Les Bell729b0352021-11-24 10:28:21 +00001951 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001952 self.ser,
1953 validator_fcns,
1954 error_name,
1955 op=op,
1956 a=a,
1957 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07001958 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001959 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001960 ):
1961 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001962
Eric Kunzee5e26762020-10-13 16:11:07 -07001963 return result_tens
1964
Matthew Haddon630c17c2021-10-14 15:05:41 +01001965 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001966 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001967
Kevin Cheng550ccc52021-03-03 11:21:43 -08001968 cond_block = "COND_BLOCK"
1969 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001970
1971 attr = ts.TosaSerializerAttribute()
1972 attr.WhileLoopAttribute(cond_block, body_block)
1973
1974 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001975 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001976 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001977 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001978
1979 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001980 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1981 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001982 if error_name == ErrorIf.InputListOutputListMismatch:
1983 incorrect_acc = deepcopy(acc)
1984 for i in range(len(incorrect_acc.shape)):
1985 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1986 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1987 else:
1988 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001989
1990 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001991 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001992 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001993 [iter.name, a.name, acc.name],
1994 [iter_out.name, a_out.name, acc_out.name],
1995 attr,
1996 )
Kevin Chengb227ae52021-09-02 13:43:17 -07001997 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07001998
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001999 if error_name in [
2000 ErrorIf.InputListCondGraphMismatch,
2001 ErrorIf.InputListBodyGraphInputMismatch,
2002 ErrorIf.InputListBodyGraphOutputMismatch,
2003 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002004 incorrect_iter = deepcopy(iter)
2005 for i in range(len(incorrect_iter.shape)):
2006 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2007 if len(incorrect_iter.shape) == 0:
2008 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2009
2010 incorrect_acc = deepcopy(acc)
2011 for i in range(len(incorrect_acc.shape)):
2012 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2013
Eric Kunzee5e26762020-10-13 16:11:07 -07002014 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002015 self.ser.addBasicBlock(cond_block)
2016
Matthew Haddon630c17c2021-10-14 15:05:41 +01002017 if error_name == ErrorIf.InputListCondGraphMismatch:
2018 self.ser.addInputTensor(incorrect_iter)
2019 self.ser.addInputTensor(a)
2020 self.ser.addInputTensor(incorrect_acc)
2021 else:
2022 self.ser.addInputTensor(iter)
2023 self.ser.addInputTensor(a)
2024 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002025 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002026
2027 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002028 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002029 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002030 cond_type = DType.BOOL
2031 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2032 choice = self.rng.choice([1, 2])
2033 if choice == 1:
2034 cond_shape = [3]
2035 else:
2036 cond_shape = [1, 2]
2037 else:
2038 cond_shape = []
2039 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002040
Kevin Cheng550ccc52021-03-03 11:21:43 -08002041 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002042
2043 # BODY block (input: a, acc, iter, output: a, acc, iter)
2044 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002045 self.ser.addBasicBlock(body_block)
2046
Matthew Haddon630c17c2021-10-14 15:05:41 +01002047 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2048 self.ser.addInputTensor(incorrect_iter)
2049 self.ser.addInputTensor(a)
2050 self.ser.addInputTensor(incorrect_acc)
2051 else:
2052 self.ser.addInputTensor(iter)
2053 self.ser.addInputTensor(a)
2054 self.ser.addInputTensor(acc)
2055
Kevin Cheng550ccc52021-03-03 11:21:43 -08002056 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002057
2058 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002059 iter_body_out = self.ser.addIntermediate(
2060 incorrect_iter.shape, incorrect_iter.dtype
2061 )
2062 acc_body_out = self.ser.addIntermediate(
2063 incorrect_acc.shape, incorrect_acc.dtype
2064 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002065 else:
2066 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2067 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2068
Eric Kunzee5e26762020-10-13 16:11:07 -07002069 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2070 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2071 self.ser.addOutputTensor(iter_body_out)
2072 self.ser.addOutputTensor(a)
2073 self.ser.addOutputTensor(acc_body_out)
2074
Les Bell729b0352021-11-24 10:28:21 +00002075 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002076 self.ser,
2077 validator_fcns,
2078 error_name,
2079 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002080 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002081 ):
2082 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002083
Eric Kunzee5e26762020-10-13 16:11:07 -07002084 return acc_out
2085
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002086 def create_filter_lists(
2087 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2088 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002089 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2090 default_test_rank_range = range(1, 5)
2091 if not shapeFilter:
2092 shapeFilter = [None]
2093
2094 # Calculate the filters based on what is requested and what the operator allows
2095 rmin, rmax = op["rank"]
2096 if rankFilter is not None:
2097 cleanRankFilter = []
2098 # Ensure rankFilter values are allowed by operator
2099 for rank in rankFilter:
2100 if rank >= rmin and rank <= rmax:
2101 cleanRankFilter.append(rank)
2102 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002103 # Ensure default behaviour is bounded by default range or by operator,
2104 # whichever is the smaller range of ranks.
2105 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002106 cleanRankFilter = (
2107 opRankRange
2108 if len(opRankRange) <= len(default_test_rank_range)
2109 else default_test_rank_range
2110 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002111 else:
2112 cleanRankFilter = range(rmin, rmax + 1)
2113
2114 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002115
Matthew Haddon1c00b712021-10-01 15:51:03 +01002116 if dtypeFilter is not None:
2117 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002118 # Create list of operator dtypes filtered by requested dtypes
2119 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002120 if dtype in dtypeFilter or (
2121 isinstance(dtype, list) and dtype[0] in dtypeFilter
2122 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002123 cleanDtypeFilter.append(dtype)
2124 else:
2125 cleanDtypeFilter = dtypes
2126
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002127 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002128 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002129 "shapeFilter": shapeFilter,
2130 "rankFilter": cleanRankFilter,
2131 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002132 }
2133 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002134 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002135 if validator is not None:
2136 validator_info = validator(check=False, op=op)
2137 else:
2138 return None
2139
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002140 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002141
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002142 # Set parameters as required
2143 if error_arguments["rank"] is not None:
2144 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002145 else:
2146 rankFilter = cleanRankFilter
2147
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002148 if error_arguments["dtype"] is not None:
2149 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002150 else:
2151 dtypeFilter = cleanDtypeFilter
2152
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002153 if error_arguments["shape"] is not None:
2154 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002155 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002156 shapeFilter = shapeFilter[
2157 :2
2158 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002159
2160 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002161 "shapeFilter": shapeFilter,
2162 "rankFilter": rankFilter,
2163 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002164 }
2165 return filterDict
2166
Kevin Cheng550ccc52021-03-03 11:21:43 -08002167 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002168 self,
2169 opName,
2170 shapeFilter=[None],
2171 rankFilter=None,
2172 dtypeFilter=None,
2173 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002174 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002175
2176 try:
2177 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002178 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002179 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002180
2181 # Initialize a new random number generator
2182 self.rng = np.random.default_rng(self.random_seed)
2183
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002184 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002185
Eric Kunzee5e26762020-10-13 16:11:07 -07002186 # Test list consists of a tuple of:
2187 # (opName, testNameStr, dtype, shapeList, argumentsList)
2188 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002189 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002190 error_if_validators = op["error_if_validators"]
2191 else:
2192 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002193
Matthew Haddon1c00b712021-10-01 15:51:03 +01002194 for validator in error_if_validators:
2195 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002196 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002197 else:
2198 error_name = None
2199
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002200 filterDict = self.create_filter_lists(
2201 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2202 )
2203 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002204 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002205 cleanRankFilter = filterDict["rankFilter"]
2206 cleanDtypeFilter = filterDict["dtypeFilter"]
2207 cleanShapeFilter = filterDict["shapeFilter"]
2208 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002209
2210 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002211 for t in cleanDtypeFilter:
2212 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002213 # Filter out by rank
2214 if shape is not None and len(shape) != r:
2215 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002216 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002217 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002218
Matthew Haddon74567092021-07-16 15:38:20 +01002219 shapeStr = self.shapeStr(shapeList[0])
2220 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002221
Matthew Haddon74567092021-07-16 15:38:20 +01002222 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2223 argList = []
2224 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002225 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002226 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002227 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002228
Matthew Haddon74567092021-07-16 15:38:20 +01002229 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002230 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002231 if argStr:
2232 testStr = "{}_{}_{}_{}".format(
2233 opName, shapeStr, typeStr, argStr
2234 )
2235 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002236 testStr = "{}_{}_{}".format(
2237 opName, shapeStr, typeStr
2238 )
2239 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002240 if argStr:
2241 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2242 opName, error_name, shapeStr, typeStr, argStr
2243 )
2244 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002245 testStr = "{}_ERRORIF_{}_{}_{}".format(
2246 opName, error_name, shapeStr, typeStr
2247 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002248
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002249 testList.append(
2250 (opName, testStr, t, error_name, shapeList, args)
2251 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002252
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002253 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002254 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2255 if "invalid_test_validators" in op:
2256 invalid_test_validators = op["invalid_test_validators"]
2257 clean_testList = []
2258 for test in testList:
2259 for validator_fcn in invalid_test_validators:
2260 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002261 if validator_fcn(
2262 opName=test[0],
2263 input_dtype=test[2],
2264 shapeList=test[4],
2265 args=test[5],
2266 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002267 remove_test = True
2268 if not remove_test:
2269 clean_testList.append(test)
2270 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002271
2272 return testList
2273
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002274 def serializeTest(
2275 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2276 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002277 try:
2278 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002279 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002280 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002281
2282 # Create a serializer
2283 self.createSerializer(opName, testStr)
2284
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002285 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002286 if "error_if_validators" in op:
2287 error_if_validators = op["error_if_validators"]
2288 else:
2289 error_if_validators = None
2290
Kevin Cheng550ccc52021-03-03 11:21:43 -08002291 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002292 num_operands = pCount + cCount
2293
2294 if isinstance(dtype_or_dtypeList, list):
2295 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002296 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002297 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002298 else:
2299 dtypeList = [dtype_or_dtypeList] * (num_operands)
2300
Kevin Cheng93a16282021-08-31 16:14:03 -07002301 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002302 assert (
2303 len(shapeList) == num_operands
2304 ), "shapeList length {} must match number of operands {}".format(
2305 len(shapeList), num_operands
2306 )
2307 assert (
2308 len(dtypeList) == num_operands
2309 ), "dtypeList length {} must match number of operands {}".format(
2310 len(dtypeList), num_operands
2311 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002312
2313 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002314 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002315 except KeyError:
2316 qgen = None
2317
2318 # Build the random tensor operands and the test
2319 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002320
Matthew Haddon1c00b712021-10-01 15:51:03 +01002321 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002322 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002323 else:
2324 qinfo = None
2325
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002326 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002327
Matthew Haddon1c00b712021-10-01 15:51:03 +01002328 try:
2329 if error_if_validators is None:
2330 if qinfo is not None:
2331 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2332 else:
2333 resultName = build_fcn(self, op, *tens, *testArgs)
2334 else:
2335 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002336 resultName = build_fcn(
2337 self,
2338 op,
2339 *tens,
2340 *testArgs,
2341 validator_fcns=error_if_validators,
2342 error_name=error_name,
2343 qinfo=qinfo,
2344 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002345 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002346 resultName = build_fcn(
2347 self,
2348 op,
2349 *tens,
2350 *testArgs,
2351 validator_fcns=error_if_validators,
2352 error_name=error_name,
2353 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002354 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002355 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002356 raise e
2357
Les Bell729b0352021-11-24 10:28:21 +00002358 if resultName:
2359 # The test is valid, serialize it
2360 self.serialize("test")
2361 else:
2362 # The test is not valid
2363 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002364
Eric Kunzee5e26762020-10-13 16:11:07 -07002365 def createDynamicOpLists(self):
2366
Jeremy Johnson00423432022-09-12 17:27:37 +01002367 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2368 # Already created these lists (can occur when class is initialized more than once)
2369 return
2370
Eric Kunzee5e26762020-10-13 16:11:07 -07002371 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002372 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002373
Kevin Cheng1533b852021-09-01 12:51:58 -07002374 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002375 testName = "conv2d_{}x{}".format(k[0], k[1])
2376 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2377 self.TOSA_OP_LIST[testName]["filter"] = k
2378 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002379
Kevin Cheng550ccc52021-03-03 11:21:43 -08002380 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2381 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2382 "depthwise_conv2d_TEMPLATE"
2383 ].copy()
2384 self.TOSA_OP_LIST[testName]["filter"] = k
2385 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002386
Kevin Cheng550ccc52021-03-03 11:21:43 -08002387 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2388 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2389 "transpose_conv2d_TEMPLATE"
2390 ].copy()
2391 self.TOSA_OP_LIST[testName]["filter"] = k
2392 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002393
Kevin Cheng1533b852021-09-01 12:51:58 -07002394 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2395 for k in KERNELS_3D:
2396 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2397 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2398 self.TOSA_OP_LIST[testName]["filter"] = k
2399 self.TOSA_OP_LIST[testName]["template"] = False
2400
Eric Kunzee5e26762020-10-13 16:11:07 -07002401 # Delete any templates after having created any dynamic ops
2402 # This is a two-pass operation because it's bad practice to delete
2403 # keys from dictionaries while iterating
2404 keyList = []
2405 for k in self.TOSA_OP_LIST:
2406 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002407 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002408 keyList.append(k)
2409 continue
2410 except KeyError:
2411 pass
2412
2413 for k in keyList:
2414 del self.TOSA_OP_LIST[k]
2415
2416 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002417 """Fill in default fields for ops if they aren't already specified.
2418 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002419 for op in self.TOSA_OP_LIST:
2420
2421 # Required fields
2422 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002423 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002424 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002425 raise Exception(
2426 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2427 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002428
2429 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002430 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002431 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002432 raise Exception(
2433 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2434 op
2435 )
2436 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002437
2438 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002439 _ = self.TOSA_OP_LIST[op]["types"]
2440 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002441 raise Exception(
2442 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2443 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002444
2445 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002446 _ = self.TOSA_OP_LIST[op]["op"]
2447 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002448 raise Exception(
2449 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2450 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002451
2452 # Put in default rank range, if missing
2453 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002454 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002455 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002456 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002457
2458 # Tensor operator list
2459 # 'op': op name
2460 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002461 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2462 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002463 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2464 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002465 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002466
Kevin Cheng550ccc52021-03-03 11:21:43 -08002467 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002468 TYPE_INT_FP = [
2469 DType.INT8,
2470 DType.INT16,
2471 DType.INT32,
2472 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002473 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002474 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002475 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002476
Kevin Cheng550ccc52021-03-03 11:21:43 -08002477 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002478 TYPE_FI32 = [
2479 DType.FP32,
2480 DType.FP16,
2481 DType.BF16,
2482 DType.INT32,
2483 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002484 TYPE_FIB = [
2485 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002486 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002487 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002488 DType.INT8,
2489 DType.INT16,
2490 DType.INT32,
2491 DType.BOOL,
2492 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002493 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002494
James Ward24dbc422022-10-19 12:20:31 +01002495 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002496
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002497 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002498 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002499 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002500 [DType.INT8, DType.INT8, DType.INT32],
2501 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002502 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002503 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002504 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002505 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002506 ]
2507
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002508 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002509
2510 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002511 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002512 "argmax": {
2513 "op": Op.ARGMAX,
2514 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002515 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002516 "build_fcn": (
2517 build_argmax,
2518 TosaTensorGen.tgBasic,
2519 TosaTensorValuesGen.tvgDefault,
2520 TosaArgGen.agAxis,
2521 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002522 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002523 "error_if_validators": (
2524 TosaErrorValidator.evAxisSmallerZero,
2525 TosaErrorValidator.evAxisLargerRank,
2526 TosaErrorValidator.evArgmaxOutputRankMismatch,
2527 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2528 TosaErrorValidator.evWrongRank,
2529 TosaErrorValidator.evWrongInputType,
2530 TosaErrorValidator.evWrongOutputType,
2531 TosaErrorValidator.evWrongInputList,
2532 TosaErrorValidator.evWrongOutputList,
2533 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002534 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002535 "avg_pool2d": {
2536 "op": Op.AVG_POOL2D,
2537 "operands": (1, 0),
2538 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002539 "build_fcn": (
2540 build_pool2d,
2541 TosaTensorGen.tgNHWC,
2542 TosaTensorValuesGen.tvgDefault,
2543 TosaArgGen.agPooling,
2544 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002545 "qgen": TosaQuantGen.qgUnary,
2546 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002547 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002548 "error_if_validators": (
2549 TosaErrorValidator.evKernelSmallerOne,
2550 TosaErrorValidator.evStrideSmallerOne,
2551 TosaErrorValidator.evPadSmallerZero,
2552 TosaErrorValidator.evWrongRank,
2553 TosaErrorValidator.evWrongInputType,
2554 TosaErrorValidator.evWrongOutputType,
2555 TosaErrorValidator.evWrongInputList,
2556 TosaErrorValidator.evWrongOutputList,
2557 TosaErrorValidator.evInputZeroPointNotZero,
2558 TosaErrorValidator.evOutputZeroPointNotZero,
2559 TosaErrorValidator.evPadLargerEqualKernel,
2560 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002561 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002562 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002563 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002564 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002565 "conv2d_TEMPLATE": {
2566 "op": Op.CONV2D,
2567 "operands": (1, 2),
2568 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002569 "build_fcn": (
2570 build_conv2d,
2571 TosaTensorGen.tgConv2D,
2572 TosaTensorValuesGen.tvgDefault,
2573 TosaArgGen.agConv,
2574 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002575 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002576 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002577 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2578 "error_if_validators": (
2579 TosaErrorValidator.evWrongInputType,
2580 TosaErrorValidator.evWrongOutputType,
2581 TosaErrorValidator.evWrongInputList,
2582 TosaErrorValidator.evWrongOutputList,
2583 TosaErrorValidator.evInputZeroPointNotZero,
2584 TosaErrorValidator.evWeightZeroPointNotZero,
2585 TosaErrorValidator.evPadSmallerZero,
2586 TosaErrorValidator.evStrideSmallerOne,
2587 TosaErrorValidator.evDilationSmallerOne,
2588 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002589 TosaErrorValidator.evConvOutputShapeMismatch,
2590 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002591 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002592 "template": True,
2593 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002594 # Templated operator. Filled in by createDynamicOpLists
2595 "conv3d_TEMPLATE": {
2596 "op": Op.CONV3D,
2597 "operands": (1, 2),
2598 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002599 "build_fcn": (
2600 build_conv3d,
2601 TosaTensorGen.tgConv3D,
2602 TosaTensorValuesGen.tvgDefault,
2603 TosaArgGen.agConv,
2604 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002605 "qgen": TosaQuantGen.qgConv,
2606 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002607 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2608 "error_if_validators": (
2609 TosaErrorValidator.evWrongInputType,
2610 TosaErrorValidator.evWrongOutputType,
2611 TosaErrorValidator.evWrongInputList,
2612 TosaErrorValidator.evWrongOutputList,
2613 TosaErrorValidator.evInputZeroPointNotZero,
2614 TosaErrorValidator.evWeightZeroPointNotZero,
2615 TosaErrorValidator.evPadSmallerZero,
2616 TosaErrorValidator.evStrideSmallerOne,
2617 TosaErrorValidator.evDilationSmallerOne,
2618 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002619 TosaErrorValidator.evConvOutputShapeMismatch,
2620 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002621 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002622 "template": True,
2623 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002624 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002625 "depthwise_conv2d_TEMPLATE": {
2626 "op": Op.DEPTHWISE_CONV2D,
2627 "operands": (1, 2),
2628 "filter": [1, 1],
2629 "rank": (4, 4),
2630 "build_fcn": (
2631 build_depthwise_conv2d,
2632 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002633 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002634 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002635 ),
2636 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002637 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002638 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2639 "error_if_validators": (
2640 TosaErrorValidator.evWrongInputType,
2641 TosaErrorValidator.evWrongOutputType,
2642 TosaErrorValidator.evWrongInputList,
2643 TosaErrorValidator.evWrongOutputList,
2644 TosaErrorValidator.evInputZeroPointNotZero,
2645 TosaErrorValidator.evWeightZeroPointNotZero,
2646 TosaErrorValidator.evPadSmallerZero,
2647 TosaErrorValidator.evStrideSmallerOne,
2648 TosaErrorValidator.evDilationSmallerOne,
2649 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002650 TosaErrorValidator.evConvOutputShapeMismatch,
2651 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002652 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002653 "template": True,
2654 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002655 "fully_connected": {
2656 "op": Op.FULLY_CONNECTED,
2657 "operands": (1, 2),
2658 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002659 "build_fcn": (
2660 build_fully_connected,
2661 TosaTensorGen.tgFullyConnected,
2662 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002663 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002664 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002665 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002666 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002667 "error_if_validators": (
2668 TosaErrorValidator.evInputZeroPointNotZero,
2669 TosaErrorValidator.evWeightZeroPointNotZero,
2670 TosaErrorValidator.evWrongRank,
2671 TosaErrorValidator.evWrongInputType,
2672 TosaErrorValidator.evWrongOutputType,
2673 TosaErrorValidator.evWrongInputList,
2674 TosaErrorValidator.evWrongOutputList,
2675 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002676 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002677 "matmul": {
2678 "op": Op.MATMUL,
2679 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002680 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002681 "build_fcn": (
2682 build_matmul,
2683 TosaTensorGen.tgMatmul,
2684 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002685 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002686 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002687 "qgen": TosaQuantGen.qgMatmul,
2688 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002689 "error_if_validators": (
2690 TosaErrorValidator.evInputZeroPointNotZero,
2691 TosaErrorValidator.evWrongRank,
2692 TosaErrorValidator.evWrongInputType,
2693 TosaErrorValidator.evWrongOutputType,
2694 TosaErrorValidator.evWrongInputList,
2695 TosaErrorValidator.evWrongOutputList,
2696 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002697 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002698 "max_pool2d": {
2699 "op": Op.MAX_POOL2D,
2700 "operands": (1, 0),
2701 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002702 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002703 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002704 TosaTensorGen.tgNHWC,
2705 TosaTensorValuesGen.tvgDefault,
2706 TosaArgGen.agPooling,
2707 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002708 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002709 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002710 "error_if_validators": (
2711 TosaErrorValidator.evKernelSmallerOne,
2712 TosaErrorValidator.evStrideSmallerOne,
2713 TosaErrorValidator.evPadSmallerZero,
2714 TosaErrorValidator.evWrongRank,
2715 TosaErrorValidator.evWrongInputType,
2716 TosaErrorValidator.evWrongOutputType,
2717 TosaErrorValidator.evWrongInputList,
2718 TosaErrorValidator.evWrongOutputList,
2719 TosaErrorValidator.evPadLargerEqualKernel,
2720 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002721 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002722 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002723 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002724 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002725 "transpose_conv2d_TEMPLATE": {
2726 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002727 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002728 "rank": (4, 4),
2729 "build_fcn": (
2730 build_transpose_conv2d,
2731 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002732 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002733 TosaArgGen.agTransposeConv2D,
2734 ),
2735 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002736 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002737 "invalid_test_validators": (
2738 TosaInvalidValidator.ivHeightWidthInvalid,
2739 TosaInvalidValidator.ivNonPositiveOutputShape,
2740 ),
2741 "error_if_validators": (
2742 TosaErrorValidator.evWrongInputType,
2743 TosaErrorValidator.evWrongOutputType,
2744 TosaErrorValidator.evWrongInputList,
2745 TosaErrorValidator.evWrongOutputList,
2746 TosaErrorValidator.evInputZeroPointNotZero,
2747 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002748 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002749 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002750 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002751 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002752 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002753 "template": True,
2754 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002755 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002756 "clamp": {
2757 "op": Op.CLAMP,
2758 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002759 "build_fcn": (
2760 build_clamp,
2761 TosaTensorGen.tgBasic,
2762 TosaTensorValuesGen.tvgDefault,
2763 None,
2764 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002765 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002766 "error_if_validators": (
2767 TosaErrorValidator.evMaxSmallerMin,
2768 TosaErrorValidator.evWrongInputType,
2769 TosaErrorValidator.evWrongOutputType,
2770 TosaErrorValidator.evWrongInputList,
2771 TosaErrorValidator.evWrongOutputList,
2772 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002773 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002774 "sigmoid": {
2775 "op": Op.SIGMOID,
2776 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002777 "build_fcn": (
2778 build_sigmoid,
2779 TosaTensorGen.tgBasic,
2780 TosaTensorValuesGen.tvgDefault,
2781 None,
2782 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002783 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002784 "error_if_validators": (
2785 TosaErrorValidator.evWrongInputType,
2786 TosaErrorValidator.evWrongOutputType,
2787 TosaErrorValidator.evWrongInputList,
2788 TosaErrorValidator.evWrongOutputList,
2789 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002790 },
2791 "tanh": {
2792 "op": Op.TANH,
2793 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002794 "build_fcn": (
2795 build_tanh,
2796 TosaTensorGen.tgBasic,
2797 TosaTensorValuesGen.tvgDefault,
2798 None,
2799 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002800 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002801 "error_if_validators": (
2802 TosaErrorValidator.evWrongInputType,
2803 TosaErrorValidator.evWrongOutputType,
2804 TosaErrorValidator.evWrongInputList,
2805 TosaErrorValidator.evWrongOutputList,
2806 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002807 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002808 # Elementwise Binary Operators
2809 "add": {
2810 "op": Op.ADD,
2811 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002812 "build_fcn": (
2813 build_binary_broadcast,
2814 TosaTensorGen.tgBroadcastFuzz,
2815 TosaTensorValuesGen.tvgAddSub,
2816 None,
2817 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002818 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002819 "error_if_validators": (
2820 TosaErrorValidator.evRankMismatch,
2821 TosaErrorValidator.evWrongInputType,
2822 TosaErrorValidator.evWrongOutputType,
2823 TosaErrorValidator.evWrongInputList,
2824 TosaErrorValidator.evWrongOutputList,
2825 TosaErrorValidator.evDimensionMismatch,
2826 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002827 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002828 "arithmetic_right_shift": {
2829 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2830 "operands": (2, 0),
2831 "build_fcn": (
2832 build_arithmetic_right_shift,
2833 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002834 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002835 TosaArgGen.agArithmeticRightShift,
2836 ),
2837 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002838 "error_if_validators": (
2839 TosaErrorValidator.evRankMismatch,
2840 TosaErrorValidator.evWrongInputType,
2841 TosaErrorValidator.evWrongOutputType,
2842 TosaErrorValidator.evWrongInputList,
2843 TosaErrorValidator.evWrongOutputList,
2844 TosaErrorValidator.evDimensionMismatch,
2845 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002846 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002847 "bitwise_and": {
2848 "op": Op.BITWISE_AND,
2849 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002850 "build_fcn": (
2851 build_binary_broadcast,
2852 TosaTensorGen.tgBroadcastFuzz,
2853 TosaTensorValuesGen.tvgDefault,
2854 None,
2855 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002856 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002857 "error_if_validators": (
2858 TosaErrorValidator.evRankMismatch,
2859 TosaErrorValidator.evWrongInputType,
2860 TosaErrorValidator.evWrongOutputType,
2861 TosaErrorValidator.evWrongInputList,
2862 TosaErrorValidator.evWrongOutputList,
2863 TosaErrorValidator.evDimensionMismatch,
2864 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002865 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002866 "bitwise_or": {
2867 "op": Op.BITWISE_OR,
2868 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002869 "build_fcn": (
2870 build_binary_broadcast,
2871 TosaTensorGen.tgBroadcastFuzz,
2872 TosaTensorValuesGen.tvgDefault,
2873 None,
2874 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002875 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002876 "error_if_validators": (
2877 TosaErrorValidator.evRankMismatch,
2878 TosaErrorValidator.evWrongInputType,
2879 TosaErrorValidator.evWrongOutputType,
2880 TosaErrorValidator.evWrongInputList,
2881 TosaErrorValidator.evWrongOutputList,
2882 TosaErrorValidator.evDimensionMismatch,
2883 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002884 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002885 "bitwise_xor": {
2886 "op": Op.BITWISE_XOR,
2887 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002888 "build_fcn": (
2889 build_binary_broadcast,
2890 TosaTensorGen.tgBroadcastFuzz,
2891 TosaTensorValuesGen.tvgDefault,
2892 None,
2893 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002894 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002895 "error_if_validators": (
2896 TosaErrorValidator.evRankMismatch,
2897 TosaErrorValidator.evWrongInputType,
2898 TosaErrorValidator.evWrongOutputType,
2899 TosaErrorValidator.evWrongInputList,
2900 TosaErrorValidator.evWrongOutputList,
2901 TosaErrorValidator.evDimensionMismatch,
2902 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002903 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002904 "intdiv": {
2905 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002906 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002907 "build_fcn": (
2908 build_binary_broadcast,
2909 TosaTensorGen.tgBroadcastFuzz,
2910 TosaTensorValuesGen.tvgIntDiv,
2911 None,
2912 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002913 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002914 "error_if_validators": (
2915 TosaErrorValidator.evRankMismatch,
2916 TosaErrorValidator.evWrongInputType,
2917 TosaErrorValidator.evWrongOutputType,
2918 TosaErrorValidator.evWrongInputList,
2919 TosaErrorValidator.evWrongOutputList,
2920 TosaErrorValidator.evDimensionMismatch,
2921 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002922 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002923 "logical_and": {
2924 "op": Op.LOGICAL_AND,
2925 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002926 "build_fcn": (
2927 build_binary_broadcast,
2928 TosaTensorGen.tgBroadcastFuzz,
2929 TosaTensorValuesGen.tvgDefault,
2930 None,
2931 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002932 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002933 "error_if_validators": (
2934 TosaErrorValidator.evRankMismatch,
2935 TosaErrorValidator.evWrongInputType,
2936 TosaErrorValidator.evWrongOutputType,
2937 TosaErrorValidator.evWrongInputList,
2938 TosaErrorValidator.evWrongOutputList,
2939 TosaErrorValidator.evDimensionMismatch,
2940 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002941 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002942 "logical_left_shift": {
2943 "op": Op.LOGICAL_LEFT_SHIFT,
2944 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002945 "build_fcn": (
2946 build_binary_broadcast,
2947 TosaTensorGen.tgBroadcastFuzz,
2948 TosaTensorValuesGen.tvgLogicalShift,
2949 None,
2950 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002951 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002952 "error_if_validators": (
2953 TosaErrorValidator.evRankMismatch,
2954 TosaErrorValidator.evWrongInputType,
2955 TosaErrorValidator.evWrongOutputType,
2956 TosaErrorValidator.evWrongInputList,
2957 TosaErrorValidator.evWrongOutputList,
2958 TosaErrorValidator.evDimensionMismatch,
2959 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002960 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002961 "logical_right_shift": {
2962 "op": Op.LOGICAL_RIGHT_SHIFT,
2963 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002964 "build_fcn": (
2965 build_binary_broadcast,
2966 TosaTensorGen.tgBroadcastFuzz,
2967 TosaTensorValuesGen.tvgLogicalShift,
2968 None,
2969 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002970 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002971 "error_if_validators": (
2972 TosaErrorValidator.evRankMismatch,
2973 TosaErrorValidator.evWrongInputType,
2974 TosaErrorValidator.evWrongOutputType,
2975 TosaErrorValidator.evWrongInputList,
2976 TosaErrorValidator.evWrongOutputList,
2977 TosaErrorValidator.evDimensionMismatch,
2978 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002979 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002980 "logical_or": {
2981 "op": Op.LOGICAL_OR,
2982 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002983 "build_fcn": (
2984 build_binary_broadcast,
2985 TosaTensorGen.tgBroadcastFuzz,
2986 TosaTensorValuesGen.tvgDefault,
2987 None,
2988 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002989 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002990 "error_if_validators": (
2991 TosaErrorValidator.evRankMismatch,
2992 TosaErrorValidator.evWrongInputType,
2993 TosaErrorValidator.evWrongOutputType,
2994 TosaErrorValidator.evWrongInputList,
2995 TosaErrorValidator.evWrongOutputList,
2996 TosaErrorValidator.evDimensionMismatch,
2997 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002998 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002999 "logical_xor": {
3000 "op": Op.LOGICAL_XOR,
3001 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003002 "build_fcn": (
3003 build_binary_broadcast,
3004 TosaTensorGen.tgBroadcastFuzz,
3005 TosaTensorValuesGen.tvgDefault,
3006 None,
3007 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003008 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003009 "error_if_validators": (
3010 TosaErrorValidator.evRankMismatch,
3011 TosaErrorValidator.evWrongInputType,
3012 TosaErrorValidator.evWrongOutputType,
3013 TosaErrorValidator.evWrongInputList,
3014 TosaErrorValidator.evWrongOutputList,
3015 TosaErrorValidator.evDimensionMismatch,
3016 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003017 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003018 "maximum": {
3019 "op": Op.MAXIMUM,
3020 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003021 "build_fcn": (
3022 build_binary_broadcast,
3023 TosaTensorGen.tgBroadcastFuzz,
3024 TosaTensorValuesGen.tvgDefault,
3025 None,
3026 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003027 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003028 "error_if_validators": (
3029 TosaErrorValidator.evRankMismatch,
3030 TosaErrorValidator.evWrongInputType,
3031 TosaErrorValidator.evWrongOutputType,
3032 TosaErrorValidator.evWrongInputList,
3033 TosaErrorValidator.evWrongOutputList,
3034 TosaErrorValidator.evDimensionMismatch,
3035 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003036 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003037 "minimum": {
3038 "op": Op.MINIMUM,
3039 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003040 "build_fcn": (
3041 build_binary_broadcast,
3042 TosaTensorGen.tgBroadcastFuzz,
3043 TosaTensorValuesGen.tvgDefault,
3044 None,
3045 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003046 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003047 "error_if_validators": (
3048 TosaErrorValidator.evRankMismatch,
3049 TosaErrorValidator.evWrongInputType,
3050 TosaErrorValidator.evWrongOutputType,
3051 TosaErrorValidator.evWrongInputList,
3052 TosaErrorValidator.evWrongOutputList,
3053 TosaErrorValidator.evDimensionMismatch,
3054 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003055 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003056 "mul": {
3057 "op": Op.MUL,
3058 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003059 "build_fcn": (
3060 build_mul,
3061 TosaTensorGen.tgBroadcastFuzz,
3062 TosaTensorValuesGen.tvgMul,
3063 TosaArgGen.agMul,
3064 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003065 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003066 "error_if_validators": (
3067 TosaErrorValidator.evWrongInputType,
3068 TosaErrorValidator.evWrongOutputType,
3069 TosaErrorValidator.evWrongInputList,
3070 TosaErrorValidator.evWrongOutputList,
3071 TosaErrorValidator.evRankMismatch,
3072 TosaErrorValidator.evDimensionMismatch,
3073 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003074 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003075 "pow": {
3076 "op": Op.POW,
3077 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003078 "build_fcn": (
3079 build_binary_broadcast,
3080 TosaTensorGen.tgBroadcastFuzz,
3081 TosaTensorValuesGen.tvgDefault,
3082 None,
3083 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003084 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003085 "error_if_validators": (
3086 TosaErrorValidator.evRankMismatch,
3087 TosaErrorValidator.evWrongInputType,
3088 TosaErrorValidator.evWrongOutputType,
3089 TosaErrorValidator.evWrongInputList,
3090 TosaErrorValidator.evWrongOutputList,
3091 TosaErrorValidator.evDimensionMismatch,
3092 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003094 "sub": {
3095 "op": Op.SUB,
3096 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003097 "build_fcn": (
3098 build_binary_broadcast,
3099 TosaTensorGen.tgBroadcastFuzz,
3100 TosaTensorValuesGen.tvgAddSub,
3101 None,
3102 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003103 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003104 "error_if_validators": (
3105 TosaErrorValidator.evRankMismatch,
3106 TosaErrorValidator.evWrongInputType,
3107 TosaErrorValidator.evWrongOutputType,
3108 TosaErrorValidator.evWrongInputList,
3109 TosaErrorValidator.evWrongOutputList,
3110 TosaErrorValidator.evDimensionMismatch,
3111 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003112 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003113 "table": {
3114 "op": Op.TABLE,
3115 # Use the automatic generation functions to create the input array
3116 # but create the table tensor in the build function, as it may be
3117 # a different type from the input
3118 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003119 "build_fcn": (
3120 build_table,
3121 TosaTensorGen.tgBasic,
3122 TosaTensorValuesGen.tvgDefault,
3123 TosaArgGen.agTable,
3124 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003125 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003126 "error_if_validators": (
3127 TosaErrorValidator.evWrongInputType,
3128 TosaErrorValidator.evWrongOutputType,
3129 TosaErrorValidator.evWrongInputList,
3130 TosaErrorValidator.evWrongOutputList,
3131 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003132 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003133 # Elementwise Unary operators
3134 "abs": {
3135 "op": Op.ABS,
3136 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003137 "build_fcn": (
3138 build_unary,
3139 TosaTensorGen.tgBasic,
3140 TosaTensorValuesGen.tvgDefault,
3141 None,
3142 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003143 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003144 "error_if_validators": (
3145 TosaErrorValidator.evWrongInputType,
3146 TosaErrorValidator.evWrongOutputType,
3147 TosaErrorValidator.evWrongInputList,
3148 TosaErrorValidator.evWrongOutputList,
3149 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003150 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003151 "bitwise_not": {
3152 "op": Op.BITWISE_NOT,
3153 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003154 "build_fcn": (
3155 build_unary,
3156 TosaTensorGen.tgBasic,
3157 TosaTensorValuesGen.tvgDefault,
3158 None,
3159 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003160 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003161 "error_if_validators": (
3162 TosaErrorValidator.evWrongInputType,
3163 TosaErrorValidator.evWrongOutputType,
3164 TosaErrorValidator.evWrongInputList,
3165 TosaErrorValidator.evWrongOutputList,
3166 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003167 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003168 "ceil": {
3169 "op": Op.CEIL,
3170 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003171 "build_fcn": (
3172 build_unary,
3173 TosaTensorGen.tgBasic,
3174 TosaTensorValuesGen.tvgDefault,
3175 None,
3176 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003177 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003178 "error_if_validators": (
3179 TosaErrorValidator.evWrongInputType,
3180 TosaErrorValidator.evWrongOutputType,
3181 TosaErrorValidator.evWrongInputList,
3182 TosaErrorValidator.evWrongOutputList,
3183 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003184 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003185 "clz": {
3186 "op": Op.CLZ,
3187 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003188 "build_fcn": (
3189 build_unary,
3190 TosaTensorGen.tgBasic,
3191 TosaTensorValuesGen.tvgDefault,
3192 None,
3193 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003194 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003195 "error_if_validators": (
3196 TosaErrorValidator.evWrongInputType,
3197 TosaErrorValidator.evWrongOutputType,
3198 TosaErrorValidator.evWrongInputList,
3199 TosaErrorValidator.evWrongOutputList,
3200 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003201 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003202 "exp": {
3203 "op": Op.EXP,
3204 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003205 "build_fcn": (
3206 build_unary,
3207 TosaTensorGen.tgBasic,
3208 TosaTensorValuesGen.tvgDefault,
3209 None,
3210 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003211 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003212 "error_if_validators": (
3213 TosaErrorValidator.evWrongInputType,
3214 TosaErrorValidator.evWrongOutputType,
3215 TosaErrorValidator.evWrongInputList,
3216 TosaErrorValidator.evWrongOutputList,
3217 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003218 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003219 "floor": {
3220 "op": Op.FLOOR,
3221 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003222 "build_fcn": (
3223 build_unary,
3224 TosaTensorGen.tgBasic,
3225 TosaTensorValuesGen.tvgDefault,
3226 None,
3227 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003228 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003229 "error_if_validators": (
3230 TosaErrorValidator.evWrongInputType,
3231 TosaErrorValidator.evWrongOutputType,
3232 TosaErrorValidator.evWrongInputList,
3233 TosaErrorValidator.evWrongOutputList,
3234 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003235 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003236 "log": {
3237 "op": Op.LOG,
3238 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003239 "build_fcn": (
3240 build_unary,
3241 TosaTensorGen.tgBasic,
3242 TosaTensorValuesGen.tvgDefault,
3243 None,
3244 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003245 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003246 "error_if_validators": (
3247 TosaErrorValidator.evWrongInputType,
3248 TosaErrorValidator.evWrongOutputType,
3249 TosaErrorValidator.evWrongInputList,
3250 TosaErrorValidator.evWrongOutputList,
3251 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003252 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003253 "logical_not": {
3254 "op": Op.LOGICAL_NOT,
3255 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003256 "build_fcn": (
3257 build_unary,
3258 TosaTensorGen.tgBasic,
3259 TosaTensorValuesGen.tvgDefault,
3260 None,
3261 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003262 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003263 "error_if_validators": (
3264 TosaErrorValidator.evWrongInputType,
3265 TosaErrorValidator.evWrongOutputType,
3266 TosaErrorValidator.evWrongInputList,
3267 TosaErrorValidator.evWrongOutputList,
3268 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003269 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003270 "negate": {
3271 "op": Op.NEGATE,
3272 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003273 "build_fcn": (
3274 build_unary,
3275 TosaTensorGen.tgBasic,
3276 TosaTensorValuesGen.tvgNegate,
3277 None,
3278 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003279 "qgen": TosaQuantGen.qgUnary,
3280 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003281 "error_if_validators": (
3282 TosaErrorValidator.evInputZeroPointNotZero,
3283 TosaErrorValidator.evOutputZeroPointNotZero,
3284 TosaErrorValidator.evWrongInputType,
3285 TosaErrorValidator.evWrongOutputType,
3286 TosaErrorValidator.evWrongInputList,
3287 TosaErrorValidator.evWrongOutputList,
3288 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003289 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003290 "reciprocal": {
3291 "op": Op.RECIPROCAL,
3292 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003293 "build_fcn": (
3294 build_unary,
3295 TosaTensorGen.tgBasic,
3296 TosaTensorValuesGen.tvgDefault,
3297 None,
3298 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003299 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003300 "error_if_validators": (
3301 TosaErrorValidator.evWrongInputType,
3302 TosaErrorValidator.evWrongOutputType,
3303 TosaErrorValidator.evWrongInputList,
3304 TosaErrorValidator.evWrongOutputList,
3305 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003306 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003307 "rsqrt": {
3308 "op": Op.RSQRT,
3309 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003310 "build_fcn": (
3311 build_unary,
3312 TosaTensorGen.tgBasic,
3313 TosaTensorValuesGen.tvgDefault,
3314 None,
3315 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003316 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003317 "error_if_validators": (
3318 TosaErrorValidator.evWrongInputType,
3319 TosaErrorValidator.evWrongOutputType,
3320 TosaErrorValidator.evWrongInputList,
3321 TosaErrorValidator.evWrongOutputList,
3322 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003323 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003324 # Elementwise Ternary operators
3325 "select": {
3326 "op": Op.SELECT,
3327 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003328 "build_fcn": (
3329 build_select,
3330 TosaTensorGen.tgBroadcastFuzz,
3331 TosaTensorValuesGen.tvgSelect,
3332 None,
3333 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003334 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003335 "error_if_validators": (
3336 TosaErrorValidator.evRankMismatch,
3337 TosaErrorValidator.evWrongInputType,
3338 TosaErrorValidator.evWrongOutputType,
3339 TosaErrorValidator.evWrongInputList,
3340 TosaErrorValidator.evWrongOutputList,
3341 TosaErrorValidator.evDimensionMismatch,
3342 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003344 # Comparison operators
3345 "equal": {
3346 "op": Op.EQUAL,
3347 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003348 "build_fcn": (
3349 build_comparison,
3350 TosaTensorGen.tgBroadcastFuzz,
3351 TosaTensorValuesGen.tvgEqual,
3352 None,
3353 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003354 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003355 "error_if_validators": (
3356 TosaErrorValidator.evRankMismatch,
3357 TosaErrorValidator.evWrongInputType,
3358 TosaErrorValidator.evWrongOutputType,
3359 TosaErrorValidator.evWrongInputList,
3360 TosaErrorValidator.evWrongOutputList,
3361 TosaErrorValidator.evDimensionMismatch,
3362 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003363 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003364 "greater_equal": {
3365 "op": Op.GREATER_EQUAL,
3366 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003367 "build_fcn": (
3368 build_comparison,
3369 TosaTensorGen.tgBroadcastFuzz,
3370 TosaTensorValuesGen.tvgDefault,
3371 None,
3372 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003373 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003374 "error_if_validators": (
3375 TosaErrorValidator.evRankMismatch,
3376 TosaErrorValidator.evWrongInputType,
3377 TosaErrorValidator.evWrongOutputType,
3378 TosaErrorValidator.evWrongInputList,
3379 TosaErrorValidator.evWrongOutputList,
3380 TosaErrorValidator.evDimensionMismatch,
3381 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003382 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003383 "greater": {
3384 "op": Op.GREATER,
3385 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003386 "build_fcn": (
3387 build_comparison,
3388 TosaTensorGen.tgBroadcastFuzz,
3389 TosaTensorValuesGen.tvgDefault,
3390 None,
3391 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003392 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003393 "error_if_validators": (
3394 TosaErrorValidator.evRankMismatch,
3395 TosaErrorValidator.evWrongInputType,
3396 TosaErrorValidator.evWrongOutputType,
3397 TosaErrorValidator.evWrongInputList,
3398 TosaErrorValidator.evWrongOutputList,
3399 TosaErrorValidator.evDimensionMismatch,
3400 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003401 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003402 # Reduction operators
3403 "reduce_all": {
3404 "op": Op.REDUCE_ALL,
3405 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003406 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003407 "build_fcn": (
3408 build_reduce,
3409 TosaTensorGen.tgBasic,
3410 TosaTensorValuesGen.tvgDefault,
3411 TosaArgGen.agAxis,
3412 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003413 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003414 "error_if_validators": (
3415 TosaErrorValidator.evAxisLargerRank,
3416 TosaErrorValidator.evAxisSmallerZero,
3417 TosaErrorValidator.evShapeOfAxisNotOne,
3418 TosaErrorValidator.evWrongInputType,
3419 TosaErrorValidator.evWrongOutputType,
3420 TosaErrorValidator.evWrongRank,
3421 TosaErrorValidator.evWrongInputList,
3422 TosaErrorValidator.evWrongOutputList,
3423 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003424 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003425 "reduce_any": {
3426 "op": Op.REDUCE_ANY,
3427 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003428 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003429 "build_fcn": (
3430 build_reduce,
3431 TosaTensorGen.tgBasic,
3432 TosaTensorValuesGen.tvgDefault,
3433 TosaArgGen.agAxis,
3434 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003435 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003436 "error_if_validators": (
3437 TosaErrorValidator.evAxisLargerRank,
3438 TosaErrorValidator.evAxisSmallerZero,
3439 TosaErrorValidator.evShapeOfAxisNotOne,
3440 TosaErrorValidator.evWrongInputType,
3441 TosaErrorValidator.evWrongOutputType,
3442 TosaErrorValidator.evWrongRank,
3443 TosaErrorValidator.evWrongInputList,
3444 TosaErrorValidator.evWrongOutputList,
3445 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003446 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003447 "reduce_max": {
3448 "op": Op.REDUCE_MAX,
3449 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003450 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003451 "build_fcn": (
3452 build_reduce,
3453 TosaTensorGen.tgBasic,
3454 TosaTensorValuesGen.tvgDefault,
3455 TosaArgGen.agAxis,
3456 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003457 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003458 "error_if_validators": (
3459 TosaErrorValidator.evAxisLargerRank,
3460 TosaErrorValidator.evAxisSmallerZero,
3461 TosaErrorValidator.evShapeOfAxisNotOne,
3462 TosaErrorValidator.evWrongInputType,
3463 TosaErrorValidator.evWrongOutputType,
3464 TosaErrorValidator.evWrongRank,
3465 TosaErrorValidator.evWrongInputList,
3466 TosaErrorValidator.evWrongOutputList,
3467 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003468 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003469 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003470 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003471 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003472 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003473 "build_fcn": (
3474 build_reduce,
3475 TosaTensorGen.tgBasic,
3476 TosaTensorValuesGen.tvgDefault,
3477 TosaArgGen.agAxis,
3478 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003479 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003480 "error_if_validators": (
3481 TosaErrorValidator.evAxisLargerRank,
3482 TosaErrorValidator.evAxisSmallerZero,
3483 TosaErrorValidator.evShapeOfAxisNotOne,
3484 TosaErrorValidator.evWrongInputType,
3485 TosaErrorValidator.evWrongOutputType,
3486 TosaErrorValidator.evWrongRank,
3487 TosaErrorValidator.evWrongInputList,
3488 TosaErrorValidator.evWrongOutputList,
3489 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003490 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003491 "reduce_product": {
3492 "op": Op.REDUCE_PRODUCT,
3493 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003494 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003495 "build_fcn": (
3496 build_reduce,
3497 TosaTensorGen.tgBasic,
3498 TosaTensorValuesGen.tvgDefault,
3499 TosaArgGen.agAxis,
3500 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003501 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003502 "error_if_validators": (
3503 TosaErrorValidator.evAxisLargerRank,
3504 TosaErrorValidator.evAxisSmallerZero,
3505 TosaErrorValidator.evShapeOfAxisNotOne,
3506 TosaErrorValidator.evWrongInputType,
3507 TosaErrorValidator.evWrongOutputType,
3508 TosaErrorValidator.evWrongRank,
3509 TosaErrorValidator.evWrongInputList,
3510 TosaErrorValidator.evWrongOutputList,
3511 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003512 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003513 "reduce_sum": {
3514 "op": Op.REDUCE_SUM,
3515 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003516 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003517 "build_fcn": (
3518 build_reduce,
3519 TosaTensorGen.tgBasic,
3520 TosaTensorValuesGen.tvgReduceSum,
3521 TosaArgGen.agAxis,
3522 ),
James Ward24dbc422022-10-19 12:20:31 +01003523 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003524 "error_if_validators": (
3525 TosaErrorValidator.evAxisLargerRank,
3526 TosaErrorValidator.evAxisSmallerZero,
3527 TosaErrorValidator.evShapeOfAxisNotOne,
3528 TosaErrorValidator.evWrongInputType,
3529 TosaErrorValidator.evWrongOutputType,
3530 TosaErrorValidator.evWrongRank,
3531 TosaErrorValidator.evWrongInputList,
3532 TosaErrorValidator.evWrongOutputList,
3533 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003534 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003535 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003536 "concat": {
3537 "op": Op.CONCAT,
3538 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003539 "build_fcn": (
3540 build_concat,
3541 TosaTensorGen.tgConcat,
3542 TosaTensorValuesGen.tvgConcat,
3543 TosaArgGen.agAxis,
3544 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003545 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003546 "error_if_validators": (
3547 TosaErrorValidator.evAxisLargerRank,
3548 TosaErrorValidator.evAxisSmallerZero,
3549 TosaErrorValidator.evConcatInputRankMismatch,
3550 TosaErrorValidator.evConcatShapeSumMismatch,
3551 TosaErrorValidator.evConcatInputDimMismatch,
3552 TosaErrorValidator.evWrongInputType,
3553 TosaErrorValidator.evWrongOutputType,
3554 TosaErrorValidator.evWrongOutputList,
3555 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003556 },
3557 "pad": {
3558 "op": Op.PAD,
3559 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003560 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003561 "build_fcn": (
3562 build_pad,
3563 TosaTensorGen.tgBasic,
3564 TosaTensorValuesGen.tvgDefault,
3565 TosaArgGen.agPad,
3566 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003567 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003568 "error_if_validators": (
3569 TosaErrorValidator.evWrongInputType,
3570 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003571 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003572 TosaErrorValidator.evWrongOutputType,
3573 TosaErrorValidator.evWrongInputList,
3574 TosaErrorValidator.evWrongOutputList,
3575 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003576 },
3577 "reshape": {
3578 "op": Op.RESHAPE,
3579 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003580 "build_fcn": (
3581 build_reshape,
3582 TosaTensorGen.tgBasic,
3583 TosaTensorValuesGen.tvgDefault,
3584 TosaArgGen.agReshape,
3585 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003586 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003587 "error_if_validators": (
3588 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3589 TosaErrorValidator.evWrongInputType,
3590 TosaErrorValidator.evWrongOutputType,
3591 TosaErrorValidator.evWrongInputList,
3592 TosaErrorValidator.evWrongOutputList,
3593 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003594 },
3595 "reverse": {
3596 "op": Op.REVERSE,
3597 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003598 "build_fcn": (
3599 build_reverse,
3600 TosaTensorGen.tgBasic,
3601 TosaTensorValuesGen.tvgDefault,
3602 TosaArgGen.agAxis,
3603 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003604 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003605 "error_if_validators": (
3606 TosaErrorValidator.evAxisSmallerZero,
3607 TosaErrorValidator.evAxisLargerRank,
3608 TosaErrorValidator.evWrongInputType,
3609 TosaErrorValidator.evWrongOutputType,
3610 TosaErrorValidator.evWrongInputList,
3611 TosaErrorValidator.evWrongOutputList,
3612 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003613 },
3614 "slice": {
3615 "op": Op.SLICE,
3616 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003617 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003618 "build_fcn": (
3619 build_slice,
3620 TosaTensorGen.tgBasic,
3621 TosaTensorValuesGen.tvgDefault,
3622 TosaArgGen.agSlice,
3623 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003624 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003625 "error_if_validators": (
3626 TosaErrorValidator.evStartSmallerZero,
3627 TosaErrorValidator.evSizeSmallerEqualZero,
3628 TosaErrorValidator.evStartSizeOutsideBounds,
3629 TosaErrorValidator.evSizeOutputShapeMismatch,
3630 TosaErrorValidator.evInputSizeStartLengthMismatch,
3631 TosaErrorValidator.evWrongRank,
3632 TosaErrorValidator.evWrongInputType,
3633 TosaErrorValidator.evWrongOutputType,
3634 TosaErrorValidator.evWrongInputList,
3635 TosaErrorValidator.evWrongOutputList,
3636 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003637 },
3638 "tile": {
3639 "op": Op.TILE,
3640 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003641 "build_fcn": (
3642 build_tile,
3643 TosaTensorGen.tgBasic,
3644 TosaTensorValuesGen.tvgDefault,
3645 TosaArgGen.agTile,
3646 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003647 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003648 "error_if_validators": (
3649 TosaErrorValidator.evWrongInputType,
3650 TosaErrorValidator.evWrongOutputType,
3651 TosaErrorValidator.evWrongInputList,
3652 TosaErrorValidator.evWrongOutputList,
3653 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003654 },
3655 "transpose": {
3656 "op": Op.TRANSPOSE,
3657 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003658 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003659 "build_fcn": (
3660 build_transpose,
3661 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003662 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003663 TosaArgGen.agTranspose,
3664 ),
3665 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003666 "error_if_validators": (
3667 TosaErrorValidator.evIndexOutsideBounds,
3668 TosaErrorValidator.evIndexUsedTwice,
3669 TosaErrorValidator.evWrongInputType,
3670 TosaErrorValidator.evWrongOutputType,
3671 TosaErrorValidator.evWrongInputList,
3672 TosaErrorValidator.evWrongOutputList,
3673 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003674 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003675 # Data nodes
3676 "const": {
3677 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003678 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003679 "build_fcn": (
3680 build_const,
3681 TosaTensorGen.tgBasic,
3682 TosaTensorValuesGen.tvgDefault,
3683 None,
3684 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003685 "types": TYPE_FIB,
3686 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 "identity": {
3688 "op": Op.IDENTITY,
3689 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003690 "build_fcn": (
3691 build_unary,
3692 TosaTensorGen.tgBasic,
3693 TosaTensorValuesGen.tvgDefault,
3694 None,
3695 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003696 "types": TYPE_FIB,
3697 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003698 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003699 "gather": {
3700 "op": Op.GATHER,
3701 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3702 "operands": (1, 0),
3703 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 "build_fcn": (
3705 build_gather,
3706 TosaTensorGen.tgBasic,
3707 TosaTensorValuesGen.tvgDefault,
3708 None,
3709 ),
James Ward24dbc422022-10-19 12:20:31 +01003710 "types": (
3711 DType.INT8,
3712 DType.INT16,
3713 DType.INT32,
3714 DType.FP16,
3715 DType.BF16,
3716 DType.FP32,
3717 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003718 "error_if_validators": (
3719 TosaErrorValidator.evWrongInputType,
3720 TosaErrorValidator.evWrongOutputType,
3721 TosaErrorValidator.evWrongInputList,
3722 TosaErrorValidator.evWrongOutputList,
3723 TosaErrorValidator.evWrongRank,
3724 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003725 },
3726 "scatter": {
3727 "op": Op.SCATTER,
3728 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003729 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003730 "operands": (2, 0),
3731 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 "build_fcn": (
3733 build_scatter,
3734 TosaTensorGen.tgScatter,
3735 TosaTensorValuesGen.tvgDefault,
3736 None,
3737 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003738 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003739 "error_if_validators": (
3740 TosaErrorValidator.evWrongInputType,
3741 TosaErrorValidator.evWrongOutputType,
3742 TosaErrorValidator.evWrongInputList,
3743 TosaErrorValidator.evWrongOutputList,
3744 TosaErrorValidator.evWrongRank,
3745 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003746 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003747 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003748 "resize": {
3749 "op": Op.RESIZE,
3750 "operands": (1, 0),
3751 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 "build_fcn": (
3753 build_resize,
3754 TosaTensorGen.tgNHWC,
3755 TosaTensorValuesGen.tvgDefault,
3756 TosaArgGen.agResize,
3757 ),
James Ward24dbc422022-10-19 12:20:31 +01003758 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 "invalid_test_validators": (
3760 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003761 ),
3762 "error_if_validators": (
3763 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003764 TosaErrorValidator.evScaleSmallerEqualZero,
3765 TosaErrorValidator.evScaleNLargerMax,
3766 TosaErrorValidator.evScaleDLargerMax,
3767 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003768 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003769 TosaErrorValidator.evBorderSmallerMin,
3770 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003771 TosaErrorValidator.evWrongInputType,
3772 TosaErrorValidator.evWrongOutputType,
3773 TosaErrorValidator.evWrongRank,
3774 TosaErrorValidator.evWrongInputList,
3775 TosaErrorValidator.evWrongOutputList,
3776 TosaErrorValidator.evBatchMismatch,
3777 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003778 TosaErrorValidator.evResizeOutputShapeMismatch,
3779 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003780 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003781 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003782 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003783 "cast": {
3784 "op": Op.CAST,
3785 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003786 "build_fcn": (
3787 build_cast,
3788 TosaTensorGen.tgBasic,
3789 TosaTensorValuesGen.tvgDefault,
3790 TosaArgGen.agCast,
3791 ),
James Ward8b390432022-08-12 20:48:56 +01003792 "types": (
3793 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003794 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003795 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003796 DType.INT8,
3797 DType.INT16,
3798 DType.INT32,
3799 DType.BOOL,
3800 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003801 "error_if_validators": (
3802 TosaErrorValidator.evWrongInputType,
3803 TosaErrorValidator.evWrongOutputType,
3804 TosaErrorValidator.evWrongInputList,
3805 TosaErrorValidator.evWrongOutputList,
3806 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003807 },
3808 "rescale": {
3809 "op": Op.RESCALE,
3810 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003811 "build_fcn": (
3812 build_rescale,
3813 TosaTensorGen.tgBasic,
3814 TosaTensorValuesGen.tvgDefault,
3815 TosaArgGen.agRescale,
3816 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003817 "types": [
3818 DType.UINT8,
3819 DType.INT8,
3820 DType.INT16,
3821 DType.INT32,
3822 DType.INT48,
3823 DType.UINT16,
3824 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003825 "error_if_validators": (
3826 TosaErrorValidator.evInputZeroPointNotZero,
3827 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003828 TosaErrorValidator.evU16InputZeroPointNotValid,
3829 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003830 TosaErrorValidator.evScaleTrue,
3831 TosaErrorValidator.evScaleNotTrue,
3832 TosaErrorValidator.evWrongInputType,
3833 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003834 TosaErrorValidator.evWrongInputList,
3835 TosaErrorValidator.evWrongOutputList,
3836 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003837 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003838 # Custom
3839 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003840 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003841 # Two varients of cond_if, one that generates one of two constant tensors (no
3842 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3843 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003844 "cond_if_const": {
3845 "op": Op.COND_IF,
3846 "operands": (0, 2),
3847 "build_fcn": (
3848 build_cond_if_const,
3849 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003850 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003851 TosaArgGen.agCondIf,
3852 ),
3853 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003854 "error_if_validators": (
3855 TosaErrorValidator.evOutputListThenGraphMismatch,
3856 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003857 TosaErrorValidator.evCondIfCondNotMatchingBool,
3858 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003859 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003860 },
3861 "cond_if_binary": {
3862 "op": Op.COND_IF,
3863 "operands": (2, 0),
3864 "build_fcn": (
3865 build_cond_if_binary,
3866 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003867 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003868 TosaArgGen.agCondIf,
3869 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003870 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 "error_if_validators": (
3872 TosaErrorValidator.evInputListThenGraphMismatch,
3873 TosaErrorValidator.evInputListElseGraphMismatch,
3874 TosaErrorValidator.evOutputListThenGraphMismatch,
3875 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003876 TosaErrorValidator.evCondIfCondNotMatchingBool,
3877 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003878 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003879 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003880 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003881 "while_loop": {
3882 "op": Op.WHILE_LOOP,
3883 "operands": (0, 1),
3884 "build_fcn": (
3885 build_while_loop,
3886 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003887 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003888 TosaArgGen.agWhileLoop,
3889 ),
3890 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003891 "error_if_validators": (
3892 TosaErrorValidator.evInputListOutputListMismatch,
3893 TosaErrorValidator.evInputListCondGraphMismatch,
3894 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3895 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3896 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003897 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003898 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003899 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003900 }
3901
Kevin Cheng550ccc52021-03-03 11:21:43 -08003902
Eric Kunzee5e26762020-10-13 16:11:07 -07003903class OutputShaper:
3904 # Methods in this class compute the expected output shape and datatype
3905 # for common classes of operations
3906 def __init__(self):
3907 pass
3908
3909 # These methods return arguments that can be used for
3910 # creating a new output tensor
3911 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003912 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
3913 if error_name != ErrorIf.RankMismatch:
3914 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003915 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003916
3917 shape = []
3918 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003919 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07003920 shape.append(b.shape[i])
3921 else:
3922 shape.append(a.shape[i])
3923
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003924 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003925 all_dtypes = [
3926 DType.INT8,
3927 DType.INT16,
3928 DType.INT32,
3929 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01003930 DType.FP16,
3931 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003932 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003933 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01003934 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3935 outputDType = rng.choice(wrong_dtypes)
3936 else:
3937 outputDType = a.dtype
3938
3939 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003940
3941 @staticmethod
3942 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003943 assert len(a.shape) == len(b.shape)
3944 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003945
3946 shape = []
3947 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003948 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07003949 shape.append(a.shape[i])
3950
Kevin Cheng550ccc52021-03-03 11:21:43 -08003951 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07003952
3953 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003954 def unaryOp(ser, rng, a, error_name=None):
3955 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003956 all_dtypes = [
3957 DType.INT8,
3958 DType.INT16,
3959 DType.INT32,
3960 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003961 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01003962 DType.FP16,
3963 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003964 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01003965 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3966 outputDType = rng.choice(wrong_dtypes)
3967 else:
3968 outputDType = a.dtype
3969
3970 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07003971
3972 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003973 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003974 if error_name != ErrorIf.RankMismatch:
3975 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003976 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07003977
3978 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003979 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003980 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00003981 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
3982 else:
3983 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07003984
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003985 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003986 all_dtypes = [
3987 DType.INT8,
3988 DType.INT16,
3989 DType.INT32,
3990 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003991 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01003992 DType.FP16,
3993 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003994 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01003995 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
3996 outputDType = rng.choice(wrong_dtypes)
3997 else:
3998 outputDType = a.dtype
3999
4000 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004001
4002 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004003 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004004 if error_name != ErrorIf.RankMismatch:
4005 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004006 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004007
4008 # Do broadcast
4009 shape = []
4010 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004011 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004012 shape.append(b.shape[i])
4013 else:
4014 shape.append(a.shape[i])
4015
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004016 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004017 wrong_dtypes = [
4018 DType.INT8,
4019 DType.INT16,
4020 DType.INT32,
4021 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004022 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004023 DType.FP16,
4024 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004025 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004026 outputDType = rng.choice(wrong_dtypes)
4027 else:
4028 outputDType = DType.BOOL
4029
4030 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004031
4032 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004033 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004034 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004035 if error_name not in [
4036 ErrorIf.AxisSmallerZero,
4037 ErrorIf.AxisLargerRank,
4038 ErrorIf.ShapeOfAxisNotOne,
4039 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004040 shape[axis] = 1
4041 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4042 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004043
Matthew Haddond6ce7252021-09-29 15:35:44 +01004044 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004045 all_dtypes = [
4046 DType.INT8,
4047 DType.INT16,
4048 DType.INT32,
4049 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004050 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004051 DType.FP16,
4052 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004053 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004054 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4055 outputDType = rng.choice(wrong_dtypes)
4056 else:
4057 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004058
Matthew Haddond6ce7252021-09-29 15:35:44 +01004059 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004060
4061 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004062 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004063 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004064
4065 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4066 del shape[axis]
4067
4068 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4069 remove = rng.choice([True, False])
4070 if remove and len(shape) > 1:
4071 del shape[0]
4072 else:
4073 shape.append(1)
4074 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4075 for i in range(len(shape)):
4076 shape[i] = shape[i] + rng.integers(1, 10)
4077
4078 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004079 all_dtypes = [
4080 DType.INT8,
4081 DType.INT16,
4082 DType.INT32,
4083 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004084 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004085 DType.FP16,
4086 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004087 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004088 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4089 outputDType = rng.choice(wrong_dtypes)
4090 else:
4091 outputDType = DType.INT32
4092
4093 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004094
4095 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004096 def conv2dOp(
4097 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4098 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004099
4100 # IFM: NHWC
4101 # Filter: OHWI
4102 # OFM: NHWC
4103
Kevin Cheng550ccc52021-03-03 11:21:43 -08004104 h = (
4105 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004106 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004107 + padding[0]
4108 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004109 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004110 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004111
Kevin Cheng550ccc52021-03-03 11:21:43 -08004112 w = (
4113 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004114 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004115 + padding[2]
4116 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004117 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004118 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004119
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004120 if error_name == ErrorIf.ConvOutputShapeMismatch:
4121 choices = [1, 2, 3]
4122 change = rng.choice(choices)
4123 # increment in multiples of stride to not hit non-integer error case
4124 if change in [1, 3]:
4125 h = h + (rng.choice(choices) * strides[0])
4126 if change in [2, 3]:
4127 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004128
Eric Kunzee5e26762020-10-13 16:11:07 -07004129 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4130
James Ward8b390432022-08-12 20:48:56 +01004131 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004132 # Pick some potentially correct output dtype if input type is incorrect
4133 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004134 else:
James Ward8b390432022-08-12 20:48:56 +01004135 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004136
4137 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004138 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004139 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004140 else:
4141 excludes = [out_dtype]
4142 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004143 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004144
Kevin Cheng550ccc52021-03-03 11:21:43 -08004145 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004146
4147 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004148 def conv3dOp(
4149 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4150 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004151
4152 # IFM: NDHWC
4153 # Filter: ODHWI
4154 # OFM: NDHWC
4155
4156 d = (
4157 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004158 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004159 + padding[0]
4160 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004161 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004162 ) // strides[0] + 1
4163
4164 h = (
4165 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004166 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004167 + padding[2]
4168 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004169 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004170 ) // strides[1] + 1
4171
4172 w = (
4173 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004174 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004175 + padding[4]
4176 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004177 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004178 ) // strides[2] + 1
4179
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004180 if error_name == ErrorIf.ConvOutputShapeMismatch:
4181 choices = [1, 2, 3, 4]
4182 change = rng.choice(choices)
4183 # increment in multiples of stride to not hit non-integer error case
4184 if change in [1, 4]:
4185 d = d + (rng.choice(choices) * strides[0])
4186 if change in [2, 4]:
4187 h = h + (rng.choice(choices) * strides[1])
4188 if change in [3, 4]:
4189 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004190
Kevin Cheng1533b852021-09-01 12:51:58 -07004191 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4192
James Ward8b390432022-08-12 20:48:56 +01004193 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004194 # Pick some potentially correct output dtype if input type is incorrect
4195 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004196 else:
James Ward8b390432022-08-12 20:48:56 +01004197 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004198
4199 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004200 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004201 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004202 else:
4203 excludes = [out_dtype]
4204 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004205 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004206
4207 return ser.addOutput(ofm_shape, out_dtype)
4208
4209 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004210 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004211 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004212 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004213 # IFM: NHWC
4214 # Filter: HWCM
4215 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004216
Kevin Cheng550ccc52021-03-03 11:21:43 -08004217 h = (
4218 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004219 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004220 + padding[0]
4221 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004222 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004223 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004224
Kevin Cheng550ccc52021-03-03 11:21:43 -08004225 w = (
4226 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004227 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004228 + padding[2]
4229 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004230 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004231 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004232
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004233 if error_name == ErrorIf.ConvOutputShapeMismatch:
4234 choices = [1, 2, 3]
4235 change = rng.choice(choices)
4236 # increment in multiples of stride to not hit non-integer error case
4237 if change in [1, 3]:
4238 h = h + (rng.choice(choices) * strides[0])
4239 if change in [2, 3]:
4240 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004241
Eric Kunzee5e26762020-10-13 16:11:07 -07004242 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4243
James Ward8b390432022-08-12 20:48:56 +01004244 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004245 # Pick some potentially correct output dtype if input type is incorrect
4246 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004247 else:
James Ward8b390432022-08-12 20:48:56 +01004248 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004249
4250 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004251 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004252 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004253 else:
4254 excludes = [out_dtype]
4255 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004256 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004257
Kevin Cheng550ccc52021-03-03 11:21:43 -08004258 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004259
4260 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004261 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004262 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004263 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004264 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004265 h = 1
4266 w = 1
4267 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004268 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4269 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004270
4271 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004272 choices = [1, 2, 3]
4273 change = rng.choice(choices)
4274 # increment in multiples of stride to not hit non-integer error case
4275 if change in [1, 3]:
4276 h = h + (rng.choice(choices) * stride[0])
4277 if change in [2, 3]:
4278 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004279 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004280
4281 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004282 all_dtypes = [
4283 DType.INT8,
4284 DType.INT16,
4285 DType.INT32,
4286 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004287 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004288 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004289 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004290 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004291 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4292 outputDType = rng.choice(wrong_dtypes)
4293 else:
4294 outputDType = ifm.dtype
4295
4296 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004297
4298 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004299 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004300 # input: N, IC
4301 # filter: OC, IC
4302 # output: N, OC
4303
4304 output_shape = [input.shape[0], filter.shape[0]]
4305
James Ward8b390432022-08-12 20:48:56 +01004306 # Validated in arg_gen (also invalidated for ErrorIf)
4307 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004308
Kevin Cheng550ccc52021-03-03 11:21:43 -08004309 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004310
4311 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004312 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004313 # a: N, H, C
4314 # b: N, C, W
4315 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004316
Kevin Cheng2d60f002021-06-09 14:18:32 -07004317 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004318
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004319 if error_name == ErrorIf.WrongOutputType:
4320 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004321 incorrect_types = (
4322 DType.INT4,
4323 DType.INT8,
4324 DType.INT16,
4325 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004326 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004327 DType.FP16,
4328 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004329 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004330 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004331 incorrect_types = (
4332 DType.INT4,
4333 DType.INT8,
4334 DType.INT16,
4335 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004336 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004337 DType.FP16,
4338 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004339 )
James Ward24dbc422022-10-19 12:20:31 +01004340 elif (
4341 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4342 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004343 incorrect_types = (
4344 DType.INT4,
4345 DType.INT8,
4346 DType.INT16,
4347 DType.INT32,
4348 DType.INT48,
4349 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004350 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004351 elif error_name == ErrorIf.WrongInputType:
4352 # Pick some potentially correct output dtype if input type is incorrect
4353 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004354 else:
James Ward8b390432022-08-12 20:48:56 +01004355 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004356
Kevin Cheng550ccc52021-03-03 11:21:43 -08004357 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004358
4359 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004360 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004361 input1 = a[0]
4362 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004363
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004364 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004365 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004366 if not (
4367 # unable to concat tensors of different ranks
4368 error_name == ErrorIf.ConcatInputRankMismatch
4369 # unable to concat tensors along an invalid axis
4370 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004371 ):
4372 for tensor in remaining_inputs:
4373 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004374
Matthew Haddon01c359d2021-10-15 16:30:48 +01004375 if error_name == ErrorIf.ConcatShapeSumMismatch:
4376 output_shape[axis] += rng.integers(5, 10)
4377
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004378 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004379 all_dtypes = {
4380 DType.INT8,
4381 DType.INT16,
4382 DType.INT32,
4383 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004384 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004385 DType.FP16,
4386 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004387 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004388 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4389 outputDType = rng.choice(wrong_dtypes)
4390 else:
4391 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004392
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004393 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004394
4395 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004396 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004397
4398 output_shape = a.shape.copy()
4399
4400 for i in range(len(output_shape)):
4401 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4402
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004403 if error_name == ErrorIf.PadOutputShapeMismatch:
4404 bad_dim = rng.choice(range(len(output_shape)))
4405 output_shape[bad_dim] -= rng.choice([1, 2])
4406
Matthew Haddone807aae2021-10-11 18:12:58 +01004407 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004408 all_dtypes = [
4409 DType.INT8,
4410 DType.INT16,
4411 DType.INT32,
4412 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004413 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004414 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004415 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004416 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004417 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4418 outputDType = rng.choice(wrong_dtypes)
4419 else:
4420 outputDType = a.dtype
4421
4422 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004423
4424 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004425 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004426 output_shape = shape.copy()
4427
Matthew Haddone807aae2021-10-11 18:12:58 +01004428 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4429 for i in range(len(output_shape)):
4430 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4431
4432 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004433 all_dtypes = [
4434 DType.INT8,
4435 DType.INT16,
4436 DType.INT32,
4437 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004438 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004439 DType.FP16,
4440 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004441 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004442 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4443 outputDType = rng.choice(wrong_dtypes)
4444 else:
4445 outputDType = a.dtype
4446
4447 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004448
4449 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004450 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004451
Matthew Haddone807aae2021-10-11 18:12:58 +01004452 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004453 all_dtypes = [
4454 DType.INT8,
4455 DType.INT16,
4456 DType.INT32,
4457 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004458 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004459 DType.FP16,
4460 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004461 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004462 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4463 outputDType = rng.choice(wrong_dtypes)
4464 else:
4465 outputDType = a.dtype
4466
4467 if error_name == ErrorIf.SizeOutputShapeMismatch:
4468 output_shape = size.copy()
4469 for index in range(len(output_shape)):
4470 if output_shape[index] <= 2:
4471 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4472 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004473 output_shape[index] = output_shape[index] + rng.choice(
4474 [-2, -1, 1, 2]
4475 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004476 else:
4477 output_shape = size.copy()
4478
4479 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004480
4481 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004482 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004483
4484 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004485 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004486
4487 for i in range(len(output_shape)):
4488 output_shape[i] = a.shape[i] * multiples[i]
4489
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004490 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004491 all_dtypes = [
4492 DType.INT8,
4493 DType.INT16,
4494 DType.INT32,
4495 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004496 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004497 DType.FP16,
4498 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004499 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004500 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4501 outputDType = rng.choice(wrong_dtypes)
4502 else:
4503 outputDType = a.dtype
4504
4505 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004506
4507 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004508 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004509 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004510
Kevin Cheng550ccc52021-03-03 11:21:43 -08004511 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004512
Matthew Haddone807aae2021-10-11 18:12:58 +01004513 if error_name == ErrorIf.IndexOutsideBounds:
4514 for i in range(len(output_shape)):
4515 output_shape[i] = a.shape[0]
4516 else:
4517 for i in range(len(output_shape)):
4518 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004519
Matthew Haddone807aae2021-10-11 18:12:58 +01004520 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004521 all_dtypes = [
4522 DType.INT8,
4523 DType.INT16,
4524 DType.INT32,
4525 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004526 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004527 DType.FP16,
4528 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004529 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004530 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4531 outputDType = rng.choice(wrong_dtypes)
4532 else:
4533 outputDType = a.dtype
4534
4535 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004536
4537 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004538 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004539 if error_name != ErrorIf.WrongRank:
4540 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004541 assert len(indices.shape) == 2
4542 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004543
Kevin Cheng77d0f762020-11-24 10:26:32 -08004544 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4545
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004546 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004547 all_dtypes = [
4548 DType.INT8,
4549 DType.INT16,
4550 DType.INT32,
4551 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004552 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004553 DType.FP16,
4554 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004555 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004556 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4557 outputDType = rng.choice(wrong_dtypes)
4558 else:
4559 outputDType = values.dtype
4560
4561 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004562
4563 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004564 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004565 if error_name != ErrorIf.WrongRank:
4566 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004567 assert len(indices.shape) == 2
4568 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004569 assert values_in.shape[0] == indices.shape[0] # N
4570 assert input.shape[1] == indices.shape[1] # W
4571 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004572
4573 output_shape = values_in.shape
4574
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004575 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004576 all_dtypes = [
4577 DType.INT8,
4578 DType.INT16,
4579 DType.INT32,
4580 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004581 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004582 DType.FP16,
4583 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004584 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004585 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4586 outputDType = rng.choice(wrong_dtypes)
4587 else:
4588 outputDType = values_in.dtype
4589
4590 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004591
4592 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004593 def tableOp(ser, rng, input, error_name=None):
4594 # Same shape as the input, dtype dependent on input dtype
4595 if error_name != ErrorIf.WrongInputType:
4596 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004597 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004598 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004599 wrong_dtypes = [
4600 DType.INT8,
4601 DType.INT16,
4602 DType.INT32,
4603 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004604 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004605 DType.FP16,
4606 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004607 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004608 wrong_dtypes.remove(output_dtype)
4609 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004610 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004611
4612 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004613 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004614 serializer,
4615 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004616 input,
4617 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004618 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004619 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004620 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004621 input_dtype,
4622 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004623 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004624 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004625 # Calculate OH, OW
4626 scale_y_n = scale[0]
4627 scale_y_d = scale[1]
4628 scale_x_n = scale[2]
4629 scale_x_d = scale[3]
4630 if error_name == ErrorIf.ScaleSmallerEqualZero:
4631 scale_y_n = max(scale_y_n, 1)
4632 scale_y_d = max(scale_y_d, 1)
4633 scale_x_n = max(scale_x_n, 1)
4634 scale_x_d = max(scale_x_d, 1)
4635
4636 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4637 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4638
4639 if error_name is not None:
4640 # Make sure the output tensor is valid, which can occur when
4641 # scale, offset or border have been changed for ERROR_IFs
4642 oh = max(oh, 1)
4643 ow = max(ow, 1)
4644 if error_name != ErrorIf.MaxDimExceeded:
4645 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4646 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4647
4648 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4649 choices = [1, 2, 3]
4650 change = rng.choice(choices)
4651 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4652 if change in [1, 3]:
4653 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4654 oh -= scale_y_d
4655 assert oh > 0 # Should have been caught in agResize
4656 else:
4657 oh += scale_y_d
4658 if change in [2, 3]:
4659 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4660 ow -= scale_x_d
4661 assert ow > 0 # Should have been caught in agResize
4662 else:
4663 ow += scale_x_d
4664
Matthew Haddon848efb42021-09-09 12:30:53 +01004665 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004666 output_dims = [
4667 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004668 oh,
4669 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004670 input.shape[0],
4671 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004672 elif error_name == ErrorIf.BatchMismatch:
4673 output_dims = [
4674 input.shape[0] + rng.integers(1, 10),
4675 oh,
4676 ow,
4677 input.shape[3],
4678 ]
4679 elif error_name == ErrorIf.ChannelMismatch:
4680 output_dims = [
4681 input.shape[0],
4682 oh,
4683 ow,
4684 input.shape[3] + rng.integers(1, 10),
4685 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004686 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004687 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004688
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004689 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004690
4691 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004692 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004693 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004694
4695 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004696 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004697 if error_name == ErrorIf.ConvOutputShapeMismatch:
4698 choices = [1, 2, 3]
4699 change = rng.choice(choices)
4700 if change in [1, 3]:
4701 output_shape[1] = output_shape[1] + rng.choice(choices)
4702 if change in [2, 3]:
4703 output_shape[2] = output_shape[2] + rng.choice(choices)
4704
James Ward8b390432022-08-12 20:48:56 +01004705 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004706 # Pick some potentially correct output dtype if input type is incorrect
4707 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004708 else:
James Ward8b390432022-08-12 20:48:56 +01004709 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004710
4711 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004712 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004713 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004714 else:
4715 excludes = [out_dtype]
4716 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004717 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004718
Kevin Cheng550ccc52021-03-03 11:21:43 -08004719 return ser.addOutput(output_shape, out_dtype)