blob: 2b762aae1b248196a55d408897d5ecbc962f98bd [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Jeremy Johnson05c711e2022-12-12 18:00:41 +000017from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010018from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010019from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010020from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000021from tosa.DType import DType
22from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010023
24
Eric Kunzee5e26762020-10-13 16:11:07 -070025class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010026 # Maximum rank of tensor supported by test generator.
27 TOSA_TENSOR_MAX_RANK = 6
28
Eric Kunzee5e26762020-10-13 16:11:07 -070029 def __init__(self, args):
30 self.args = args
31 self.basePath = args.output_dir
32 self.random_seed = args.random_seed
33 self.ser = None
34 self.rng = np.random.default_rng(self.random_seed)
35 self.createDynamicOpLists()
36 self.initOpListDefaults()
37 self.quantGen = TosaQuantGen()
38 # Force makeShape to do a specific starting shape
39 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010040 # Work out floating point range
41 self.random_fp_low = min(args.tensor_fp_value_range)
42 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070043
44 def createSerializer(self, opName, testPath):
45 self.testPath = os.path.join(opName, testPath)
46
47 fullPath = os.path.join(self.basePath, self.testPath)
48 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010049 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070050
51 def getSerializer(self):
52 return self.ser
53
54 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080055 with open(
56 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
57 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070058 fd.write(self.ser.serialize())
59
Kevin Cheng550ccc52021-03-03 11:21:43 -080060 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
61 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070062
Matthew Haddon74567092021-07-16 15:38:20 +010063 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000064 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010065 seed = self.random_seed + 1
66 self.rng = np.random.default_rng(seed)
67
Eric Kunzee5e26762020-10-13 16:11:07 -070068 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070069 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070071 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070072 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070073 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070074 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010075 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
76 elif dtype == DType.UINT8:
77 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070078 elif dtype == DType.INT16:
79 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010080 elif dtype == DType.UINT16:
81 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070082 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080083 return np.int32(
84 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
85 )
Eric Kunzee5e26762020-10-13 16:11:07 -070086 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080087 return np.int64(
88 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
89 )
James Ward8b390432022-08-12 20:48:56 +010090 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010091 return np.float16(
92 self.rng.uniform(
93 low=self.random_fp_low, high=self.random_fp_high, size=shape
94 )
95 )
James Ward24dbc422022-10-19 12:20:31 +010096 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010097 f32_tensor = np.float32(
98 self.rng.uniform(
99 low=self.random_fp_low, high=self.random_fp_high, size=shape
100 )
101 )
James Ward24dbc422022-10-19 12:20:31 +0100102 # Floor the last 16 bits of each f32 value
103 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100104 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100105 return np.float32(
106 self.rng.uniform(
107 low=self.random_fp_low, high=self.random_fp_high, size=shape
108 )
109 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700110 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800111 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700112
Kevin Cheng989cb052021-04-28 16:29:44 -0700113 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700114 placeholders = []
115
Kevin Cheng989cb052021-04-28 16:29:44 -0700116 assert len(shape_list) == len(dtype_list)
117
118 for idx, shape in enumerate(shape_list):
119 arr = self.getRandTensor(shape, dtype_list[idx])
120 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 return placeholders
123
Kevin Cheng989cb052021-04-28 16:29:44 -0700124 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700125 consts = []
126
Kevin Cheng989cb052021-04-28 16:29:44 -0700127 assert len(shape_list) == len(dtype_list)
128
129 for idx, shape in enumerate(shape_list):
130 arr = self.getRandTensor(shape, dtype_list[idx])
131 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700132
133 return consts
134
135 def makeShape(self, rank):
136 if self.targetted_shape:
137 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800138 return np.int32(
139 self.rng.integers(
140 low=self.args.tensor_shape_range[0],
141 high=self.args.tensor_shape_range[1],
142 size=rank,
143 )
144 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700145
146 def setTargetShape(self, shape):
147 self.targetted_shape = shape
148
149 def randInt(self, low=0, high=256):
150 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
151
152 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100153 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100154 return np.float32(
155 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
156 )
James Ward8b390432022-08-12 20:48:56 +0100157 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100158 return np.float16(
159 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
160 )
James Ward24dbc422022-10-19 12:20:31 +0100161 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100162 rand_f32 = np.float32(
163 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
164 )
James Ward24dbc422022-10-19 12:20:31 +0100165 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700166 elif dtype == DType.BOOL:
167 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700168 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700169 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700170 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100172 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 elif dtype == DType.INT16:
174 low, high = (-32768, 32768)
175 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800176 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700177 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800178 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 # Special size
180 return np.int64(self.rng.integers(low, high, size=1))[0]
181 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800182 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700183
184 return np.int32(self.rng.integers(low, high, size=1))[0]
185
186 def shapeStr(self, shape):
187
188 sStr = []
189 # Convert to strings
190 for i in shape:
191 sStr.append(str(i))
192
Kevin Cheng550ccc52021-03-03 11:21:43 -0800193 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700194
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100195 def typeStr(self, dtype):
196 if isinstance(dtype, list) or isinstance(dtype, tuple):
197 assert len(dtype) >= 2
198 strs = [self.typeStr(t) for t in dtype]
199 # Limit types to the first 2 as the 3rd is the accumulator
200 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700201 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100202 if dtype in DTYPE_ATTRIBUTES:
203 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700204 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100205 raise Exception(
206 "Unknown dtype, cannot convert to string: {}".format(dtype)
207 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100209 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100210 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100211 if dtype in DTYPE_ATTRIBUTES:
212 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700215
Luke Hutton57287132023-02-06 14:54:18 +0000216 def constrictBatchSize(self, shape):
217 # Limit the batch size unless an explicit target shape set
218 if self.args.max_batch_size and not self.args.target_shapes:
219 shape[0] = min(shape[0], self.args.max_batch_size)
220 return shape
221
Eric Kunzee5e26762020-10-13 16:11:07 -0700222 # Argument generators
223 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
224 # Where the string descriptor is used to generate the test name and
225 # The build_fcn_arg_list is expanded and passed to the operator test
226 # build function
227
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100228 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
229 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
230
Matthew Haddon848efb42021-09-09 12:30:53 +0100231 # build_placeholder returns an int, ABS/other ops does not
232 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000233 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100234 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000235 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000236 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100237 return result_tens
238
239 # Ensure new output type has correct qinfo
240 if error_name == ErrorIf.WrongOutputType:
241 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000242 qinfo = [
243 TosaQuantGen.getZeroPoint(self, a.dtype),
244 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
245 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100246
247 # Invalidate Input/Output list for error if checks.
248 input_list = [a.name]
249 output_list = [result_tens.name]
250 pCount, cCount = op["operands"]
251 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000252 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
253 self, error_name, input_list, output_list
254 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100255
Les Bell729b0352021-11-24 10:28:21 +0000256 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100257 self.ser,
258 validator_fcns,
259 error_name,
260 op=op,
261 input_dtype=a.dtype,
262 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000263 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000264 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100265 input_list=input_list,
266 output_list=output_list,
267 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000268 ):
269 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100270
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000271 attr = None
272 if op["op"] == Op.NEGATE:
273 attr = ts.TosaSerializerAttribute()
274 attr.NegateAttribute(qinfo[0], qinfo[1])
275
276 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700277 return result_tens
278
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100279 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000280 result_tens = OutputShaper.binaryBroadcastOp(
281 self.ser, self.rng, a, b, error_name
282 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100283
284 # Invalidate Input/Output list for error if checks.
285 input_list = [a.name, b.name]
286 output_list = [result_tens.name]
287 pCount, cCount = op["operands"]
288 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000289 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
290 self, error_name, input_list, output_list
291 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100292
Les Bell729b0352021-11-24 10:28:21 +0000293 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100294 self.ser,
295 validator_fcns,
296 error_name,
297 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000298 input1=a,
299 input2=b,
300 input_dtype=a.dtype,
301 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000302 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100303 input_list=input_list,
304 output_list=output_list,
305 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000306 ):
307 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100308
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000309 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700310 return result_tens
311
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100312 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700313 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000314 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700315 return result_tens
316
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000317 def build_arithmetic_right_shift(
318 self, op, a, b, round, validator_fcns=None, error_name=None
319 ):
320 result_tens = OutputShaper.binaryBroadcastOp(
321 self.ser, self.rng, a, b, error_name
322 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100323
324 # Invalidate Input/Output list for error if checks.
325 input_list = [a.name, b.name]
326 output_list = [result_tens.name]
327 pCount, cCount = op["operands"]
328 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000329 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
330 self, error_name, input_list, output_list
331 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100332
Les Bell729b0352021-11-24 10:28:21 +0000333 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100334 self.ser,
335 validator_fcns,
336 error_name,
337 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000338 input1=a,
339 input2=b,
340 input_dtype=a.dtype,
341 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000342 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100343 input_list=input_list,
344 output_list=output_list,
345 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000346 ):
347 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800348
349 attr = ts.TosaSerializerAttribute()
350 attr.ArithmeticRightShiftAttribute(round)
351
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000352 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800353 return result_tens
354
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100355 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000356 result_tens = OutputShaper.binaryBroadcastOp(
357 self.ser, self.rng, a, b, error_name
358 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700359
360 # Special for multiply:
361 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100362 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700363 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100364 if error_name == ErrorIf.WrongOutputType:
365 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
366 outputDType = self.rng.choice(all_dtypes)
367 result_tens.setDtype(outputDType)
368
369 # Invalidate Input/Output list for error if checks.
370 input_list = [a.name, b.name]
371 output_list = [result_tens.name]
372 pCount, cCount = op["operands"]
373 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000374 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
375 self, error_name, input_list, output_list
376 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100377
Les Bell729b0352021-11-24 10:28:21 +0000378 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100379 self.ser,
380 validator_fcns,
381 error_name,
382 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000383 input1=a,
384 input2=b,
385 input_dtype=a.dtype,
386 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000387 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100388 input_list=input_list,
389 output_list=output_list,
390 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000391 ):
392 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700393
Kevin Chengaee1fac2020-11-11 13:54:06 -0800394 attr = ts.TosaSerializerAttribute()
395 attr.MulAttribute(shift)
396
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000397 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700398 return result_tens
399
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100400 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
401 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700402
Kevin Chengfe392ce2021-10-18 21:51:55 +0000403 attr = ts.TosaSerializerAttribute()
404 attr.TableAttribute(table)
405
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100406 # Invalidate Input/Output list for error if checks.
407 input_list = [a.name]
408 output_list = [result_tens.name]
409 pCount, cCount = op["operands"]
410 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000411 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
412 self, error_name, input_list, output_list
413 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100414
Les Bell729b0352021-11-24 10:28:21 +0000415 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100416 self.ser,
417 validator_fcns,
418 error_name,
419 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000420 input_shape=a.shape,
421 input_dtype=a.dtype,
422 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000423 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100424 input_list=input_list,
425 output_list=output_list,
426 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000427 ):
428 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100429
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000430 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700431
432 return result_tens
433
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100434 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
435 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
436
437 # Invalidate Input/Output list for error if checks.
438 input_list = [cond.name, a.name, b.name]
439 output_list = [result_tens.name]
440 pCount, cCount = op["operands"]
441 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000442 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
443 self, error_name, input_list, output_list
444 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100445
Les Bell729b0352021-11-24 10:28:21 +0000446 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100447 self.ser,
448 validator_fcns,
449 error_name,
450 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000451 input1=cond,
452 input2=a,
453 input3=b,
454 input_shape=a.shape,
455 input_dtype=a.dtype,
456 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000457 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100458 input_list=input_list,
459 output_list=output_list,
460 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000461 ):
462 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100463
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000464 self.ser.addOperator(
465 op["op"],
466 input_list,
467 output_list,
468 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700469 return result_tens
470
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000472 result_tens = OutputShaper.binaryComparisonOp(
473 self.ser, self.rng, a, b, error_name
474 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100475
476 # Invalidate Input/Output list for error if checks.
477 input_list = [a.name, b.name]
478 output_list = [result_tens.name]
479 pCount, cCount = op["operands"]
480 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000481 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
482 self, error_name, input_list, output_list
483 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100484
Les Bell729b0352021-11-24 10:28:21 +0000485 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100486 self.ser,
487 validator_fcns,
488 error_name,
489 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000490 input1=a,
491 input2=b,
492 input_shape=a.shape,
493 input_dtype=a.dtype,
494 output_shape=result_tens.shape,
495 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000496 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100497 input_list=input_list,
498 output_list=output_list,
499 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000500 ):
501 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100502
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000503 self.ser.addOperator(
504 op["op"],
505 input_list,
506 output_list,
507 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700508 return result_tens
509
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100510 def build_argmax(self, op, a, axis, validator_fcns, error_name):
511 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
512
513 # Invalidate Input/Output list for error if checks.
514 input_list = [a.name]
515 output_list = [result_tens.name]
516 pCount, cCount = op["operands"]
517 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000518 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
519 self, error_name, input_list, output_list
520 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100521
Les Bell729b0352021-11-24 10:28:21 +0000522 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100523 self.ser,
524 validator_fcns,
525 error_name,
526 op=op,
527 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 input_shape=a.shape,
529 input_dtype=a.dtype,
530 output_shape=result_tens.shape,
531 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000532 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100533 input_list=input_list,
534 output_list=output_list,
535 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000536 ):
537 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700538
539 attr = ts.TosaSerializerAttribute()
540 attr.AxisAttribute(axis)
541
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000542 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700543 return result_tens
544
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000545 def build_pool2d(
546 self,
547 op,
548 input,
James Ward8b390432022-08-12 20:48:56 +0100549 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000550 stride,
551 pad,
552 kernel,
553 validator_fcns=None,
554 error_name=None,
555 qinfo=None,
556 ):
557 result_tens = OutputShaper.pool2dOp(
558 self.ser, self.rng, input, kernel, stride, pad, error_name
559 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100560
561 # Ensure new output type has correct qinfo
562 if error_name == ErrorIf.WrongInputType:
563 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000564 qinfo = [
565 TosaQuantGen.getZeroPoint(self, input.dtype),
566 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
567 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100568
569 # Invalidate Input/Output list for error if checks.
570 input_list = [input.name]
571 output_list = [result_tens.name]
572 pCount, cCount = op["operands"]
573 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000574 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
575 self, error_name, input_list, output_list
576 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100577
Les Bell729b0352021-11-24 10:28:21 +0000578 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100579 self.ser,
580 validator_fcns,
581 error_name,
582 op=op,
583 input_shape=input.shape,
584 input_dtype=input.dtype,
585 output_shape=result_tens.shape,
586 output_dtype=result_tens.dtype,
587 kernel=kernel,
588 stride=stride,
589 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000590 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000591 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100592 input_list=input_list,
593 output_list=output_list,
594 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000595 ):
596 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700597
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000598 if qinfo is None:
599 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700600
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000601 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100602 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000603
604 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700605 return result_tens
606
James Ward8b390432022-08-12 20:48:56 +0100607 def build_maxpool2d(
608 self,
609 op,
610 input,
611 stride,
612 pad,
613 kernel,
614 validator_fcns=None,
615 error_name=None,
616 qinfo=None,
617 ):
618 # Same as build_pool2d but manually sets accum_dtype value
619 # (maxpool has no accum_dtype)
620 return self.build_pool2d(
621 op,
622 input,
623 DType.UNKNOWN,
624 stride,
625 pad,
626 kernel,
627 validator_fcns,
628 error_name,
629 qinfo,
630 )
631
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000632 def build_conv2d(
633 self,
634 op,
635 ifm,
636 filter,
637 bias,
James Ward8b390432022-08-12 20:48:56 +0100638 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000639 strides,
640 padding,
641 dilations,
642 validator_fcns=None,
643 error_name=None,
644 qinfo=None,
645 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800646 assert len(padding) == 4
647 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100648 self.ser,
649 self.rng,
650 ifm,
651 filter,
652 accum_dtype,
653 strides,
654 padding,
655 dilations,
656 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000657 )
658
659 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000660 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
661 DType.INT8,
662 DType.UINT8,
663 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000664 qinfo = [
665 TosaQuantGen.getZeroPoint(self, ifm.dtype),
666 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
667 ]
Les Bell0e027d42021-11-09 14:42:14 +0000668
669 # Invalidate Input/Output list for error_if checks.
670 input_list = [ifm.name, filter.name, bias.name]
671 output_list = [result_tens.name]
672 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000673 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
674 self, error_name, input_list, output_list
675 )
Les Bell0e027d42021-11-09 14:42:14 +0000676
Les Bell729b0352021-11-24 10:28:21 +0000677 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000678 self.ser,
679 validator_fcns,
680 error_name,
681 op=op,
682 input_dtype=ifm.dtype,
683 weight_dtype=filter.dtype,
684 output_dtype=result_tens.dtype,
685 qinfo=qinfo,
686 input_list=input_list,
687 num_operands=num_operands,
688 output_list=output_list,
689 pad=padding,
690 stride=strides,
691 dilation=dilations,
692 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100693 weight_shape=filter.shape,
694 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000695 ):
696 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700697
698 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000699 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700700
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000701 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700702 return result_tens
703
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000704 def build_conv3d(
705 self,
706 op,
707 ifm,
708 filter,
709 bias,
James Ward8b390432022-08-12 20:48:56 +0100710 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000711 strides,
712 padding,
713 dilations,
714 validator_fcns=None,
715 error_name=None,
716 qinfo=None,
717 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700718 assert len(padding) == 6
719 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100720 self.ser,
721 self.rng,
722 ifm,
723 filter,
724 accum_dtype,
725 strides,
726 padding,
727 dilations,
728 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000729 )
730
731 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000732 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
733 DType.INT8,
734 DType.UINT8,
735 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000736 qinfo = [
737 TosaQuantGen.getZeroPoint(self, ifm.dtype),
738 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
739 ]
Les Bell0e027d42021-11-09 14:42:14 +0000740
741 # Invalidate Input/Output list for error_if checks.
742 input_list = [ifm.name, filter.name, bias.name]
743 output_list = [result_tens.name]
744 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000745 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
746 self, error_name, input_list, output_list
747 )
Les Bell0e027d42021-11-09 14:42:14 +0000748
Les Bell729b0352021-11-24 10:28:21 +0000749 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000750 self.ser,
751 validator_fcns,
752 error_name,
753 op=op,
754 input_dtype=ifm.dtype,
755 weight_dtype=filter.dtype,
756 output_dtype=result_tens.dtype,
757 qinfo=qinfo,
758 input_list=input_list,
759 num_operands=num_operands,
760 output_list=output_list,
761 pad=padding,
762 stride=strides,
763 dilation=dilations,
764 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100765 weight_shape=filter.shape,
766 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000767 ):
768 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700769
770 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000771 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700772
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000773 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700774 return result_tens
775
Kevin Cheng550ccc52021-03-03 11:21:43 -0800776 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000777 self,
778 op,
779 ifm,
780 filter,
781 bias,
James Ward8b390432022-08-12 20:48:56 +0100782 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000783 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700784 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000785 output_shape,
786 validator_fcns=None,
787 error_name=None,
788 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800789 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700790 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000791 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100792 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000793 )
Les Bell0e027d42021-11-09 14:42:14 +0000794
795 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000796 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
797 DType.INT8,
798 DType.UINT8,
799 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000800 qinfo = [
801 TosaQuantGen.getZeroPoint(self, ifm.dtype),
802 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
803 ]
Les Bell0e027d42021-11-09 14:42:14 +0000804
805 # Invalidate Input/Output list for error_if checks.
806 input_list = [ifm.name, filter.name, bias.name]
807 output_list = [result_tens.name]
808 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000809 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
810 self, error_name, input_list, output_list
811 )
Les Bell0e027d42021-11-09 14:42:14 +0000812
Les Bell729b0352021-11-24 10:28:21 +0000813 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000814 self.ser,
815 validator_fcns,
816 error_name,
817 op=op,
818 input_dtype=ifm.dtype,
819 weight_dtype=filter.dtype,
820 output_dtype=result_tens.dtype,
821 qinfo=qinfo,
822 input_list=input_list,
823 num_operands=num_operands,
824 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700825 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000826 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000827 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100828 weight_shape=filter.shape,
829 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000830 ):
831 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700832
833 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000834 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700835
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000836 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700837 return result_tens
838
Kevin Cheng550ccc52021-03-03 11:21:43 -0800839 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000840 self,
841 op,
842 ifm,
843 filter,
844 bias,
James Ward8b390432022-08-12 20:48:56 +0100845 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000846 strides,
847 padding,
848 dilations,
849 validator_fcns=None,
850 error_name=None,
851 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800852 ):
853 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100854 self.ser,
855 self.rng,
856 ifm,
857 filter,
858 accum_dtype,
859 strides,
860 padding,
861 dilations,
862 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000863 )
864
865 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000866 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
867 DType.INT8,
868 DType.UINT8,
869 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000870 qinfo = [
871 TosaQuantGen.getZeroPoint(self, ifm.dtype),
872 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
873 ]
Les Bell0e027d42021-11-09 14:42:14 +0000874
875 # Invalidate Input/Output list for error_if checks.
876 input_list = [ifm.name, filter.name, bias.name]
877 output_list = [result_tens.name]
878 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000879 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
880 self, error_name, input_list, output_list
881 )
Les Bell0e027d42021-11-09 14:42:14 +0000882
Les Bell729b0352021-11-24 10:28:21 +0000883 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000884 self.ser,
885 validator_fcns,
886 error_name,
887 op=op,
888 input_dtype=ifm.dtype,
889 weight_dtype=filter.dtype,
890 output_dtype=result_tens.dtype,
891 qinfo=qinfo,
892 input_list=input_list,
893 num_operands=num_operands,
894 output_list=output_list,
895 pad=padding,
896 stride=strides,
897 dilation=dilations,
898 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100899 weight_shape=filter.shape,
900 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000901 ):
902 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700903
904 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000905 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700906
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000907 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700908 return result_tens
909
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000910 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100911 self,
912 op,
913 ifm,
914 filter,
915 bias,
916 accum_dtype,
917 validator_fcns=None,
918 error_name=None,
919 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000920 ):
921 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100922 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000923 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100924
925 # Invalidate Input/Output list for error if checks.
926 input_list = [ifm.name, filter.name, bias.name]
927 output_list = [result_tens.name]
928 pCount, cCount = op["operands"]
929 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000930 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
931 self, error_name, input_list, output_list
932 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100933
Les Bell729b0352021-11-24 10:28:21 +0000934 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100935 self.ser,
936 validator_fcns,
937 error_name,
938 op=op,
939 input_shape=ifm.shape,
940 input_dtype=ifm.dtype,
941 weight_dtype=filter.dtype,
942 output_shape=result_tens.shape,
943 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000944 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000945 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100946 input_list=input_list,
947 output_list=output_list,
948 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100949 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000950 ):
951 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700952
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000953 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000954 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000955
956 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700957 return result_tens
958
James Ward8b390432022-08-12 20:48:56 +0100959 def build_matmul(
960 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
961 ):
962 result_tens = OutputShaper.matmulOp(
963 self.ser, self.rng, a, b, accum_dtype, error_name
964 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100965
966 # Invalidate Input/Output list for error if checks.
967 input_list = [a.name, b.name]
968 output_list = [result_tens.name]
969 pCount, cCount = op["operands"]
970 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000971 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
972 self, error_name, input_list, output_list
973 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100974
Les Bell729b0352021-11-24 10:28:21 +0000975 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100976 self.ser,
977 validator_fcns,
978 error_name,
979 op=op,
980 input_shape=a.shape,
981 input_dtype=a.dtype,
982 input2_shape=b.shape,
983 input2_dtype=b.dtype,
984 output_shape=result_tens.shape,
985 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000986 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000987 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100988 input_list=input_list,
989 output_list=output_list,
990 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100991 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000992 ):
993 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100994
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000995 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000996 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000997
998 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700999 return result_tens
1000
Matthew Haddond6ce7252021-09-29 15:35:44 +01001001 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1002 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1003
1004 # Invalidate Input/Output list for error if checks.
1005 input_list = [a.name]
1006 output_list = [result_tens.name]
1007 pCount, cCount = op["operands"]
1008 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001009 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1010 self, error_name, input_list, output_list
1011 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001012
Les Bell729b0352021-11-24 10:28:21 +00001013 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001014 self.ser,
1015 validator_fcns,
1016 error_name,
1017 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001018 axis=axis,
1019 input_shape=a.shape,
1020 output_shape=result_tens.shape,
1021 input_dtype=a.dtype,
1022 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001023 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001024 input_list=input_list,
1025 output_list=output_list,
1026 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001027 ):
1028 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001029
1030 attr = ts.TosaSerializerAttribute()
1031 attr.AxisAttribute(axis)
1032
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001033 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001034 return result_tens
1035
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001036 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1037 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001038
Jeremy Johnson18e26662021-07-22 16:15:29 +01001039 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001040
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001041 if error_name == ErrorIf.MaxSmallerMin:
1042 # Make sure the numbers are different to invoke this error
1043 while v[0] == v[1]:
1044 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1045 max_val = min(v)
1046 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001047 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001048 max_val = max(v)
1049 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001050
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001051 # Invalidate Input/Output list for error if checks.
1052 input_list = [a.name]
1053 output_list = [result_tens.name]
1054 pCount, cCount = op["operands"]
1055 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001056 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1057 self, error_name, input_list, output_list
1058 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001059
Les Bell729b0352021-11-24 10:28:21 +00001060 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001061 self.ser,
1062 validator_fcns,
1063 error_name,
1064 op=op,
1065 max_val=max_val,
1066 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001067 input_shape=a.shape,
1068 output_shape=result_tens.shape,
1069 input_dtype=a.dtype,
1070 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001071 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001072 input_list=input_list,
1073 output_list=output_list,
1074 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001075 ):
1076 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001077
1078 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001079 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1080 if a.dtype == DType.FP16:
1081 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1082 min_val = min_val.astype(np.float32)
1083 max_val = max_val.astype(np.float32)
1084
1085 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001086 else:
James Ward34071252022-12-07 15:48:47 +00001087 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001088
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001089 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001090 return result_tens
1091
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001092 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1093 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001094 attr = ts.TosaSerializerAttribute()
1095
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001096 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001097
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001098 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001099 return result_tens
1100
1101 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001102 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1103 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001104
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001105 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001106 return result_tens
1107
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001108 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1109 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1110
1111 # Invalidate Input/Output list for error if checks.
1112 input_list = [a.name]
1113 output_list = [result_tens.name]
1114 pCount, cCount = op["operands"]
1115 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001116 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1117 self, error_name, input_list, output_list
1118 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001119
Les Bell729b0352021-11-24 10:28:21 +00001120 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001121 self.ser,
1122 validator_fcns,
1123 error_name,
1124 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001125 input_shape=a.shape,
1126 output_shape=result_tens.shape,
1127 input_dtype=a.dtype,
1128 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001129 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001130 input_list=input_list,
1131 output_list=output_list,
1132 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001133 ):
1134 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001135
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001136 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001137 return result_tens
1138
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001139 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1140 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1141
1142 # Invalidate Input/Output list for error if checks.
1143 input_list = [a.name]
1144 output_list = [result_tens.name]
1145 pCount, cCount = op["operands"]
1146 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001147 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1148 self, error_name, input_list, output_list
1149 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001150
Les Bell729b0352021-11-24 10:28:21 +00001151 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001152 self.ser,
1153 validator_fcns,
1154 error_name,
1155 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001156 input_shape=a.shape,
1157 output_shape=result_tens.shape,
1158 input_dtype=a.dtype,
1159 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001160 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001161 input_list=input_list,
1162 output_list=output_list,
1163 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001164 ):
1165 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001166
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001167 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001168 return result_tens
1169
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001170 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1171 if error_name != ErrorIf.WrongInputType:
1172 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001173
1174 # To store variable length list of input tensors we need to store axis along with it
1175 axis = a[-1]
1176 a = a[:-1]
1177
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001178 result_tens = OutputShaper.concatOp(
1179 self.ser, self.rng, axis, *a, error_name=error_name
1180 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001181
Matthew Haddon818ab902021-07-27 09:12:49 +01001182 input_tensor_names = []
1183 for tensor in a:
1184 input_tensor_names.append(tensor.name)
1185
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001186 # Invalidate Input/Output list for error if checks.
1187 input_list = input_tensor_names
1188 output_list = [result_tens.name]
1189 pCount, cCount = op["operands"]
1190 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001191 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1192 self, error_name, input_list, output_list
1193 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001194
Les Bell729b0352021-11-24 10:28:21 +00001195 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001196 self.ser,
1197 validator_fcns,
1198 error_name,
1199 op=op,
1200 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001201 input_shape=a[0].shape,
1202 output_shape=result_tens.shape,
1203 input_dtype=a[0].dtype,
1204 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001205 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001206 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001207 input_list=input_list,
1208 output_list=output_list,
1209 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001210 ):
1211 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001212
1213 attr = ts.TosaSerializerAttribute()
1214 attr.AxisAttribute(axis)
1215
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001216 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001217 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001218
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001219 def build_pad(
1220 self,
1221 op,
1222 a,
1223 padding,
1224 pad_const_int,
1225 pad_const_float,
1226 validator_fcns=None,
1227 error_name=None,
1228 qinfo=None,
1229 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001230 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001231
Kevin Chengfe392ce2021-10-18 21:51:55 +00001232 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001233 attr.PadAttribute(
1234 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1235 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001236
Matthew Haddone807aae2021-10-11 18:12:58 +01001237 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001238 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001239 output_list = [result_tens.name]
1240 pCount, cCount = op["operands"]
1241 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001242 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1243 self, error_name, input_list, output_list
1244 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001245
Les Bell729b0352021-11-24 10:28:21 +00001246 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001247 self.ser,
1248 validator_fcns,
1249 error_name,
1250 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001251 input_shape=a.shape,
1252 output_shape=result_tens.shape,
1253 input_dtype=a.dtype,
1254 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001255 pad=padding,
1256 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001257 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001258 input_list=input_list,
1259 output_list=output_list,
1260 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001261 ):
1262 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001263
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001264 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001265 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001266
Matthew Haddone807aae2021-10-11 18:12:58 +01001267 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001268 result_tens = OutputShaper.reshapeOp(
1269 self.ser, self.rng, a, newShape, error_name
1270 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001271
1272 # Invalidate Input/Output list for error if checks.
1273 input_list = [a.name]
1274 output_list = [result_tens.name]
1275 pCount, cCount = op["operands"]
1276 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001277 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1278 self, error_name, input_list, output_list
1279 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001280
Les Bell729b0352021-11-24 10:28:21 +00001281 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001282 self.ser,
1283 validator_fcns,
1284 error_name,
1285 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001286 input_shape=a.shape,
1287 output_shape=result_tens.shape,
1288 input_dtype=a.dtype,
1289 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001290 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001291 input_list=input_list,
1292 output_list=output_list,
1293 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001294 ):
1295 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001296
1297 attr = ts.TosaSerializerAttribute()
1298 attr.ReshapeAttribute(newShape)
1299
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001300 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001301 return result_tens
1302
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001303 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1304 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1305
1306 # Invalidate Input/Output list for error if checks.
1307 input_list = [a.name]
1308 output_list = [result_tens.name]
1309 pCount, cCount = op["operands"]
1310 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001311 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1312 self, error_name, input_list, output_list
1313 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001314
Les Bell729b0352021-11-24 10:28:21 +00001315 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001316 self.ser,
1317 validator_fcns,
1318 error_name,
1319 op=op,
1320 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001321 input_shape=a.shape,
1322 output_shape=result_tens.shape,
1323 input_dtype=a.dtype,
1324 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001325 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001326 input_list=input_list,
1327 output_list=output_list,
1328 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001329 ):
1330 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001331
1332 attr = ts.TosaSerializerAttribute()
1333 attr.AxisAttribute(axis)
1334
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001335 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001336 return result_tens
1337
Matthew Haddone807aae2021-10-11 18:12:58 +01001338 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1339 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001340
Kevin Chengfe392ce2021-10-18 21:51:55 +00001341 attr = ts.TosaSerializerAttribute()
1342 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001343
Matthew Haddone807aae2021-10-11 18:12:58 +01001344 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001345 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001346 output_list = [result_tens.name]
1347 pCount, cCount = op["operands"]
1348 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001349 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1350 self, error_name, input_list, output_list
1351 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001352
Les Bell729b0352021-11-24 10:28:21 +00001353 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001354 self.ser,
1355 validator_fcns,
1356 error_name,
1357 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001358 input_shape=a.shape,
1359 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001360 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001361 input_dtype=a.dtype,
1362 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001363 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001364 input_list=input_list,
1365 output_list=output_list,
1366 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001367 ):
1368 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001369
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001370 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001371 return result_tens
1372
Matthew Haddone807aae2021-10-11 18:12:58 +01001373 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001374 result_tens = OutputShaper.sliceOp(
1375 self.ser, self.rng, a, start, size, error_name
1376 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001377
1378 # Invalidate Input/Output list for error if checks.
1379 input_list = [a.name]
1380 output_list = [result_tens.name]
1381 pCount, cCount = op["operands"]
1382 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001383 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1384 self, error_name, input_list, output_list
1385 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001386
Les Bell729b0352021-11-24 10:28:21 +00001387 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001388 self.ser,
1389 validator_fcns,
1390 error_name,
1391 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001392 input_shape=a.shape,
1393 output_shape=result_tens.shape,
1394 input_dtype=a.dtype,
1395 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001396 start=start,
1397 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001398 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001399 input_list=input_list,
1400 output_list=output_list,
1401 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001402 ):
1403 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001404
1405 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001406 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001407
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001409 return result_tens
1410
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001411 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1412 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1413
1414 # Invalidate Input/Output list for error if checks.
1415 input_list = [a.name]
1416 output_list = [result_tens.name]
1417 pCount, cCount = op["operands"]
1418 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001419 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1420 self, error_name, input_list, output_list
1421 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001422
Les Bell729b0352021-11-24 10:28:21 +00001423 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 self.ser,
1425 validator_fcns,
1426 error_name,
1427 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001428 input_shape=a.shape,
1429 output_shape=result_tens.shape,
1430 input_dtype=a.dtype,
1431 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001432 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001433 input_list=input_list,
1434 output_list=output_list,
1435 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001436 ):
1437 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001438
1439 attr = ts.TosaSerializerAttribute()
1440 attr.TileAttribute(multiples)
1441
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001442 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001443 return result_tens
1444
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001445 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001446
1447 # Create a new indicies tensor
1448 # here with data that doesn't exceed the dimensions of the values tensor
1449
Kevin Cheng550ccc52021-03-03 11:21:43 -08001450 K = values.shape[1] # K
1451 W = self.randInt(
1452 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1453 ) # W
1454 indicies_arr = np.int32(
1455 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1456 ) # (N, W)
1457 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001458
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001459 result_tens = OutputShaper.gatherOp(
1460 self.ser, self.rng, values, indicies, error_name
1461 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001462
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001463 # Invalidate Input/Output list for error if checks.
1464 input_list = [values.name, indicies.name]
1465 output_list = [result_tens.name]
1466 pCount, cCount = op["operands"]
1467 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001468 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1469 self, error_name, input_list, output_list
1470 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001471
Les Bell729b0352021-11-24 10:28:21 +00001472 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001473 self.ser,
1474 validator_fcns,
1475 error_name,
1476 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001477 input_shape=values.shape,
1478 output_shape=result_tens.shape,
1479 input_dtype=values.dtype,
1480 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001481 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001482 input_list=input_list,
1483 output_list=output_list,
1484 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001485 ):
1486 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001487
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001488 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001489
1490 return result_tens
1491
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001492 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001493
1494 # Create a new indicies tensor
1495 # here with data that doesn't exceed the dimensions of the values_in tensor
1496
Kevin Cheng550ccc52021-03-03 11:21:43 -08001497 K = values_in.shape[1] # K
1498 W = input.shape[1] # W
1499 indicies_arr = np.int32(
1500 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1501 ) # (N, W)
1502 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001503
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001504 result_tens = OutputShaper.scatterOp(
1505 self.ser, self.rng, values_in, indicies, input, error_name
1506 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001507
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001508 # Invalidate Input/Output list for error if checks.
1509 input_list = [values_in.name, indicies.name, input.name]
1510 output_list = [result_tens.name]
1511 pCount, cCount = op["operands"]
1512 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001513 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1514 self, error_name, input_list, output_list
1515 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001516
Les Bell729b0352021-11-24 10:28:21 +00001517 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001518 self.ser,
1519 validator_fcns,
1520 error_name,
1521 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001522 input_shape=values_in.shape,
1523 output_shape=result_tens.shape,
1524 input_dtype=values_in.dtype,
1525 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001526 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001527 input_list=input_list,
1528 output_list=output_list,
1529 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001530 ):
1531 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001532
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001533 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001534
Kevin Cheng77d0f762020-11-24 10:26:32 -08001535 return result_tens
1536
Kevin Cheng550ccc52021-03-03 11:21:43 -08001537 def build_resize(
1538 self,
1539 op,
1540 input,
1541 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001542 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001543 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001544 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001545 input_dtype,
1546 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001547 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001548 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001549 ):
1550 result_tens = OutputShaper.resizeOp(
1551 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001552 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001553 input,
1554 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001555 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001556 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001557 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001558 input_dtype,
1559 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001560 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001561 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001562
Matthew Haddon848efb42021-09-09 12:30:53 +01001563 # Invalidate Input/Output list for error if checks.
1564 input_list = [input.name]
1565 output_list = [result_tens.name]
1566 pCount, cCount = op["operands"]
1567 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001568 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1569 self, error_name, input_list, output_list
1570 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001571
Les Bell729b0352021-11-24 10:28:21 +00001572 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001573 self.ser,
1574 validator_fcns,
1575 error_name,
1576 op=op,
1577 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001578 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001579 input_dtype=input_dtype,
1580 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001581 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001582 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001583 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001584 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001585 input_list=input_list,
1586 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001587 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001588 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001589 ):
1590 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001591
Eric Kunzee5e26762020-10-13 16:11:07 -07001592 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001593
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001594 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001595
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001596 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001597 return result_tens
1598
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001599 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1600 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1601 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001602 self.ser.addOperator(
1603 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1604 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001605 return result_tens
1606
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001607 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001608 self.ser.addOutputTensor(val)
1609 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001610
1611 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001612 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001613 result_tens = OutputShaper.typeConversionOp(
1614 self.ser, self.rng, val, out_dtype, error_name
1615 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001616
1617 # Invalidate Input/Output list for error if checks.
1618 input_list = [val.name]
1619 output_list = [result_tens.name]
1620 pCount, cCount = op["operands"]
1621 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001622 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1623 self, error_name, input_list, output_list
1624 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001625
Les Bell729b0352021-11-24 10:28:21 +00001626 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001627 self.ser,
1628 validator_fcns,
1629 error_name,
1630 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001631 input_shape=val.shape,
1632 output_shape=result_tens.shape,
1633 input_dtype=val.dtype,
1634 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001635 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001636 input_list=input_list,
1637 output_list=output_list,
1638 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001639 ):
1640 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001641
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001642 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001643 return result_tens
1644
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001645 def build_rescale(
1646 self,
1647 op,
1648 val,
1649 out_dtype,
1650 scale32,
1651 double_round,
1652 per_channel,
1653 validator_fcns,
1654 error_name,
1655 ):
1656 result_tens = OutputShaper.typeConversionOp(
1657 self.ser, self.rng, val, out_dtype, error_name
1658 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001659
1660 if per_channel:
1661 nc = val.shape[-1]
1662 else:
1663 nc = 1
1664
1665 in_type_width = self.typeWidth(val.dtype)
1666 out_type_width = self.typeWidth(out_dtype)
1667
Kevin Cheng3a478572021-01-22 17:21:02 -08001668 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001669 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001670 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001671 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001672 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001673 in_type_width += 1
1674 elif error_name in [
1675 ErrorIf.InputZeroPointNotZero,
1676 ErrorIf.U16InputZeroPointNotValid,
1677 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001678 input_zp = self.randInt(-128, 128)
1679 if input_zp == 0:
1680 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001681 in_type_width += 1
1682 elif val.dtype == DType.UINT16:
1683 # Must come after ErrorIf.U16InputZeroPointNotValid check
1684 input_zp = self.rng.choice([0, 32768])
1685 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001686 else:
1687 input_zp = 0
1688
Kevin Cheng3a478572021-01-22 17:21:02 -08001689 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001690 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001691 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001692 elif out_dtype == DType.UINT8:
1693 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001694 out_type_width += 1
1695 elif error_name in [
1696 ErrorIf.OutputZeroPointNotZero,
1697 ErrorIf.U16OutputZeroPointNotValid,
1698 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001699 output_zp = self.randInt(-128, 128)
1700 if output_zp == 0:
1701 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001702 out_type_width += 1
1703 elif out_dtype == DType.UINT16:
1704 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1705 output_zp = self.rng.choice([0, 32768])
1706 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001707 else:
1708 output_zp = 0
1709
1710 # Calculate scale based on:
1711 # scale = a *(2^output_width)/(2^input_width))
1712
1713 a = np.float32(self.rng.random(size=[nc]))
1714 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1715
1716 if scale32:
1717 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001718 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001719 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1720 else:
1721 # Cap the scaling at 2^15 - 1 for scale16
1722 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1723
Kevin Cheng550ccc52021-03-03 11:21:43 -08001724 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001725
1726 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1727 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001728 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1729 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001730
1731 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001732 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1733 scale_arr[i], scale32
1734 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001735 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1736 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001737
Kevin Cheng550ccc52021-03-03 11:21:43 -08001738 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001739 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001740 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001741 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001742 assert val.placeholderFilename
1743 values = np.load(
1744 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1745 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001746 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1747 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1748 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1749 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001750 if not np.all(np.array_equal(values, val_adj)):
1751 # Values changed so overwrite file with new values
1752 np.save(
1753 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1754 val_adj,
1755 False,
1756 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001757
Matthew Haddonc2025212021-10-08 21:21:05 +01001758 # Invalidate Input/Output list for error if checks.
1759 input_list = [val.name]
1760 output_list = [result_tens.name]
1761 pCount, cCount = op["operands"]
1762 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001763 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1764 self, error_name, input_list, output_list
1765 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001766
1767 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001768 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001769 self.ser,
1770 validator_fcns,
1771 error_name,
1772 op=op,
1773 input_dtype=val.dtype,
1774 output_dtype=out_dtype,
1775 input_shape=val.shape,
1776 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001777 scale32=scale32,
1778 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001779 input_list=input_list,
1780 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001781 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001782 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001783 ):
1784 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001785
Eric Kunzee5e26762020-10-13 16:11:07 -07001786 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001787 attr.RescaleAttribute(
1788 input_zp,
1789 output_zp,
1790 multiplier_arr,
1791 shift_arr,
1792 scale32,
1793 double_round,
1794 per_channel,
1795 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001796
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001797 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001798 return result_tens
1799
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001800 def _get_condition_tensor(self, op, cond, error_name):
1801 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1802 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1803 else:
1804 cond_type = DType.BOOL
1805 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1806 choice = self.rng.choice([1, 2])
1807 if choice == 1:
1808 cond_shape = [2]
1809 else:
1810 cond_shape = [1, 2]
1811 else:
1812 # Must be of size 1 (rank 0)
1813 cond_shape = []
1814 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1815 return cond_tens
1816
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001817 def build_cond_if_const(
1818 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1819 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001820 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001821 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001822 # and fill them with const nodes for the body.
1823
1824 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001825 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001826
1827 # Make then/else tensors
1828 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001829
1830 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001831 if error_name in [
1832 ErrorIf.CondIfOutputListThenGraphMismatch,
1833 ErrorIf.CondIfOutputListElseGraphMismatch,
1834 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001835 incorrect_shape = deepcopy(then_tens.shape)
1836 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001837 incorrect_shape[i] += (
1838 self.rng.choice([-3, -2, 2, 3])
1839 if incorrect_shape[i] > 3
1840 else self.rng.choice([1, 2, 4])
1841 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001842 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1843
Jeremy Johnson18e26662021-07-22 16:15:29 +01001844 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1845 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001846
1847 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001848 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001849
1850 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001851 then_block = "THEN_BLOCK"
1852 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001853 attr = ts.TosaSerializerAttribute()
1854 attr.CondIfAttribute(then_block, else_block)
1855
1856 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001857 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001858
Jerry Ge9e94af82022-10-27 09:57:00 -07001859 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001860 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001861 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1862 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1863 else:
1864 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001865 self.ser.addOutputTensor(then_tens)
1866
Jerry Ge9e94af82022-10-27 09:57:00 -07001867 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001868 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1869 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1870 else:
1871 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001872 self.ser.addOutputTensor(else_tens)
1873
Les Bell729b0352021-11-24 10:28:21 +00001874 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001875 self.ser,
1876 validator_fcns,
1877 error_name,
1878 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001879 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001880 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001881 ):
1882 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001883
Eric Kunzee5e26762020-10-13 16:11:07 -07001884 return result_tens
1885
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001886 def build_cond_if_binary(
1887 self, op, a, b, cond, validator_fcns=None, error_name=None
1888 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001889 # For cond_if with a binary op in the then/else blocks, take a and b and
1890 # alternately add or subtract them based on the condition
1891
1892 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001893 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001894
Kevin Cheng550ccc52021-03-03 11:21:43 -08001895 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001896
1897 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001898 then_block = "THEN_BLOCK"
1899 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001900 attr = ts.TosaSerializerAttribute()
1901 attr.CondIfAttribute(then_block, else_block)
1902
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001903 if error_name in [
1904 ErrorIf.CondIfInputListThenGraphMismatch,
1905 ErrorIf.CondIfInputListElseGraphMismatch,
1906 ErrorIf.CondIfOutputListElseGraphMismatch,
1907 ErrorIf.CondIfOutputListThenGraphMismatch,
1908 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001909 incorrect_shape = a.shape.copy()
1910 for i in range(len(incorrect_shape)):
1911 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1912 incorrect_block_input = deepcopy(a)
1913 incorrect_block_input.shape = incorrect_shape
1914
Eric Kunzee5e26762020-10-13 16:11:07 -07001915 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001916 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001917 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001918 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001919
James Ward24dbc422022-10-19 12:20:31 +01001920 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001921 then_op, else_op = Op.ADD, Op.SUB
1922 elif a.dtype in (DType.INT8, DType.INT16):
1923 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1924 else:
1925 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001926
Les Bell6040b4d2021-10-11 12:50:31 +01001927 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001928 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001929 if (
1930 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1931 and block == then_block
1932 ) or (
1933 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1934 and block == else_block
1935 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001936 self.ser.addInputTensor(incorrect_block_input)
1937 self.ser.addInputTensor(b)
1938 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001939 elif (
1940 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1941 and block == then_block
1942 ) or (
1943 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1944 and block == else_block
1945 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001946 self.ser.addInputTensor(a)
1947 self.ser.addInputTensor(b)
1948 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1949 else:
1950 self.ser.addInputTensor(a)
1951 self.ser.addInputTensor(b)
1952 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001953 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001954
Les Bell729b0352021-11-24 10:28:21 +00001955 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001956 self.ser,
1957 validator_fcns,
1958 error_name,
1959 op=op,
1960 a=a,
1961 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07001962 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001963 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001964 ):
1965 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001966
Eric Kunzee5e26762020-10-13 16:11:07 -07001967 return result_tens
1968
Matthew Haddon630c17c2021-10-14 15:05:41 +01001969 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001970 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001971
Kevin Cheng550ccc52021-03-03 11:21:43 -08001972 cond_block = "COND_BLOCK"
1973 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001974
1975 attr = ts.TosaSerializerAttribute()
1976 attr.WhileLoopAttribute(cond_block, body_block)
1977
1978 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001979 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001980 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001981 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001982
1983 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001984 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1985 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001986 if error_name == ErrorIf.InputListOutputListMismatch:
1987 incorrect_acc = deepcopy(acc)
1988 for i in range(len(incorrect_acc.shape)):
1989 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
1990 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
1991 else:
1992 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001993
1994 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001995 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001996 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08001997 [iter.name, a.name, acc.name],
1998 [iter_out.name, a_out.name, acc_out.name],
1999 attr,
2000 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002001 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002002
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002003 if error_name in [
2004 ErrorIf.InputListCondGraphMismatch,
2005 ErrorIf.InputListBodyGraphInputMismatch,
2006 ErrorIf.InputListBodyGraphOutputMismatch,
2007 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002008 incorrect_iter = deepcopy(iter)
2009 for i in range(len(incorrect_iter.shape)):
2010 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2011 if len(incorrect_iter.shape) == 0:
2012 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2013
2014 incorrect_acc = deepcopy(acc)
2015 for i in range(len(incorrect_acc.shape)):
2016 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2017
Eric Kunzee5e26762020-10-13 16:11:07 -07002018 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002019 self.ser.addBasicBlock(cond_block)
2020
Matthew Haddon630c17c2021-10-14 15:05:41 +01002021 if error_name == ErrorIf.InputListCondGraphMismatch:
2022 self.ser.addInputTensor(incorrect_iter)
2023 self.ser.addInputTensor(a)
2024 self.ser.addInputTensor(incorrect_acc)
2025 else:
2026 self.ser.addInputTensor(iter)
2027 self.ser.addInputTensor(a)
2028 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002029 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002030
2031 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002032 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002033 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002034 cond_type = DType.BOOL
2035 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2036 choice = self.rng.choice([1, 2])
2037 if choice == 1:
2038 cond_shape = [3]
2039 else:
2040 cond_shape = [1, 2]
2041 else:
2042 cond_shape = []
2043 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002044
Kevin Cheng550ccc52021-03-03 11:21:43 -08002045 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002046
2047 # BODY block (input: a, acc, iter, output: a, acc, iter)
2048 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002049 self.ser.addBasicBlock(body_block)
2050
Matthew Haddon630c17c2021-10-14 15:05:41 +01002051 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2052 self.ser.addInputTensor(incorrect_iter)
2053 self.ser.addInputTensor(a)
2054 self.ser.addInputTensor(incorrect_acc)
2055 else:
2056 self.ser.addInputTensor(iter)
2057 self.ser.addInputTensor(a)
2058 self.ser.addInputTensor(acc)
2059
Kevin Cheng550ccc52021-03-03 11:21:43 -08002060 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002061
2062 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002063 iter_body_out = self.ser.addIntermediate(
2064 incorrect_iter.shape, incorrect_iter.dtype
2065 )
2066 acc_body_out = self.ser.addIntermediate(
2067 incorrect_acc.shape, incorrect_acc.dtype
2068 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002069 else:
2070 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2071 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2072
Eric Kunzee5e26762020-10-13 16:11:07 -07002073 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2074 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2075 self.ser.addOutputTensor(iter_body_out)
2076 self.ser.addOutputTensor(a)
2077 self.ser.addOutputTensor(acc_body_out)
2078
Les Bell729b0352021-11-24 10:28:21 +00002079 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002080 self.ser,
2081 validator_fcns,
2082 error_name,
2083 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002084 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002085 ):
2086 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002087
Eric Kunzee5e26762020-10-13 16:11:07 -07002088 return acc_out
2089
Luke Hutton57287132023-02-06 14:54:18 +00002090 def build_fft2d(
2091 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2092 ):
2093 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2094
2095 input_names = [val1.name, val2.name]
2096 pCount, cCount = op["operands"]
2097 num_operands = pCount + cCount
2098
2099 output_names = [res.name for res in results]
2100 output_shapes = [res.shape for res in results]
2101 output_dtypes = [res.dtype for res in results]
2102
2103 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2104 self, error_name, input_names, output_names
2105 )
2106
2107 if not TosaErrorValidator.evValidateErrorIfs(
2108 self.ser,
2109 validator_fcns,
2110 error_name,
2111 op=op,
2112 inverse=inverse,
2113 input1=val1,
2114 input2=val2,
2115 input_shape=val1.shape,
2116 input_dtype=val1.dtype,
2117 output_shape=output_shapes,
2118 output_dtype=output_dtypes,
2119 result_tensors=results,
2120 input_list=input_names,
2121 output_list=output_names,
2122 num_operands=num_operands,
2123 ):
2124 return None
2125
2126 attr = ts.TosaSerializerAttribute()
2127 attr.FFTAttribute(inverse)
2128
2129 self.ser.addOperator(op["op"], input_names, output_names, attr)
2130 return results
2131
Luke Hutton261b7b62023-01-10 14:50:31 +00002132 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2133 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2134
2135 input_names = [val.name]
2136 pCount, cCount = op["operands"]
2137 num_operands = pCount + cCount
2138
2139 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002140 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002141 output_dtypes = [res.dtype for res in results]
2142
2143 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2144 self, error_name, input_names, output_names
2145 )
2146
2147 if not TosaErrorValidator.evValidateErrorIfs(
2148 self.ser,
2149 validator_fcns,
2150 error_name,
2151 op=op,
2152 input_shape=val.shape,
2153 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002154 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002155 output_dtype=output_dtypes,
2156 result_tensors=results,
2157 input_list=input_names,
2158 output_list=output_names,
2159 num_operands=num_operands,
2160 ):
2161 return None
2162
2163 self.ser.addOperator(op["op"], input_names, output_names)
2164 return results
2165
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002166 def create_filter_lists(
2167 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2168 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002169 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2170 default_test_rank_range = range(1, 5)
2171 if not shapeFilter:
2172 shapeFilter = [None]
2173
2174 # Calculate the filters based on what is requested and what the operator allows
2175 rmin, rmax = op["rank"]
2176 if rankFilter is not None:
2177 cleanRankFilter = []
2178 # Ensure rankFilter values are allowed by operator
2179 for rank in rankFilter:
2180 if rank >= rmin and rank <= rmax:
2181 cleanRankFilter.append(rank)
2182 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002183 # Ensure default behaviour is bounded by default range or by operator,
2184 # whichever is the smaller range of ranks.
2185 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002186 cleanRankFilter = (
2187 opRankRange
2188 if len(opRankRange) <= len(default_test_rank_range)
2189 else default_test_rank_range
2190 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002191 else:
2192 cleanRankFilter = range(rmin, rmax + 1)
2193
2194 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002195
Matthew Haddon1c00b712021-10-01 15:51:03 +01002196 if dtypeFilter is not None:
2197 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002198 # Create list of operator dtypes filtered by requested dtypes
2199 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002200 if dtype in dtypeFilter or (
2201 isinstance(dtype, list) and dtype[0] in dtypeFilter
2202 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002203 cleanDtypeFilter.append(dtype)
2204 else:
2205 cleanDtypeFilter = dtypes
2206
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002207 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002208 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002209 "shapeFilter": shapeFilter,
2210 "rankFilter": cleanRankFilter,
2211 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002212 }
2213 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002214 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002215 if validator is not None:
2216 validator_info = validator(check=False, op=op)
2217 else:
2218 return None
2219
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002220 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002221
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002222 # Set parameters as required
2223 if error_arguments["rank"] is not None:
2224 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002225 else:
2226 rankFilter = cleanRankFilter
2227
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002228 if error_arguments["dtype"] is not None:
2229 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002230 else:
2231 dtypeFilter = cleanDtypeFilter
2232
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002233 if error_arguments["shape"] is not None:
2234 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002235 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002236 shapeFilter = shapeFilter[
2237 :2
2238 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002239
2240 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002241 "shapeFilter": shapeFilter,
2242 "rankFilter": rankFilter,
2243 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002244 }
2245 return filterDict
2246
Kevin Cheng550ccc52021-03-03 11:21:43 -08002247 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002248 self,
2249 opName,
2250 shapeFilter=[None],
2251 rankFilter=None,
2252 dtypeFilter=None,
2253 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002254 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002255
2256 try:
2257 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002258 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002259 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002260
2261 # Initialize a new random number generator
2262 self.rng = np.random.default_rng(self.random_seed)
2263
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002264 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002265
Eric Kunzee5e26762020-10-13 16:11:07 -07002266 # Test list consists of a tuple of:
2267 # (opName, testNameStr, dtype, shapeList, argumentsList)
2268 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002269 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002270 error_if_validators = op["error_if_validators"]
2271 else:
2272 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002273
Matthew Haddon1c00b712021-10-01 15:51:03 +01002274 for validator in error_if_validators:
2275 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002276 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002277 else:
2278 error_name = None
2279
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002280 filterDict = self.create_filter_lists(
2281 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2282 )
2283 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002284 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002285 cleanRankFilter = filterDict["rankFilter"]
2286 cleanDtypeFilter = filterDict["dtypeFilter"]
2287 cleanShapeFilter = filterDict["shapeFilter"]
2288 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002289
2290 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002291 for t in cleanDtypeFilter:
2292 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002293 # Filter out by rank
2294 if shape is not None and len(shape) != r:
2295 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002296 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002297 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002298
Matthew Haddon74567092021-07-16 15:38:20 +01002299 shapeStr = self.shapeStr(shapeList[0])
2300 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002301
Matthew Haddon74567092021-07-16 15:38:20 +01002302 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2303 argList = []
2304 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002305 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002306 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002307 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002308
Matthew Haddon74567092021-07-16 15:38:20 +01002309 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002310 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002311 if argStr:
2312 testStr = "{}_{}_{}_{}".format(
2313 opName, shapeStr, typeStr, argStr
2314 )
2315 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002316 testStr = "{}_{}_{}".format(
2317 opName, shapeStr, typeStr
2318 )
2319 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002320 if argStr:
2321 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2322 opName, error_name, shapeStr, typeStr, argStr
2323 )
2324 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002325 testStr = "{}_ERRORIF_{}_{}_{}".format(
2326 opName, error_name, shapeStr, typeStr
2327 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002328
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002329 testList.append(
2330 (opName, testStr, t, error_name, shapeList, args)
2331 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002332
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002333 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002334 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2335 if "invalid_test_validators" in op:
2336 invalid_test_validators = op["invalid_test_validators"]
2337 clean_testList = []
2338 for test in testList:
2339 for validator_fcn in invalid_test_validators:
2340 remove_test = False
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002341 if validator_fcn(
2342 opName=test[0],
2343 input_dtype=test[2],
2344 shapeList=test[4],
2345 args=test[5],
2346 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002347 remove_test = True
2348 if not remove_test:
2349 clean_testList.append(test)
2350 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002351
2352 return testList
2353
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002354 def serializeTest(
2355 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2356 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002357 try:
2358 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002359 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002360 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002361
2362 # Create a serializer
2363 self.createSerializer(opName, testStr)
2364
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002365 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002366 if "error_if_validators" in op:
2367 error_if_validators = op["error_if_validators"]
2368 else:
2369 error_if_validators = None
2370
Kevin Cheng550ccc52021-03-03 11:21:43 -08002371 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002372 num_operands = pCount + cCount
2373
2374 if isinstance(dtype_or_dtypeList, list):
2375 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002376 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002377 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002378 else:
2379 dtypeList = [dtype_or_dtypeList] * (num_operands)
2380
Kevin Cheng93a16282021-08-31 16:14:03 -07002381 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002382 assert (
2383 len(shapeList) == num_operands
2384 ), "shapeList length {} must match number of operands {}".format(
2385 len(shapeList), num_operands
2386 )
2387 assert (
2388 len(dtypeList) == num_operands
2389 ), "dtypeList length {} must match number of operands {}".format(
2390 len(dtypeList), num_operands
2391 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002392
2393 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002394 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002395 except KeyError:
2396 qgen = None
2397
2398 # Build the random tensor operands and the test
2399 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002400
Matthew Haddon1c00b712021-10-01 15:51:03 +01002401 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002402 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002403 else:
2404 qinfo = None
2405
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002406 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002407
Matthew Haddon1c00b712021-10-01 15:51:03 +01002408 try:
2409 if error_if_validators is None:
2410 if qinfo is not None:
2411 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2412 else:
2413 resultName = build_fcn(self, op, *tens, *testArgs)
2414 else:
2415 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002416 resultName = build_fcn(
2417 self,
2418 op,
2419 *tens,
2420 *testArgs,
2421 validator_fcns=error_if_validators,
2422 error_name=error_name,
2423 qinfo=qinfo,
2424 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002425 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002426 resultName = build_fcn(
2427 self,
2428 op,
2429 *tens,
2430 *testArgs,
2431 validator_fcns=error_if_validators,
2432 error_name=error_name,
2433 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002434 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002435 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002436 raise e
2437
Les Bell729b0352021-11-24 10:28:21 +00002438 if resultName:
2439 # The test is valid, serialize it
2440 self.serialize("test")
2441 else:
2442 # The test is not valid
2443 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002444
Eric Kunzee5e26762020-10-13 16:11:07 -07002445 def createDynamicOpLists(self):
2446
Jeremy Johnson00423432022-09-12 17:27:37 +01002447 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2448 # Already created these lists (can occur when class is initialized more than once)
2449 return
2450
Eric Kunzee5e26762020-10-13 16:11:07 -07002451 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng1533b852021-09-01 12:51:58 -07002452 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002453
Kevin Cheng1533b852021-09-01 12:51:58 -07002454 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002455 testName = "conv2d_{}x{}".format(k[0], k[1])
2456 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2457 self.TOSA_OP_LIST[testName]["filter"] = k
2458 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002459
Kevin Cheng550ccc52021-03-03 11:21:43 -08002460 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2461 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2462 "depthwise_conv2d_TEMPLATE"
2463 ].copy()
2464 self.TOSA_OP_LIST[testName]["filter"] = k
2465 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002466
Kevin Cheng550ccc52021-03-03 11:21:43 -08002467 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2468 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2469 "transpose_conv2d_TEMPLATE"
2470 ].copy()
2471 self.TOSA_OP_LIST[testName]["filter"] = k
2472 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002473
Kevin Cheng1533b852021-09-01 12:51:58 -07002474 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2475 for k in KERNELS_3D:
2476 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2477 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2478 self.TOSA_OP_LIST[testName]["filter"] = k
2479 self.TOSA_OP_LIST[testName]["template"] = False
2480
Eric Kunzee5e26762020-10-13 16:11:07 -07002481 # Delete any templates after having created any dynamic ops
2482 # This is a two-pass operation because it's bad practice to delete
2483 # keys from dictionaries while iterating
2484 keyList = []
2485 for k in self.TOSA_OP_LIST:
2486 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002487 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002488 keyList.append(k)
2489 continue
2490 except KeyError:
2491 pass
2492
2493 for k in keyList:
2494 del self.TOSA_OP_LIST[k]
2495
2496 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002497 """Fill in default fields for ops if they aren't already specified.
2498 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002499 for op in self.TOSA_OP_LIST:
2500
2501 # Required fields
2502 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002503 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002504 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002505 raise Exception(
2506 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2507 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002508
2509 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002510 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002511 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002512 raise Exception(
2513 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2514 op
2515 )
2516 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002517
2518 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002519 _ = self.TOSA_OP_LIST[op]["types"]
2520 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002521 raise Exception(
2522 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2523 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002524
2525 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002526 _ = self.TOSA_OP_LIST[op]["op"]
2527 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002528 raise Exception(
2529 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2530 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002531
2532 # Put in default rank range, if missing
2533 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002534 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002535 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002536 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002537
2538 # Tensor operator list
2539 # 'op': op name
2540 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002541 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2542 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002543 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2544 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002545 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002546
Kevin Cheng550ccc52021-03-03 11:21:43 -08002547 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002548 TYPE_INT_FP = [
2549 DType.INT8,
2550 DType.INT16,
2551 DType.INT32,
2552 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002553 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002554 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002555 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002556
Kevin Cheng550ccc52021-03-03 11:21:43 -08002557 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002558 TYPE_FI32 = [
2559 DType.FP32,
2560 DType.FP16,
2561 DType.BF16,
2562 DType.INT32,
2563 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002564 TYPE_FIB = [
2565 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002566 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002567 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002568 DType.INT8,
2569 DType.INT16,
2570 DType.INT32,
2571 DType.BOOL,
2572 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002573 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002574
James Ward24dbc422022-10-19 12:20:31 +01002575 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002576
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002577 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002578 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002579 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002580 [DType.INT8, DType.INT8, DType.INT32],
2581 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002582 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002583 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002584 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002585 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002586 ]
2587
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002588 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002589
2590 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002591 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002592 "argmax": {
2593 "op": Op.ARGMAX,
2594 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002595 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002596 "build_fcn": (
2597 build_argmax,
2598 TosaTensorGen.tgBasic,
2599 TosaTensorValuesGen.tvgDefault,
2600 TosaArgGen.agAxis,
2601 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002602 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002603 "error_if_validators": (
2604 TosaErrorValidator.evAxisSmallerZero,
2605 TosaErrorValidator.evAxisLargerRank,
2606 TosaErrorValidator.evArgmaxOutputRankMismatch,
2607 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2608 TosaErrorValidator.evWrongRank,
2609 TosaErrorValidator.evWrongInputType,
2610 TosaErrorValidator.evWrongOutputType,
2611 TosaErrorValidator.evWrongInputList,
2612 TosaErrorValidator.evWrongOutputList,
2613 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002614 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002615 "avg_pool2d": {
2616 "op": Op.AVG_POOL2D,
2617 "operands": (1, 0),
2618 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002619 "build_fcn": (
2620 build_pool2d,
2621 TosaTensorGen.tgNHWC,
2622 TosaTensorValuesGen.tvgDefault,
2623 TosaArgGen.agPooling,
2624 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002625 "qgen": TosaQuantGen.qgUnary,
2626 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002627 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002628 "error_if_validators": (
2629 TosaErrorValidator.evKernelSmallerOne,
2630 TosaErrorValidator.evStrideSmallerOne,
2631 TosaErrorValidator.evPadSmallerZero,
2632 TosaErrorValidator.evWrongRank,
2633 TosaErrorValidator.evWrongInputType,
2634 TosaErrorValidator.evWrongOutputType,
2635 TosaErrorValidator.evWrongInputList,
2636 TosaErrorValidator.evWrongOutputList,
2637 TosaErrorValidator.evInputZeroPointNotZero,
2638 TosaErrorValidator.evOutputZeroPointNotZero,
2639 TosaErrorValidator.evPadLargerEqualKernel,
2640 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002641 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002642 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002643 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002644 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002645 "conv2d_TEMPLATE": {
2646 "op": Op.CONV2D,
2647 "operands": (1, 2),
2648 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002649 "build_fcn": (
2650 build_conv2d,
2651 TosaTensorGen.tgConv2D,
2652 TosaTensorValuesGen.tvgDefault,
2653 TosaArgGen.agConv,
2654 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002655 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002656 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002657 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2658 "error_if_validators": (
2659 TosaErrorValidator.evWrongInputType,
2660 TosaErrorValidator.evWrongOutputType,
2661 TosaErrorValidator.evWrongInputList,
2662 TosaErrorValidator.evWrongOutputList,
2663 TosaErrorValidator.evInputZeroPointNotZero,
2664 TosaErrorValidator.evWeightZeroPointNotZero,
2665 TosaErrorValidator.evPadSmallerZero,
2666 TosaErrorValidator.evStrideSmallerOne,
2667 TosaErrorValidator.evDilationSmallerOne,
2668 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002669 TosaErrorValidator.evConvOutputShapeMismatch,
2670 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002671 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002672 "template": True,
2673 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002674 # Templated operator. Filled in by createDynamicOpLists
2675 "conv3d_TEMPLATE": {
2676 "op": Op.CONV3D,
2677 "operands": (1, 2),
2678 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002679 "build_fcn": (
2680 build_conv3d,
2681 TosaTensorGen.tgConv3D,
2682 TosaTensorValuesGen.tvgDefault,
2683 TosaArgGen.agConv,
2684 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002685 "qgen": TosaQuantGen.qgConv,
2686 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002687 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2688 "error_if_validators": (
2689 TosaErrorValidator.evWrongInputType,
2690 TosaErrorValidator.evWrongOutputType,
2691 TosaErrorValidator.evWrongInputList,
2692 TosaErrorValidator.evWrongOutputList,
2693 TosaErrorValidator.evInputZeroPointNotZero,
2694 TosaErrorValidator.evWeightZeroPointNotZero,
2695 TosaErrorValidator.evPadSmallerZero,
2696 TosaErrorValidator.evStrideSmallerOne,
2697 TosaErrorValidator.evDilationSmallerOne,
2698 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002699 TosaErrorValidator.evConvOutputShapeMismatch,
2700 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002701 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002702 "template": True,
2703 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002704 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002705 "depthwise_conv2d_TEMPLATE": {
2706 "op": Op.DEPTHWISE_CONV2D,
2707 "operands": (1, 2),
2708 "filter": [1, 1],
2709 "rank": (4, 4),
2710 "build_fcn": (
2711 build_depthwise_conv2d,
2712 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002713 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002714 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002715 ),
2716 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002717 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002718 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2719 "error_if_validators": (
2720 TosaErrorValidator.evWrongInputType,
2721 TosaErrorValidator.evWrongOutputType,
2722 TosaErrorValidator.evWrongInputList,
2723 TosaErrorValidator.evWrongOutputList,
2724 TosaErrorValidator.evInputZeroPointNotZero,
2725 TosaErrorValidator.evWeightZeroPointNotZero,
2726 TosaErrorValidator.evPadSmallerZero,
2727 TosaErrorValidator.evStrideSmallerOne,
2728 TosaErrorValidator.evDilationSmallerOne,
2729 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002730 TosaErrorValidator.evConvOutputShapeMismatch,
2731 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002732 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002733 "template": True,
2734 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002735 "fully_connected": {
2736 "op": Op.FULLY_CONNECTED,
2737 "operands": (1, 2),
2738 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002739 "build_fcn": (
2740 build_fully_connected,
2741 TosaTensorGen.tgFullyConnected,
2742 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002743 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002744 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002745 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002746 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002747 "error_if_validators": (
2748 TosaErrorValidator.evInputZeroPointNotZero,
2749 TosaErrorValidator.evWeightZeroPointNotZero,
2750 TosaErrorValidator.evWrongRank,
2751 TosaErrorValidator.evWrongInputType,
2752 TosaErrorValidator.evWrongOutputType,
2753 TosaErrorValidator.evWrongInputList,
2754 TosaErrorValidator.evWrongOutputList,
2755 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002756 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002757 "matmul": {
2758 "op": Op.MATMUL,
2759 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002760 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002761 "build_fcn": (
2762 build_matmul,
2763 TosaTensorGen.tgMatmul,
2764 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002765 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002766 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002767 "qgen": TosaQuantGen.qgMatmul,
2768 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002769 "error_if_validators": (
2770 TosaErrorValidator.evInputZeroPointNotZero,
2771 TosaErrorValidator.evWrongRank,
2772 TosaErrorValidator.evWrongInputType,
2773 TosaErrorValidator.evWrongOutputType,
2774 TosaErrorValidator.evWrongInputList,
2775 TosaErrorValidator.evWrongOutputList,
2776 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002777 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002778 "max_pool2d": {
2779 "op": Op.MAX_POOL2D,
2780 "operands": (1, 0),
2781 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002782 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002783 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002784 TosaTensorGen.tgNHWC,
2785 TosaTensorValuesGen.tvgDefault,
2786 TosaArgGen.agPooling,
2787 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002788 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002789 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002790 "error_if_validators": (
2791 TosaErrorValidator.evKernelSmallerOne,
2792 TosaErrorValidator.evStrideSmallerOne,
2793 TosaErrorValidator.evPadSmallerZero,
2794 TosaErrorValidator.evWrongRank,
2795 TosaErrorValidator.evWrongInputType,
2796 TosaErrorValidator.evWrongOutputType,
2797 TosaErrorValidator.evWrongInputList,
2798 TosaErrorValidator.evWrongOutputList,
2799 TosaErrorValidator.evPadLargerEqualKernel,
2800 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002801 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002802 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002803 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002804 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002805 "transpose_conv2d_TEMPLATE": {
2806 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002807 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002808 "rank": (4, 4),
2809 "build_fcn": (
2810 build_transpose_conv2d,
2811 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002812 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002813 TosaArgGen.agTransposeConv2D,
2814 ),
2815 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002816 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002817 "invalid_test_validators": (
2818 TosaInvalidValidator.ivHeightWidthInvalid,
2819 TosaInvalidValidator.ivNonPositiveOutputShape,
2820 ),
2821 "error_if_validators": (
2822 TosaErrorValidator.evWrongInputType,
2823 TosaErrorValidator.evWrongOutputType,
2824 TosaErrorValidator.evWrongInputList,
2825 TosaErrorValidator.evWrongOutputList,
2826 TosaErrorValidator.evInputZeroPointNotZero,
2827 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002828 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002829 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002830 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002831 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002832 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002833 "template": True,
2834 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002835 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002836 "clamp": {
2837 "op": Op.CLAMP,
2838 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002839 "build_fcn": (
2840 build_clamp,
2841 TosaTensorGen.tgBasic,
2842 TosaTensorValuesGen.tvgDefault,
2843 None,
2844 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002845 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002846 "error_if_validators": (
2847 TosaErrorValidator.evMaxSmallerMin,
2848 TosaErrorValidator.evWrongInputType,
2849 TosaErrorValidator.evWrongOutputType,
2850 TosaErrorValidator.evWrongInputList,
2851 TosaErrorValidator.evWrongOutputList,
2852 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002853 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002854 "sigmoid": {
2855 "op": Op.SIGMOID,
2856 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002857 "build_fcn": (
2858 build_sigmoid,
2859 TosaTensorGen.tgBasic,
2860 TosaTensorValuesGen.tvgDefault,
2861 None,
2862 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002863 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002864 "error_if_validators": (
2865 TosaErrorValidator.evWrongInputType,
2866 TosaErrorValidator.evWrongOutputType,
2867 TosaErrorValidator.evWrongInputList,
2868 TosaErrorValidator.evWrongOutputList,
2869 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002870 },
2871 "tanh": {
2872 "op": Op.TANH,
2873 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002874 "build_fcn": (
2875 build_tanh,
2876 TosaTensorGen.tgBasic,
2877 TosaTensorValuesGen.tvgDefault,
2878 None,
2879 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002880 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002881 "error_if_validators": (
2882 TosaErrorValidator.evWrongInputType,
2883 TosaErrorValidator.evWrongOutputType,
2884 TosaErrorValidator.evWrongInputList,
2885 TosaErrorValidator.evWrongOutputList,
2886 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002887 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002888 # Elementwise Binary Operators
2889 "add": {
2890 "op": Op.ADD,
2891 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002892 "build_fcn": (
2893 build_binary_broadcast,
2894 TosaTensorGen.tgBroadcastFuzz,
2895 TosaTensorValuesGen.tvgAddSub,
2896 None,
2897 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002898 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002899 "error_if_validators": (
2900 TosaErrorValidator.evRankMismatch,
2901 TosaErrorValidator.evWrongInputType,
2902 TosaErrorValidator.evWrongOutputType,
2903 TosaErrorValidator.evWrongInputList,
2904 TosaErrorValidator.evWrongOutputList,
2905 TosaErrorValidator.evDimensionMismatch,
2906 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002907 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002908 "arithmetic_right_shift": {
2909 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2910 "operands": (2, 0),
2911 "build_fcn": (
2912 build_arithmetic_right_shift,
2913 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002914 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002915 TosaArgGen.agArithmeticRightShift,
2916 ),
2917 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002918 "error_if_validators": (
2919 TosaErrorValidator.evRankMismatch,
2920 TosaErrorValidator.evWrongInputType,
2921 TosaErrorValidator.evWrongOutputType,
2922 TosaErrorValidator.evWrongInputList,
2923 TosaErrorValidator.evWrongOutputList,
2924 TosaErrorValidator.evDimensionMismatch,
2925 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002926 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002927 "bitwise_and": {
2928 "op": Op.BITWISE_AND,
2929 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002930 "build_fcn": (
2931 build_binary_broadcast,
2932 TosaTensorGen.tgBroadcastFuzz,
2933 TosaTensorValuesGen.tvgDefault,
2934 None,
2935 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002936 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002937 "error_if_validators": (
2938 TosaErrorValidator.evRankMismatch,
2939 TosaErrorValidator.evWrongInputType,
2940 TosaErrorValidator.evWrongOutputType,
2941 TosaErrorValidator.evWrongInputList,
2942 TosaErrorValidator.evWrongOutputList,
2943 TosaErrorValidator.evDimensionMismatch,
2944 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002945 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002946 "bitwise_or": {
2947 "op": Op.BITWISE_OR,
2948 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002949 "build_fcn": (
2950 build_binary_broadcast,
2951 TosaTensorGen.tgBroadcastFuzz,
2952 TosaTensorValuesGen.tvgDefault,
2953 None,
2954 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002955 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002956 "error_if_validators": (
2957 TosaErrorValidator.evRankMismatch,
2958 TosaErrorValidator.evWrongInputType,
2959 TosaErrorValidator.evWrongOutputType,
2960 TosaErrorValidator.evWrongInputList,
2961 TosaErrorValidator.evWrongOutputList,
2962 TosaErrorValidator.evDimensionMismatch,
2963 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002964 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002965 "bitwise_xor": {
2966 "op": Op.BITWISE_XOR,
2967 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002968 "build_fcn": (
2969 build_binary_broadcast,
2970 TosaTensorGen.tgBroadcastFuzz,
2971 TosaTensorValuesGen.tvgDefault,
2972 None,
2973 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002974 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002975 "error_if_validators": (
2976 TosaErrorValidator.evRankMismatch,
2977 TosaErrorValidator.evWrongInputType,
2978 TosaErrorValidator.evWrongOutputType,
2979 TosaErrorValidator.evWrongInputList,
2980 TosaErrorValidator.evWrongOutputList,
2981 TosaErrorValidator.evDimensionMismatch,
2982 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002983 },
Matthew Haddon459443c2021-08-23 16:43:13 +01002984 "intdiv": {
2985 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002986 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002987 "build_fcn": (
2988 build_binary_broadcast,
2989 TosaTensorGen.tgBroadcastFuzz,
2990 TosaTensorValuesGen.tvgIntDiv,
2991 None,
2992 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002993 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002994 "error_if_validators": (
2995 TosaErrorValidator.evRankMismatch,
2996 TosaErrorValidator.evWrongInputType,
2997 TosaErrorValidator.evWrongOutputType,
2998 TosaErrorValidator.evWrongInputList,
2999 TosaErrorValidator.evWrongOutputList,
3000 TosaErrorValidator.evDimensionMismatch,
3001 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003002 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003003 "logical_and": {
3004 "op": Op.LOGICAL_AND,
3005 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003006 "build_fcn": (
3007 build_binary_broadcast,
3008 TosaTensorGen.tgBroadcastFuzz,
3009 TosaTensorValuesGen.tvgDefault,
3010 None,
3011 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003012 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003013 "error_if_validators": (
3014 TosaErrorValidator.evRankMismatch,
3015 TosaErrorValidator.evWrongInputType,
3016 TosaErrorValidator.evWrongOutputType,
3017 TosaErrorValidator.evWrongInputList,
3018 TosaErrorValidator.evWrongOutputList,
3019 TosaErrorValidator.evDimensionMismatch,
3020 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003021 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003022 "logical_left_shift": {
3023 "op": Op.LOGICAL_LEFT_SHIFT,
3024 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003025 "build_fcn": (
3026 build_binary_broadcast,
3027 TosaTensorGen.tgBroadcastFuzz,
3028 TosaTensorValuesGen.tvgLogicalShift,
3029 None,
3030 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003031 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003032 "error_if_validators": (
3033 TosaErrorValidator.evRankMismatch,
3034 TosaErrorValidator.evWrongInputType,
3035 TosaErrorValidator.evWrongOutputType,
3036 TosaErrorValidator.evWrongInputList,
3037 TosaErrorValidator.evWrongOutputList,
3038 TosaErrorValidator.evDimensionMismatch,
3039 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003040 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003041 "logical_right_shift": {
3042 "op": Op.LOGICAL_RIGHT_SHIFT,
3043 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003044 "build_fcn": (
3045 build_binary_broadcast,
3046 TosaTensorGen.tgBroadcastFuzz,
3047 TosaTensorValuesGen.tvgLogicalShift,
3048 None,
3049 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003050 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003051 "error_if_validators": (
3052 TosaErrorValidator.evRankMismatch,
3053 TosaErrorValidator.evWrongInputType,
3054 TosaErrorValidator.evWrongOutputType,
3055 TosaErrorValidator.evWrongInputList,
3056 TosaErrorValidator.evWrongOutputList,
3057 TosaErrorValidator.evDimensionMismatch,
3058 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003059 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003060 "logical_or": {
3061 "op": Op.LOGICAL_OR,
3062 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003063 "build_fcn": (
3064 build_binary_broadcast,
3065 TosaTensorGen.tgBroadcastFuzz,
3066 TosaTensorValuesGen.tvgDefault,
3067 None,
3068 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003069 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003070 "error_if_validators": (
3071 TosaErrorValidator.evRankMismatch,
3072 TosaErrorValidator.evWrongInputType,
3073 TosaErrorValidator.evWrongOutputType,
3074 TosaErrorValidator.evWrongInputList,
3075 TosaErrorValidator.evWrongOutputList,
3076 TosaErrorValidator.evDimensionMismatch,
3077 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003078 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003079 "logical_xor": {
3080 "op": Op.LOGICAL_XOR,
3081 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003082 "build_fcn": (
3083 build_binary_broadcast,
3084 TosaTensorGen.tgBroadcastFuzz,
3085 TosaTensorValuesGen.tvgDefault,
3086 None,
3087 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003088 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003089 "error_if_validators": (
3090 TosaErrorValidator.evRankMismatch,
3091 TosaErrorValidator.evWrongInputType,
3092 TosaErrorValidator.evWrongOutputType,
3093 TosaErrorValidator.evWrongInputList,
3094 TosaErrorValidator.evWrongOutputList,
3095 TosaErrorValidator.evDimensionMismatch,
3096 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003097 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003098 "maximum": {
3099 "op": Op.MAXIMUM,
3100 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003101 "build_fcn": (
3102 build_binary_broadcast,
3103 TosaTensorGen.tgBroadcastFuzz,
3104 TosaTensorValuesGen.tvgDefault,
3105 None,
3106 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003107 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003108 "error_if_validators": (
3109 TosaErrorValidator.evRankMismatch,
3110 TosaErrorValidator.evWrongInputType,
3111 TosaErrorValidator.evWrongOutputType,
3112 TosaErrorValidator.evWrongInputList,
3113 TosaErrorValidator.evWrongOutputList,
3114 TosaErrorValidator.evDimensionMismatch,
3115 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003116 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003117 "minimum": {
3118 "op": Op.MINIMUM,
3119 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003120 "build_fcn": (
3121 build_binary_broadcast,
3122 TosaTensorGen.tgBroadcastFuzz,
3123 TosaTensorValuesGen.tvgDefault,
3124 None,
3125 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003126 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003127 "error_if_validators": (
3128 TosaErrorValidator.evRankMismatch,
3129 TosaErrorValidator.evWrongInputType,
3130 TosaErrorValidator.evWrongOutputType,
3131 TosaErrorValidator.evWrongInputList,
3132 TosaErrorValidator.evWrongOutputList,
3133 TosaErrorValidator.evDimensionMismatch,
3134 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003135 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003136 "mul": {
3137 "op": Op.MUL,
3138 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003139 "build_fcn": (
3140 build_mul,
3141 TosaTensorGen.tgBroadcastFuzz,
3142 TosaTensorValuesGen.tvgMul,
3143 TosaArgGen.agMul,
3144 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003145 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003146 "error_if_validators": (
3147 TosaErrorValidator.evWrongInputType,
3148 TosaErrorValidator.evWrongOutputType,
3149 TosaErrorValidator.evWrongInputList,
3150 TosaErrorValidator.evWrongOutputList,
3151 TosaErrorValidator.evRankMismatch,
3152 TosaErrorValidator.evDimensionMismatch,
3153 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003154 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003155 "pow": {
3156 "op": Op.POW,
3157 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003158 "build_fcn": (
3159 build_binary_broadcast,
3160 TosaTensorGen.tgBroadcastFuzz,
3161 TosaTensorValuesGen.tvgDefault,
3162 None,
3163 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003164 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003165 "error_if_validators": (
3166 TosaErrorValidator.evRankMismatch,
3167 TosaErrorValidator.evWrongInputType,
3168 TosaErrorValidator.evWrongOutputType,
3169 TosaErrorValidator.evWrongInputList,
3170 TosaErrorValidator.evWrongOutputList,
3171 TosaErrorValidator.evDimensionMismatch,
3172 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003173 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003174 "sub": {
3175 "op": Op.SUB,
3176 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003177 "build_fcn": (
3178 build_binary_broadcast,
3179 TosaTensorGen.tgBroadcastFuzz,
3180 TosaTensorValuesGen.tvgAddSub,
3181 None,
3182 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003183 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003184 "error_if_validators": (
3185 TosaErrorValidator.evRankMismatch,
3186 TosaErrorValidator.evWrongInputType,
3187 TosaErrorValidator.evWrongOutputType,
3188 TosaErrorValidator.evWrongInputList,
3189 TosaErrorValidator.evWrongOutputList,
3190 TosaErrorValidator.evDimensionMismatch,
3191 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003192 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003193 "table": {
3194 "op": Op.TABLE,
3195 # Use the automatic generation functions to create the input array
3196 # but create the table tensor in the build function, as it may be
3197 # a different type from the input
3198 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003199 "build_fcn": (
3200 build_table,
3201 TosaTensorGen.tgBasic,
3202 TosaTensorValuesGen.tvgDefault,
3203 TosaArgGen.agTable,
3204 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003205 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003206 "error_if_validators": (
3207 TosaErrorValidator.evWrongInputType,
3208 TosaErrorValidator.evWrongOutputType,
3209 TosaErrorValidator.evWrongInputList,
3210 TosaErrorValidator.evWrongOutputList,
3211 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003212 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003213 # Elementwise Unary operators
3214 "abs": {
3215 "op": Op.ABS,
3216 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003217 "build_fcn": (
3218 build_unary,
3219 TosaTensorGen.tgBasic,
3220 TosaTensorValuesGen.tvgDefault,
3221 None,
3222 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003223 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003224 "error_if_validators": (
3225 TosaErrorValidator.evWrongInputType,
3226 TosaErrorValidator.evWrongOutputType,
3227 TosaErrorValidator.evWrongInputList,
3228 TosaErrorValidator.evWrongOutputList,
3229 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003230 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003231 "bitwise_not": {
3232 "op": Op.BITWISE_NOT,
3233 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003234 "build_fcn": (
3235 build_unary,
3236 TosaTensorGen.tgBasic,
3237 TosaTensorValuesGen.tvgDefault,
3238 None,
3239 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003240 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003241 "error_if_validators": (
3242 TosaErrorValidator.evWrongInputType,
3243 TosaErrorValidator.evWrongOutputType,
3244 TosaErrorValidator.evWrongInputList,
3245 TosaErrorValidator.evWrongOutputList,
3246 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003247 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003248 "ceil": {
3249 "op": Op.CEIL,
3250 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003251 "build_fcn": (
3252 build_unary,
3253 TosaTensorGen.tgBasic,
3254 TosaTensorValuesGen.tvgDefault,
3255 None,
3256 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003257 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003258 "error_if_validators": (
3259 TosaErrorValidator.evWrongInputType,
3260 TosaErrorValidator.evWrongOutputType,
3261 TosaErrorValidator.evWrongInputList,
3262 TosaErrorValidator.evWrongOutputList,
3263 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003264 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003265 "clz": {
3266 "op": Op.CLZ,
3267 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003268 "build_fcn": (
3269 build_unary,
3270 TosaTensorGen.tgBasic,
3271 TosaTensorValuesGen.tvgDefault,
3272 None,
3273 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003274 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003275 "error_if_validators": (
3276 TosaErrorValidator.evWrongInputType,
3277 TosaErrorValidator.evWrongOutputType,
3278 TosaErrorValidator.evWrongInputList,
3279 TosaErrorValidator.evWrongOutputList,
3280 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003281 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003282 "exp": {
3283 "op": Op.EXP,
3284 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003285 "build_fcn": (
3286 build_unary,
3287 TosaTensorGen.tgBasic,
3288 TosaTensorValuesGen.tvgDefault,
3289 None,
3290 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003291 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003292 "error_if_validators": (
3293 TosaErrorValidator.evWrongInputType,
3294 TosaErrorValidator.evWrongOutputType,
3295 TosaErrorValidator.evWrongInputList,
3296 TosaErrorValidator.evWrongOutputList,
3297 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003298 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003299 "floor": {
3300 "op": Op.FLOOR,
3301 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003302 "build_fcn": (
3303 build_unary,
3304 TosaTensorGen.tgBasic,
3305 TosaTensorValuesGen.tvgDefault,
3306 None,
3307 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003308 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003309 "error_if_validators": (
3310 TosaErrorValidator.evWrongInputType,
3311 TosaErrorValidator.evWrongOutputType,
3312 TosaErrorValidator.evWrongInputList,
3313 TosaErrorValidator.evWrongOutputList,
3314 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003315 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003316 "log": {
3317 "op": Op.LOG,
3318 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003319 "build_fcn": (
3320 build_unary,
3321 TosaTensorGen.tgBasic,
3322 TosaTensorValuesGen.tvgDefault,
3323 None,
3324 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003325 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003326 "error_if_validators": (
3327 TosaErrorValidator.evWrongInputType,
3328 TosaErrorValidator.evWrongOutputType,
3329 TosaErrorValidator.evWrongInputList,
3330 TosaErrorValidator.evWrongOutputList,
3331 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003332 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003333 "logical_not": {
3334 "op": Op.LOGICAL_NOT,
3335 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003336 "build_fcn": (
3337 build_unary,
3338 TosaTensorGen.tgBasic,
3339 TosaTensorValuesGen.tvgDefault,
3340 None,
3341 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003342 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003343 "error_if_validators": (
3344 TosaErrorValidator.evWrongInputType,
3345 TosaErrorValidator.evWrongOutputType,
3346 TosaErrorValidator.evWrongInputList,
3347 TosaErrorValidator.evWrongOutputList,
3348 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003349 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003350 "negate": {
3351 "op": Op.NEGATE,
3352 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003353 "build_fcn": (
3354 build_unary,
3355 TosaTensorGen.tgBasic,
3356 TosaTensorValuesGen.tvgNegate,
3357 None,
3358 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003359 "qgen": TosaQuantGen.qgUnary,
3360 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003361 "error_if_validators": (
3362 TosaErrorValidator.evInputZeroPointNotZero,
3363 TosaErrorValidator.evOutputZeroPointNotZero,
3364 TosaErrorValidator.evWrongInputType,
3365 TosaErrorValidator.evWrongOutputType,
3366 TosaErrorValidator.evWrongInputList,
3367 TosaErrorValidator.evWrongOutputList,
3368 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003369 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003370 "reciprocal": {
3371 "op": Op.RECIPROCAL,
3372 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003373 "build_fcn": (
3374 build_unary,
3375 TosaTensorGen.tgBasic,
3376 TosaTensorValuesGen.tvgDefault,
3377 None,
3378 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003379 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003380 "error_if_validators": (
3381 TosaErrorValidator.evWrongInputType,
3382 TosaErrorValidator.evWrongOutputType,
3383 TosaErrorValidator.evWrongInputList,
3384 TosaErrorValidator.evWrongOutputList,
3385 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003386 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003387 "rsqrt": {
3388 "op": Op.RSQRT,
3389 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003390 "build_fcn": (
3391 build_unary,
3392 TosaTensorGen.tgBasic,
3393 TosaTensorValuesGen.tvgDefault,
3394 None,
3395 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003396 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003397 "error_if_validators": (
3398 TosaErrorValidator.evWrongInputType,
3399 TosaErrorValidator.evWrongOutputType,
3400 TosaErrorValidator.evWrongInputList,
3401 TosaErrorValidator.evWrongOutputList,
3402 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003403 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003404 # Elementwise Ternary operators
3405 "select": {
3406 "op": Op.SELECT,
3407 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003408 "build_fcn": (
3409 build_select,
3410 TosaTensorGen.tgBroadcastFuzz,
3411 TosaTensorValuesGen.tvgSelect,
3412 None,
3413 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003414 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003415 "error_if_validators": (
3416 TosaErrorValidator.evRankMismatch,
3417 TosaErrorValidator.evWrongInputType,
3418 TosaErrorValidator.evWrongOutputType,
3419 TosaErrorValidator.evWrongInputList,
3420 TosaErrorValidator.evWrongOutputList,
3421 TosaErrorValidator.evDimensionMismatch,
3422 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003423 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003424 # Comparison operators
3425 "equal": {
3426 "op": Op.EQUAL,
3427 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003428 "build_fcn": (
3429 build_comparison,
3430 TosaTensorGen.tgBroadcastFuzz,
3431 TosaTensorValuesGen.tvgEqual,
3432 None,
3433 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003435 "error_if_validators": (
3436 TosaErrorValidator.evRankMismatch,
3437 TosaErrorValidator.evWrongInputType,
3438 TosaErrorValidator.evWrongOutputType,
3439 TosaErrorValidator.evWrongInputList,
3440 TosaErrorValidator.evWrongOutputList,
3441 TosaErrorValidator.evDimensionMismatch,
3442 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003443 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003444 "greater_equal": {
3445 "op": Op.GREATER_EQUAL,
3446 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003447 "build_fcn": (
3448 build_comparison,
3449 TosaTensorGen.tgBroadcastFuzz,
3450 TosaTensorValuesGen.tvgDefault,
3451 None,
3452 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003453 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003454 "error_if_validators": (
3455 TosaErrorValidator.evRankMismatch,
3456 TosaErrorValidator.evWrongInputType,
3457 TosaErrorValidator.evWrongOutputType,
3458 TosaErrorValidator.evWrongInputList,
3459 TosaErrorValidator.evWrongOutputList,
3460 TosaErrorValidator.evDimensionMismatch,
3461 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003462 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 "greater": {
3464 "op": Op.GREATER,
3465 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003466 "build_fcn": (
3467 build_comparison,
3468 TosaTensorGen.tgBroadcastFuzz,
3469 TosaTensorValuesGen.tvgDefault,
3470 None,
3471 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003472 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003473 "error_if_validators": (
3474 TosaErrorValidator.evRankMismatch,
3475 TosaErrorValidator.evWrongInputType,
3476 TosaErrorValidator.evWrongOutputType,
3477 TosaErrorValidator.evWrongInputList,
3478 TosaErrorValidator.evWrongOutputList,
3479 TosaErrorValidator.evDimensionMismatch,
3480 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003481 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003482 # Reduction operators
3483 "reduce_all": {
3484 "op": Op.REDUCE_ALL,
3485 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003486 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003487 "build_fcn": (
3488 build_reduce,
3489 TosaTensorGen.tgBasic,
3490 TosaTensorValuesGen.tvgDefault,
3491 TosaArgGen.agAxis,
3492 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003493 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003494 "error_if_validators": (
3495 TosaErrorValidator.evAxisLargerRank,
3496 TosaErrorValidator.evAxisSmallerZero,
3497 TosaErrorValidator.evShapeOfAxisNotOne,
3498 TosaErrorValidator.evWrongInputType,
3499 TosaErrorValidator.evWrongOutputType,
3500 TosaErrorValidator.evWrongRank,
3501 TosaErrorValidator.evWrongInputList,
3502 TosaErrorValidator.evWrongOutputList,
3503 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003505 "reduce_any": {
3506 "op": Op.REDUCE_ANY,
3507 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003508 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003509 "build_fcn": (
3510 build_reduce,
3511 TosaTensorGen.tgBasic,
3512 TosaTensorValuesGen.tvgDefault,
3513 TosaArgGen.agAxis,
3514 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003516 "error_if_validators": (
3517 TosaErrorValidator.evAxisLargerRank,
3518 TosaErrorValidator.evAxisSmallerZero,
3519 TosaErrorValidator.evShapeOfAxisNotOne,
3520 TosaErrorValidator.evWrongInputType,
3521 TosaErrorValidator.evWrongOutputType,
3522 TosaErrorValidator.evWrongRank,
3523 TosaErrorValidator.evWrongInputList,
3524 TosaErrorValidator.evWrongOutputList,
3525 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003526 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003527 "reduce_max": {
3528 "op": Op.REDUCE_MAX,
3529 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003530 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003531 "build_fcn": (
3532 build_reduce,
3533 TosaTensorGen.tgBasic,
3534 TosaTensorValuesGen.tvgDefault,
3535 TosaArgGen.agAxis,
3536 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003537 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003538 "error_if_validators": (
3539 TosaErrorValidator.evAxisLargerRank,
3540 TosaErrorValidator.evAxisSmallerZero,
3541 TosaErrorValidator.evShapeOfAxisNotOne,
3542 TosaErrorValidator.evWrongInputType,
3543 TosaErrorValidator.evWrongOutputType,
3544 TosaErrorValidator.evWrongRank,
3545 TosaErrorValidator.evWrongInputList,
3546 TosaErrorValidator.evWrongOutputList,
3547 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003550 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003551 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003552 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003553 "build_fcn": (
3554 build_reduce,
3555 TosaTensorGen.tgBasic,
3556 TosaTensorValuesGen.tvgDefault,
3557 TosaArgGen.agAxis,
3558 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003560 "error_if_validators": (
3561 TosaErrorValidator.evAxisLargerRank,
3562 TosaErrorValidator.evAxisSmallerZero,
3563 TosaErrorValidator.evShapeOfAxisNotOne,
3564 TosaErrorValidator.evWrongInputType,
3565 TosaErrorValidator.evWrongOutputType,
3566 TosaErrorValidator.evWrongRank,
3567 TosaErrorValidator.evWrongInputList,
3568 TosaErrorValidator.evWrongOutputList,
3569 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003570 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003571 "reduce_product": {
3572 "op": Op.REDUCE_PRODUCT,
3573 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003574 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003575 "build_fcn": (
3576 build_reduce,
3577 TosaTensorGen.tgBasic,
3578 TosaTensorValuesGen.tvgDefault,
3579 TosaArgGen.agAxis,
3580 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003582 "error_if_validators": (
3583 TosaErrorValidator.evAxisLargerRank,
3584 TosaErrorValidator.evAxisSmallerZero,
3585 TosaErrorValidator.evShapeOfAxisNotOne,
3586 TosaErrorValidator.evWrongInputType,
3587 TosaErrorValidator.evWrongOutputType,
3588 TosaErrorValidator.evWrongRank,
3589 TosaErrorValidator.evWrongInputList,
3590 TosaErrorValidator.evWrongOutputList,
3591 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003592 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003593 "reduce_sum": {
3594 "op": Op.REDUCE_SUM,
3595 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003596 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003597 "build_fcn": (
3598 build_reduce,
3599 TosaTensorGen.tgBasic,
3600 TosaTensorValuesGen.tvgReduceSum,
3601 TosaArgGen.agAxis,
3602 ),
James Ward24dbc422022-10-19 12:20:31 +01003603 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003604 "error_if_validators": (
3605 TosaErrorValidator.evAxisLargerRank,
3606 TosaErrorValidator.evAxisSmallerZero,
3607 TosaErrorValidator.evShapeOfAxisNotOne,
3608 TosaErrorValidator.evWrongInputType,
3609 TosaErrorValidator.evWrongOutputType,
3610 TosaErrorValidator.evWrongRank,
3611 TosaErrorValidator.evWrongInputList,
3612 TosaErrorValidator.evWrongOutputList,
3613 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003614 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003615 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003616 "concat": {
3617 "op": Op.CONCAT,
3618 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003619 "build_fcn": (
3620 build_concat,
3621 TosaTensorGen.tgConcat,
3622 TosaTensorValuesGen.tvgConcat,
3623 TosaArgGen.agAxis,
3624 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003625 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003626 "error_if_validators": (
3627 TosaErrorValidator.evAxisLargerRank,
3628 TosaErrorValidator.evAxisSmallerZero,
3629 TosaErrorValidator.evConcatInputRankMismatch,
3630 TosaErrorValidator.evConcatShapeSumMismatch,
3631 TosaErrorValidator.evConcatInputDimMismatch,
3632 TosaErrorValidator.evWrongInputType,
3633 TosaErrorValidator.evWrongOutputType,
3634 TosaErrorValidator.evWrongOutputList,
3635 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003636 },
3637 "pad": {
3638 "op": Op.PAD,
3639 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003640 "rank": (1, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003641 "build_fcn": (
3642 build_pad,
3643 TosaTensorGen.tgBasic,
3644 TosaTensorValuesGen.tvgDefault,
3645 TosaArgGen.agPad,
3646 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003647 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003648 "error_if_validators": (
3649 TosaErrorValidator.evWrongInputType,
3650 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003651 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003652 TosaErrorValidator.evWrongOutputType,
3653 TosaErrorValidator.evWrongInputList,
3654 TosaErrorValidator.evWrongOutputList,
3655 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003656 },
3657 "reshape": {
3658 "op": Op.RESHAPE,
3659 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003660 "build_fcn": (
3661 build_reshape,
3662 TosaTensorGen.tgBasic,
3663 TosaTensorValuesGen.tvgDefault,
3664 TosaArgGen.agReshape,
3665 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003666 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003667 "error_if_validators": (
3668 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3669 TosaErrorValidator.evWrongInputType,
3670 TosaErrorValidator.evWrongOutputType,
3671 TosaErrorValidator.evWrongInputList,
3672 TosaErrorValidator.evWrongOutputList,
3673 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003674 },
3675 "reverse": {
3676 "op": Op.REVERSE,
3677 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003678 "build_fcn": (
3679 build_reverse,
3680 TosaTensorGen.tgBasic,
3681 TosaTensorValuesGen.tvgDefault,
3682 TosaArgGen.agAxis,
3683 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003684 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003685 "error_if_validators": (
3686 TosaErrorValidator.evAxisSmallerZero,
3687 TosaErrorValidator.evAxisLargerRank,
3688 TosaErrorValidator.evWrongInputType,
3689 TosaErrorValidator.evWrongOutputType,
3690 TosaErrorValidator.evWrongInputList,
3691 TosaErrorValidator.evWrongOutputList,
3692 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003693 },
3694 "slice": {
3695 "op": Op.SLICE,
3696 "operands": (1, 0),
Matthew Haddone807aae2021-10-11 18:12:58 +01003697 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003698 "build_fcn": (
3699 build_slice,
3700 TosaTensorGen.tgBasic,
3701 TosaTensorValuesGen.tvgDefault,
3702 TosaArgGen.agSlice,
3703 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003704 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003705 "error_if_validators": (
3706 TosaErrorValidator.evStartSmallerZero,
3707 TosaErrorValidator.evSizeSmallerEqualZero,
3708 TosaErrorValidator.evStartSizeOutsideBounds,
3709 TosaErrorValidator.evSizeOutputShapeMismatch,
3710 TosaErrorValidator.evInputSizeStartLengthMismatch,
3711 TosaErrorValidator.evWrongRank,
3712 TosaErrorValidator.evWrongInputType,
3713 TosaErrorValidator.evWrongOutputType,
3714 TosaErrorValidator.evWrongInputList,
3715 TosaErrorValidator.evWrongOutputList,
3716 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003717 },
3718 "tile": {
3719 "op": Op.TILE,
3720 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003721 "build_fcn": (
3722 build_tile,
3723 TosaTensorGen.tgBasic,
3724 TosaTensorValuesGen.tvgDefault,
3725 TosaArgGen.agTile,
3726 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003727 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003728 "error_if_validators": (
3729 TosaErrorValidator.evWrongInputType,
3730 TosaErrorValidator.evWrongOutputType,
3731 TosaErrorValidator.evWrongInputList,
3732 TosaErrorValidator.evWrongOutputList,
3733 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003734 },
3735 "transpose": {
3736 "op": Op.TRANSPOSE,
3737 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01003738 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003739 "build_fcn": (
3740 build_transpose,
3741 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003742 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003743 TosaArgGen.agTranspose,
3744 ),
3745 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003746 "error_if_validators": (
3747 TosaErrorValidator.evIndexOutsideBounds,
3748 TosaErrorValidator.evIndexUsedTwice,
3749 TosaErrorValidator.evWrongInputType,
3750 TosaErrorValidator.evWrongOutputType,
3751 TosaErrorValidator.evWrongInputList,
3752 TosaErrorValidator.evWrongOutputList,
3753 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003754 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003755 # Data nodes
3756 "const": {
3757 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003758 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003759 "build_fcn": (
3760 build_const,
3761 TosaTensorGen.tgBasic,
3762 TosaTensorValuesGen.tvgDefault,
3763 None,
3764 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003765 "types": TYPE_FIB,
3766 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003767 "identity": {
3768 "op": Op.IDENTITY,
3769 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003770 "build_fcn": (
3771 build_unary,
3772 TosaTensorGen.tgBasic,
3773 TosaTensorValuesGen.tvgDefault,
3774 None,
3775 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003776 "types": TYPE_FIB,
3777 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003778 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003779 "gather": {
3780 "op": Op.GATHER,
3781 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3782 "operands": (1, 0),
3783 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003784 "build_fcn": (
3785 build_gather,
3786 TosaTensorGen.tgBasic,
3787 TosaTensorValuesGen.tvgDefault,
3788 None,
3789 ),
James Ward24dbc422022-10-19 12:20:31 +01003790 "types": (
3791 DType.INT8,
3792 DType.INT16,
3793 DType.INT32,
3794 DType.FP16,
3795 DType.BF16,
3796 DType.FP32,
3797 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003798 "error_if_validators": (
3799 TosaErrorValidator.evWrongInputType,
3800 TosaErrorValidator.evWrongOutputType,
3801 TosaErrorValidator.evWrongInputList,
3802 TosaErrorValidator.evWrongOutputList,
3803 TosaErrorValidator.evWrongRank,
3804 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003805 },
3806 "scatter": {
3807 "op": Op.SCATTER,
3808 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003809 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003810 "operands": (2, 0),
3811 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003812 "build_fcn": (
3813 build_scatter,
3814 TosaTensorGen.tgScatter,
3815 TosaTensorValuesGen.tvgDefault,
3816 None,
3817 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003818 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 "error_if_validators": (
3820 TosaErrorValidator.evWrongInputType,
3821 TosaErrorValidator.evWrongOutputType,
3822 TosaErrorValidator.evWrongInputList,
3823 TosaErrorValidator.evWrongOutputList,
3824 TosaErrorValidator.evWrongRank,
3825 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003826 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003827 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003828 "resize": {
3829 "op": Op.RESIZE,
3830 "operands": (1, 0),
3831 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003832 "build_fcn": (
3833 build_resize,
3834 TosaTensorGen.tgNHWC,
3835 TosaTensorValuesGen.tvgDefault,
3836 TosaArgGen.agResize,
3837 ),
James Ward24dbc422022-10-19 12:20:31 +01003838 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 "invalid_test_validators": (
3840 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003841 ),
3842 "error_if_validators": (
3843 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003844 TosaErrorValidator.evScaleSmallerEqualZero,
3845 TosaErrorValidator.evScaleNLargerMax,
3846 TosaErrorValidator.evScaleDLargerMax,
3847 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003848 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003849 TosaErrorValidator.evBorderSmallerMin,
3850 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003851 TosaErrorValidator.evWrongInputType,
3852 TosaErrorValidator.evWrongOutputType,
3853 TosaErrorValidator.evWrongRank,
3854 TosaErrorValidator.evWrongInputList,
3855 TosaErrorValidator.evWrongOutputList,
3856 TosaErrorValidator.evBatchMismatch,
3857 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003858 TosaErrorValidator.evResizeOutputShapeMismatch,
3859 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003860 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003861 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003862 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003863 "cast": {
3864 "op": Op.CAST,
3865 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003866 "build_fcn": (
3867 build_cast,
3868 TosaTensorGen.tgBasic,
3869 TosaTensorValuesGen.tvgDefault,
3870 TosaArgGen.agCast,
3871 ),
James Ward8b390432022-08-12 20:48:56 +01003872 "types": (
3873 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003874 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003875 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003876 DType.INT8,
3877 DType.INT16,
3878 DType.INT32,
3879 DType.BOOL,
3880 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003881 "error_if_validators": (
3882 TosaErrorValidator.evWrongInputType,
3883 TosaErrorValidator.evWrongOutputType,
3884 TosaErrorValidator.evWrongInputList,
3885 TosaErrorValidator.evWrongOutputList,
3886 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003887 },
3888 "rescale": {
3889 "op": Op.RESCALE,
3890 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003891 "build_fcn": (
3892 build_rescale,
3893 TosaTensorGen.tgBasic,
3894 TosaTensorValuesGen.tvgDefault,
3895 TosaArgGen.agRescale,
3896 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003897 "types": [
3898 DType.UINT8,
3899 DType.INT8,
3900 DType.INT16,
3901 DType.INT32,
3902 DType.INT48,
3903 DType.UINT16,
3904 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003905 "error_if_validators": (
3906 TosaErrorValidator.evInputZeroPointNotZero,
3907 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003908 TosaErrorValidator.evU16InputZeroPointNotValid,
3909 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003910 TosaErrorValidator.evScaleTrue,
3911 TosaErrorValidator.evScaleNotTrue,
3912 TosaErrorValidator.evWrongInputType,
3913 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003914 TosaErrorValidator.evWrongInputList,
3915 TosaErrorValidator.evWrongOutputList,
3916 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003917 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003918 # Custom
3919 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003921 # Two varients of cond_if, one that generates one of two constant tensors (no
3922 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3923 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003924 "cond_if_const": {
3925 "op": Op.COND_IF,
3926 "operands": (0, 2),
3927 "build_fcn": (
3928 build_cond_if_const,
3929 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003930 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003931 TosaArgGen.agCondIf,
3932 ),
3933 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003934 "error_if_validators": (
3935 TosaErrorValidator.evOutputListThenGraphMismatch,
3936 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003937 TosaErrorValidator.evCondIfCondNotMatchingBool,
3938 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003940 },
3941 "cond_if_binary": {
3942 "op": Op.COND_IF,
3943 "operands": (2, 0),
3944 "build_fcn": (
3945 build_cond_if_binary,
3946 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003947 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003948 TosaArgGen.agCondIf,
3949 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003950 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003951 "error_if_validators": (
3952 TosaErrorValidator.evInputListThenGraphMismatch,
3953 TosaErrorValidator.evInputListElseGraphMismatch,
3954 TosaErrorValidator.evOutputListThenGraphMismatch,
3955 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003956 TosaErrorValidator.evCondIfCondNotMatchingBool,
3957 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003958 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003959 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003960 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003961 "while_loop": {
3962 "op": Op.WHILE_LOOP,
3963 "operands": (0, 1),
3964 "build_fcn": (
3965 build_while_loop,
3966 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003967 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003968 TosaArgGen.agWhileLoop,
3969 ),
3970 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003971 "error_if_validators": (
3972 TosaErrorValidator.evInputListOutputListMismatch,
3973 TosaErrorValidator.evInputListCondGraphMismatch,
3974 TosaErrorValidator.evInputListBodyGraphInputMismatch,
3975 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
3976 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003977 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003978 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003979 },
Luke Hutton57287132023-02-06 14:54:18 +00003980 "fft2d": {
3981 "op": Op.FFT2D,
3982 "operands": (2, 0),
3983 "rank": (3, 3),
3984 "build_fcn": (
3985 build_fft2d,
3986 TosaTensorGen.tgFFT2d,
3987 TosaTensorValuesGen.tvgDefault,
3988 TosaArgGen.agFFT2d,
3989 ),
3990 "types": [DType.FP32],
3991 "error_if_validators": (
3992 TosaErrorValidator.evWrongInputType,
3993 TosaErrorValidator.evWrongOutputType,
3994 TosaErrorValidator.evWrongInputList,
3995 TosaErrorValidator.evWrongOutputList,
3996 TosaErrorValidator.evWrongRank,
3997 TosaErrorValidator.evBatchMismatch,
3998 TosaErrorValidator.evKernelNotPowerOfTwo,
3999 TosaErrorValidator.evFFTInputShapeMismatch,
4000 TosaErrorValidator.evFFTOutputShapeMismatch,
4001 ),
4002 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004003 "rfft2d": {
4004 "op": Op.RFFT2D,
4005 "operands": (1, 0),
4006 "rank": (3, 3),
4007 "build_fcn": (
4008 build_rfft2d,
4009 TosaTensorGen.tgRFFT2d,
4010 TosaTensorValuesGen.tvgDefault,
4011 TosaArgGen.agNone,
4012 ),
4013 "types": [DType.FP32],
4014 "error_if_validators": (
4015 TosaErrorValidator.evWrongInputType,
4016 TosaErrorValidator.evWrongOutputType,
4017 TosaErrorValidator.evWrongInputList,
4018 TosaErrorValidator.evWrongOutputList,
4019 TosaErrorValidator.evWrongRank,
4020 TosaErrorValidator.evBatchMismatch,
4021 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004022 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004023 ),
4024 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004025 }
4026
Kevin Cheng550ccc52021-03-03 11:21:43 -08004027
Eric Kunzee5e26762020-10-13 16:11:07 -07004028class OutputShaper:
4029 # Methods in this class compute the expected output shape and datatype
4030 # for common classes of operations
4031 def __init__(self):
4032 pass
4033
4034 # These methods return arguments that can be used for
4035 # creating a new output tensor
4036 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004037 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4038 if error_name != ErrorIf.RankMismatch:
4039 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004040 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004041
4042 shape = []
4043 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004044 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004045 shape.append(b.shape[i])
4046 else:
4047 shape.append(a.shape[i])
4048
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004049 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004050 all_dtypes = [
4051 DType.INT8,
4052 DType.INT16,
4053 DType.INT32,
4054 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004055 DType.FP16,
4056 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004057 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004058 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004059 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4060 outputDType = rng.choice(wrong_dtypes)
4061 else:
4062 outputDType = a.dtype
4063
4064 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004065
4066 @staticmethod
4067 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004068 assert len(a.shape) == len(b.shape)
4069 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004070
4071 shape = []
4072 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004073 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004074 shape.append(a.shape[i])
4075
Kevin Cheng550ccc52021-03-03 11:21:43 -08004076 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004077
4078 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004079 def unaryOp(ser, rng, a, error_name=None):
4080 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004081 all_dtypes = [
4082 DType.INT8,
4083 DType.INT16,
4084 DType.INT32,
4085 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004086 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004087 DType.FP16,
4088 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004089 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004090 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4091 outputDType = rng.choice(wrong_dtypes)
4092 else:
4093 outputDType = a.dtype
4094
4095 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004096
4097 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004098 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004099 if error_name != ErrorIf.RankMismatch:
4100 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004101 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004102
4103 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004104 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004105 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004106 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4107 else:
4108 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004109
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004110 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004111 all_dtypes = [
4112 DType.INT8,
4113 DType.INT16,
4114 DType.INT32,
4115 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004116 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004117 DType.FP16,
4118 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004119 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004120 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4121 outputDType = rng.choice(wrong_dtypes)
4122 else:
4123 outputDType = a.dtype
4124
4125 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004126
4127 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004128 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004129 if error_name != ErrorIf.RankMismatch:
4130 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004131 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004132
4133 # Do broadcast
4134 shape = []
4135 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004136 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004137 shape.append(b.shape[i])
4138 else:
4139 shape.append(a.shape[i])
4140
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004141 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004142 wrong_dtypes = [
4143 DType.INT8,
4144 DType.INT16,
4145 DType.INT32,
4146 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004147 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004148 DType.FP16,
4149 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004150 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004151 outputDType = rng.choice(wrong_dtypes)
4152 else:
4153 outputDType = DType.BOOL
4154
4155 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004156
4157 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004158 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004159 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004160 if error_name not in [
4161 ErrorIf.AxisSmallerZero,
4162 ErrorIf.AxisLargerRank,
4163 ErrorIf.ShapeOfAxisNotOne,
4164 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004165 shape[axis] = 1
4166 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4167 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004168
Matthew Haddond6ce7252021-09-29 15:35:44 +01004169 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004170 all_dtypes = [
4171 DType.INT8,
4172 DType.INT16,
4173 DType.INT32,
4174 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004175 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004176 DType.FP16,
4177 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004178 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004179 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4180 outputDType = rng.choice(wrong_dtypes)
4181 else:
4182 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004183
Matthew Haddond6ce7252021-09-29 15:35:44 +01004184 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004185
4186 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004187 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004188 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004189
4190 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4191 del shape[axis]
4192
4193 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4194 remove = rng.choice([True, False])
4195 if remove and len(shape) > 1:
4196 del shape[0]
4197 else:
4198 shape.append(1)
4199 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4200 for i in range(len(shape)):
4201 shape[i] = shape[i] + rng.integers(1, 10)
4202
4203 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004204 all_dtypes = [
4205 DType.INT8,
4206 DType.INT16,
4207 DType.INT32,
4208 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004209 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004210 DType.FP16,
4211 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004212 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004213 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4214 outputDType = rng.choice(wrong_dtypes)
4215 else:
4216 outputDType = DType.INT32
4217
4218 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004219
4220 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004221 def conv2dOp(
4222 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4223 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004224
4225 # IFM: NHWC
4226 # Filter: OHWI
4227 # OFM: NHWC
4228
Kevin Cheng550ccc52021-03-03 11:21:43 -08004229 h = (
4230 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004231 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004232 + padding[0]
4233 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004234 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004235 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004236
Kevin Cheng550ccc52021-03-03 11:21:43 -08004237 w = (
4238 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004239 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004240 + padding[2]
4241 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004242 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004243 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004244
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004245 if error_name == ErrorIf.ConvOutputShapeMismatch:
4246 choices = [1, 2, 3]
4247 change = rng.choice(choices)
4248 # increment in multiples of stride to not hit non-integer error case
4249 if change in [1, 3]:
4250 h = h + (rng.choice(choices) * strides[0])
4251 if change in [2, 3]:
4252 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004253
Eric Kunzee5e26762020-10-13 16:11:07 -07004254 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4255
James Ward8b390432022-08-12 20:48:56 +01004256 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004257 # Pick some potentially correct output dtype if input type is incorrect
4258 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004259 else:
James Ward8b390432022-08-12 20:48:56 +01004260 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004261
4262 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004263 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004264 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004265 else:
4266 excludes = [out_dtype]
4267 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004268 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004269
Kevin Cheng550ccc52021-03-03 11:21:43 -08004270 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004271
4272 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004273 def conv3dOp(
4274 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4275 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004276
4277 # IFM: NDHWC
4278 # Filter: ODHWI
4279 # OFM: NDHWC
4280
4281 d = (
4282 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004283 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004284 + padding[0]
4285 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004286 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004287 ) // strides[0] + 1
4288
4289 h = (
4290 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004291 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004292 + padding[2]
4293 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004294 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004295 ) // strides[1] + 1
4296
4297 w = (
4298 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004299 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004300 + padding[4]
4301 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004302 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004303 ) // strides[2] + 1
4304
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004305 if error_name == ErrorIf.ConvOutputShapeMismatch:
4306 choices = [1, 2, 3, 4]
4307 change = rng.choice(choices)
4308 # increment in multiples of stride to not hit non-integer error case
4309 if change in [1, 4]:
4310 d = d + (rng.choice(choices) * strides[0])
4311 if change in [2, 4]:
4312 h = h + (rng.choice(choices) * strides[1])
4313 if change in [3, 4]:
4314 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004315
Kevin Cheng1533b852021-09-01 12:51:58 -07004316 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4317
James Ward8b390432022-08-12 20:48:56 +01004318 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004319 # Pick some potentially correct output dtype if input type is incorrect
4320 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004321 else:
James Ward8b390432022-08-12 20:48:56 +01004322 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004323
4324 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004325 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004326 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004327 else:
4328 excludes = [out_dtype]
4329 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004330 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004331
4332 return ser.addOutput(ofm_shape, out_dtype)
4333
4334 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004335 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004336 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004337 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004338 # IFM: NHWC
4339 # Filter: HWCM
4340 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004341
Kevin Cheng550ccc52021-03-03 11:21:43 -08004342 h = (
4343 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004344 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004345 + padding[0]
4346 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004347 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004348 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004349
Kevin Cheng550ccc52021-03-03 11:21:43 -08004350 w = (
4351 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004352 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004353 + padding[2]
4354 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004355 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004356 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004357
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004358 if error_name == ErrorIf.ConvOutputShapeMismatch:
4359 choices = [1, 2, 3]
4360 change = rng.choice(choices)
4361 # increment in multiples of stride to not hit non-integer error case
4362 if change in [1, 3]:
4363 h = h + (rng.choice(choices) * strides[0])
4364 if change in [2, 3]:
4365 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004366
Eric Kunzee5e26762020-10-13 16:11:07 -07004367 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4368
James Ward8b390432022-08-12 20:48:56 +01004369 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004370 # Pick some potentially correct output dtype if input type is incorrect
4371 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004372 else:
James Ward8b390432022-08-12 20:48:56 +01004373 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004374
4375 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004376 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004377 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004378 else:
4379 excludes = [out_dtype]
4380 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004381 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004382
Kevin Cheng550ccc52021-03-03 11:21:43 -08004383 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004384
4385 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004386 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004387 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004388 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004389 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004390 h = 1
4391 w = 1
4392 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004393 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4394 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004395
4396 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004397 choices = [1, 2, 3]
4398 change = rng.choice(choices)
4399 # increment in multiples of stride to not hit non-integer error case
4400 if change in [1, 3]:
4401 h = h + (rng.choice(choices) * stride[0])
4402 if change in [2, 3]:
4403 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004404 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004405
4406 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004407 all_dtypes = [
4408 DType.INT8,
4409 DType.INT16,
4410 DType.INT32,
4411 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004412 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004413 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004414 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004415 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004416 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4417 outputDType = rng.choice(wrong_dtypes)
4418 else:
4419 outputDType = ifm.dtype
4420
4421 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004422
4423 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004424 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004425 # input: N, IC
4426 # filter: OC, IC
4427 # output: N, OC
4428
4429 output_shape = [input.shape[0], filter.shape[0]]
4430
James Ward8b390432022-08-12 20:48:56 +01004431 # Validated in arg_gen (also invalidated for ErrorIf)
4432 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004433
Kevin Cheng550ccc52021-03-03 11:21:43 -08004434 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004435
4436 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004437 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004438 # a: N, H, C
4439 # b: N, C, W
4440 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004441
Kevin Cheng2d60f002021-06-09 14:18:32 -07004442 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004443
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004444 if error_name == ErrorIf.WrongOutputType:
4445 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004446 incorrect_types = (
4447 DType.INT4,
4448 DType.INT8,
4449 DType.INT16,
4450 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004451 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004452 DType.FP16,
4453 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004454 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004455 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004456 incorrect_types = (
4457 DType.INT4,
4458 DType.INT8,
4459 DType.INT16,
4460 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004461 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004462 DType.FP16,
4463 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004464 )
James Ward24dbc422022-10-19 12:20:31 +01004465 elif (
4466 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4467 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004468 incorrect_types = (
4469 DType.INT4,
4470 DType.INT8,
4471 DType.INT16,
4472 DType.INT32,
4473 DType.INT48,
4474 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004475 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004476 elif error_name == ErrorIf.WrongInputType:
4477 # Pick some potentially correct output dtype if input type is incorrect
4478 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004479 else:
James Ward8b390432022-08-12 20:48:56 +01004480 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004481
Kevin Cheng550ccc52021-03-03 11:21:43 -08004482 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004483
4484 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004485 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004486 input1 = a[0]
4487 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004488
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004489 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004490 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004491 if not (
4492 # unable to concat tensors of different ranks
4493 error_name == ErrorIf.ConcatInputRankMismatch
4494 # unable to concat tensors along an invalid axis
4495 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004496 ):
4497 for tensor in remaining_inputs:
4498 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004499
Matthew Haddon01c359d2021-10-15 16:30:48 +01004500 if error_name == ErrorIf.ConcatShapeSumMismatch:
4501 output_shape[axis] += rng.integers(5, 10)
4502
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004503 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004504 all_dtypes = {
4505 DType.INT8,
4506 DType.INT16,
4507 DType.INT32,
4508 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004509 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004510 DType.FP16,
4511 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004512 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004513 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4514 outputDType = rng.choice(wrong_dtypes)
4515 else:
4516 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004517
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004518 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004519
4520 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004521 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004522
4523 output_shape = a.shape.copy()
4524
4525 for i in range(len(output_shape)):
4526 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4527
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004528 if error_name == ErrorIf.PadOutputShapeMismatch:
4529 bad_dim = rng.choice(range(len(output_shape)))
4530 output_shape[bad_dim] -= rng.choice([1, 2])
4531
Matthew Haddone807aae2021-10-11 18:12:58 +01004532 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004533 all_dtypes = [
4534 DType.INT8,
4535 DType.INT16,
4536 DType.INT32,
4537 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004538 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004539 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004540 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004541 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004542 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4543 outputDType = rng.choice(wrong_dtypes)
4544 else:
4545 outputDType = a.dtype
4546
4547 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004548
4549 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004550 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004551 output_shape = shape.copy()
4552
Matthew Haddone807aae2021-10-11 18:12:58 +01004553 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4554 for i in range(len(output_shape)):
4555 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4556
4557 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004558 all_dtypes = [
4559 DType.INT8,
4560 DType.INT16,
4561 DType.INT32,
4562 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004563 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004564 DType.FP16,
4565 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004566 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004567 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4568 outputDType = rng.choice(wrong_dtypes)
4569 else:
4570 outputDType = a.dtype
4571
4572 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004573
4574 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004575 def sliceOp(ser, rng, a, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004576
Matthew Haddone807aae2021-10-11 18:12:58 +01004577 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004578 all_dtypes = [
4579 DType.INT8,
4580 DType.INT16,
4581 DType.INT32,
4582 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004583 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004584 DType.FP16,
4585 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004586 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004587 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4588 outputDType = rng.choice(wrong_dtypes)
4589 else:
4590 outputDType = a.dtype
4591
4592 if error_name == ErrorIf.SizeOutputShapeMismatch:
4593 output_shape = size.copy()
4594 for index in range(len(output_shape)):
4595 if output_shape[index] <= 2:
4596 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4597 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004598 output_shape[index] = output_shape[index] + rng.choice(
4599 [-2, -1, 1, 2]
4600 )
Matthew Haddone807aae2021-10-11 18:12:58 +01004601 else:
4602 output_shape = size.copy()
4603
4604 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004605
4606 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004607 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004608
4609 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004610 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004611
4612 for i in range(len(output_shape)):
4613 output_shape[i] = a.shape[i] * multiples[i]
4614
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004615 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004616 all_dtypes = [
4617 DType.INT8,
4618 DType.INT16,
4619 DType.INT32,
4620 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004621 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004622 DType.FP16,
4623 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004624 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004625 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4626 outputDType = rng.choice(wrong_dtypes)
4627 else:
4628 outputDType = a.dtype
4629
4630 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004631
4632 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004633 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004634 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004635
Kevin Cheng550ccc52021-03-03 11:21:43 -08004636 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004637
Matthew Haddone807aae2021-10-11 18:12:58 +01004638 if error_name == ErrorIf.IndexOutsideBounds:
4639 for i in range(len(output_shape)):
4640 output_shape[i] = a.shape[0]
4641 else:
4642 for i in range(len(output_shape)):
4643 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004644
Matthew Haddone807aae2021-10-11 18:12:58 +01004645 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004646 all_dtypes = [
4647 DType.INT8,
4648 DType.INT16,
4649 DType.INT32,
4650 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004651 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004652 DType.FP16,
4653 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004654 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004655 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4656 outputDType = rng.choice(wrong_dtypes)
4657 else:
4658 outputDType = a.dtype
4659
4660 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004661
4662 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004663 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004664 if error_name != ErrorIf.WrongRank:
4665 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004666 assert len(indices.shape) == 2
4667 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004668
Kevin Cheng77d0f762020-11-24 10:26:32 -08004669 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4670
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004671 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004672 all_dtypes = [
4673 DType.INT8,
4674 DType.INT16,
4675 DType.INT32,
4676 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004677 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004678 DType.FP16,
4679 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004680 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004681 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4682 outputDType = rng.choice(wrong_dtypes)
4683 else:
4684 outputDType = values.dtype
4685
4686 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004687
4688 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004689 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004690 if error_name != ErrorIf.WrongRank:
4691 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004692 assert len(indices.shape) == 2
4693 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004694 assert values_in.shape[0] == indices.shape[0] # N
4695 assert input.shape[1] == indices.shape[1] # W
4696 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004697
4698 output_shape = values_in.shape
4699
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004700 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004701 all_dtypes = [
4702 DType.INT8,
4703 DType.INT16,
4704 DType.INT32,
4705 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004706 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004707 DType.FP16,
4708 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004709 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004710 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4711 outputDType = rng.choice(wrong_dtypes)
4712 else:
4713 outputDType = values_in.dtype
4714
4715 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004716
4717 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004718 def tableOp(ser, rng, input, error_name=None):
4719 # Same shape as the input, dtype dependent on input dtype
4720 if error_name != ErrorIf.WrongInputType:
4721 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004722 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004723 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004724 wrong_dtypes = [
4725 DType.INT8,
4726 DType.INT16,
4727 DType.INT32,
4728 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004729 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004730 DType.FP16,
4731 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004732 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004733 wrong_dtypes.remove(output_dtype)
4734 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004735 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004736
4737 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004738 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004739 serializer,
4740 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004741 input,
4742 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004743 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004744 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004745 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004746 input_dtype,
4747 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004748 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004749 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004750 # Calculate OH, OW
4751 scale_y_n = scale[0]
4752 scale_y_d = scale[1]
4753 scale_x_n = scale[2]
4754 scale_x_d = scale[3]
4755 if error_name == ErrorIf.ScaleSmallerEqualZero:
4756 scale_y_n = max(scale_y_n, 1)
4757 scale_y_d = max(scale_y_d, 1)
4758 scale_x_n = max(scale_x_n, 1)
4759 scale_x_d = max(scale_x_d, 1)
4760
4761 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4762 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4763
4764 if error_name is not None:
4765 # Make sure the output tensor is valid, which can occur when
4766 # scale, offset or border have been changed for ERROR_IFs
4767 oh = max(oh, 1)
4768 ow = max(ow, 1)
4769 if error_name != ErrorIf.MaxDimExceeded:
4770 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4771 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4772
4773 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4774 choices = [1, 2, 3]
4775 change = rng.choice(choices)
4776 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4777 if change in [1, 3]:
4778 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4779 oh -= scale_y_d
4780 assert oh > 0 # Should have been caught in agResize
4781 else:
4782 oh += scale_y_d
4783 if change in [2, 3]:
4784 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4785 ow -= scale_x_d
4786 assert ow > 0 # Should have been caught in agResize
4787 else:
4788 ow += scale_x_d
4789
Matthew Haddon848efb42021-09-09 12:30:53 +01004790 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004791 output_dims = [
4792 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004793 oh,
4794 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004795 input.shape[0],
4796 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004797 elif error_name == ErrorIf.BatchMismatch:
4798 output_dims = [
4799 input.shape[0] + rng.integers(1, 10),
4800 oh,
4801 ow,
4802 input.shape[3],
4803 ]
4804 elif error_name == ErrorIf.ChannelMismatch:
4805 output_dims = [
4806 input.shape[0],
4807 oh,
4808 ow,
4809 input.shape[3] + rng.integers(1, 10),
4810 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004811 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004812 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004813
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004814 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004815
4816 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004817 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004818 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004819
4820 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004821 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004822 if error_name == ErrorIf.ConvOutputShapeMismatch:
4823 choices = [1, 2, 3]
4824 change = rng.choice(choices)
4825 if change in [1, 3]:
4826 output_shape[1] = output_shape[1] + rng.choice(choices)
4827 if change in [2, 3]:
4828 output_shape[2] = output_shape[2] + rng.choice(choices)
4829
James Ward8b390432022-08-12 20:48:56 +01004830 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004831 # Pick some potentially correct output dtype if input type is incorrect
4832 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004833 else:
James Ward8b390432022-08-12 20:48:56 +01004834 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004835
4836 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004837 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004838 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004839 else:
4840 excludes = [out_dtype]
4841 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004842 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004843
Kevin Cheng550ccc52021-03-03 11:21:43 -08004844 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00004845
4846 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00004847 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
4848 outputs = []
4849
4850 assert ifm1.dtype == ifm2.dtype
4851 input_dtype = ifm1.dtype
4852
4853 if error_name != ErrorIf.FFTInputShapeMismatch:
4854 assert ifm1.shape == ifm2.shape
4855
4856 input_shape = ifm1.shape
4857 if error_name != ErrorIf.WrongRank:
4858 assert len(input_shape) == 3
4859
4860 output_shape = input_shape.copy()
4861 output_dtype = input_dtype
4862
4863 if error_name == ErrorIf.WrongOutputType:
4864 excludes = [DType.FP32]
4865 wrong_dtypes = list(usableDTypes(excludes=excludes))
4866 output_dtype = rng.choice(wrong_dtypes)
4867 elif error_name == ErrorIf.BatchMismatch:
4868 output_shape[0] += rng.integers(1, 10)
4869 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4870 modify_dim = rng.choice([1, 2])
4871 output_shape[modify_dim] += rng.integers(1, 10)
4872
4873 outputs.append(serializer.addOutput(output_shape, output_dtype))
4874 outputs.append(serializer.addOutput(output_shape, output_dtype))
4875 return outputs
4876
4877 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00004878 def rfft2dOp(serializer, rng, value, error_name=None):
4879 outputs = []
4880
4881 input_shape = value.shape
4882 if error_name != ErrorIf.WrongRank:
4883 assert len(input_shape) == 3
4884
4885 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
4886
4887 output_dtype = value.dtype
4888 if error_name == ErrorIf.WrongOutputType:
4889 excludes = [DType.FP32]
4890 wrong_dtypes = list(usableDTypes(excludes=excludes))
4891 output_dtype = rng.choice(wrong_dtypes)
4892 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00004893 output_shape[0] += rng.integers(1, 10)
4894 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4895 modify_dim = rng.choice([1, 2])
4896 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00004897
4898 outputs.append(serializer.addOutput(output_shape, output_dtype))
4899 outputs.append(serializer.addOutput(output_shape, output_dtype))
4900 return outputs