blob: 66084b47509abb3ae1a169cdb61456e394b21ca7 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Luke Huttona4e48ca2023-02-22 11:53:48 +000017from generator.tosa_utils import get_rank_mismatch_shape
Jeremy Johnson05c711e2022-12-12 18:00:41 +000018from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010019from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010021from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
25
Eric Kunzee5e26762020-10-13 16:11:07 -070026class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010027 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000028 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010029 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010030 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010031 TOSA_8K_LEVEL_MAX_KERNEL = 8192
32 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010033
Eric Kunzee5e26762020-10-13 16:11:07 -070034 def __init__(self, args):
35 self.args = args
36 self.basePath = args.output_dir
37 self.random_seed = args.random_seed
38 self.ser = None
39 self.rng = np.random.default_rng(self.random_seed)
40 self.createDynamicOpLists()
41 self.initOpListDefaults()
42 self.quantGen = TosaQuantGen()
43 # Force makeShape to do a specific starting shape
44 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010045 # Work out floating point range
46 self.random_fp_low = min(args.tensor_fp_value_range)
47 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070048
49 def createSerializer(self, opName, testPath):
50 self.testPath = os.path.join(opName, testPath)
51
52 fullPath = os.path.join(self.basePath, self.testPath)
53 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010054 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070055
56 def getSerializer(self):
57 return self.ser
58
59 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080060 with open(
61 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
62 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070063 fd.write(self.ser.serialize())
64
Kevin Cheng550ccc52021-03-03 11:21:43 -080065 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
66 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070067
Matthew Haddon74567092021-07-16 15:38:20 +010068 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000069 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010070 seed = self.random_seed + 1
71 self.rng = np.random.default_rng(seed)
72
Eric Kunzee5e26762020-10-13 16:11:07 -070073 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070075 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070076 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070077 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070078 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070079 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010080 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
81 elif dtype == DType.UINT8:
82 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070083 elif dtype == DType.INT16:
84 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010085 elif dtype == DType.UINT16:
86 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070087 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080088 return np.int32(
89 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
90 )
Eric Kunzee5e26762020-10-13 16:11:07 -070091 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080092 return np.int64(
93 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
94 )
James Ward8b390432022-08-12 20:48:56 +010095 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010096 return np.float16(
97 self.rng.uniform(
98 low=self.random_fp_low, high=self.random_fp_high, size=shape
99 )
100 )
James Ward24dbc422022-10-19 12:20:31 +0100101 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100102 f32_tensor = np.float32(
103 self.rng.uniform(
104 low=self.random_fp_low, high=self.random_fp_high, size=shape
105 )
106 )
James Ward24dbc422022-10-19 12:20:31 +0100107 # Floor the last 16 bits of each f32 value
108 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100109 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100110 return np.float32(
111 self.rng.uniform(
112 low=self.random_fp_low, high=self.random_fp_high, size=shape
113 )
114 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700115 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
Kevin Cheng989cb052021-04-28 16:29:44 -0700118 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700119 placeholders = []
120
Kevin Cheng989cb052021-04-28 16:29:44 -0700121 assert len(shape_list) == len(dtype_list)
122
123 for idx, shape in enumerate(shape_list):
124 arr = self.getRandTensor(shape, dtype_list[idx])
125 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700126
127 return placeholders
128
Kevin Cheng989cb052021-04-28 16:29:44 -0700129 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700130 consts = []
131
Kevin Cheng989cb052021-04-28 16:29:44 -0700132 assert len(shape_list) == len(dtype_list)
133
134 for idx, shape in enumerate(shape_list):
135 arr = self.getRandTensor(shape, dtype_list[idx])
136 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700137
138 return consts
139
140 def makeShape(self, rank):
141 if self.targetted_shape:
142 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800143 return np.int32(
144 self.rng.integers(
145 low=self.args.tensor_shape_range[0],
146 high=self.args.tensor_shape_range[1],
147 size=rank,
148 )
149 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700150
151 def setTargetShape(self, shape):
152 self.targetted_shape = shape
153
154 def randInt(self, low=0, high=256):
155 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
156
157 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100158 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100159 return np.float32(
160 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
161 )
James Ward8b390432022-08-12 20:48:56 +0100162 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100163 return np.float16(
164 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
165 )
James Ward24dbc422022-10-19 12:20:31 +0100166 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100167 rand_f32 = np.float32(
168 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
169 )
James Ward24dbc422022-10-19 12:20:31 +0100170 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 elif dtype == DType.BOOL:
172 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700173 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700175 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700176 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100177 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 elif dtype == DType.INT16:
179 low, high = (-32768, 32768)
180 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800183 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 # Special size
185 return np.int64(self.rng.integers(low, high, size=1))[0]
186 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800187 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700188
189 return np.int32(self.rng.integers(low, high, size=1))[0]
190
191 def shapeStr(self, shape):
192
193 sStr = []
194 # Convert to strings
195 for i in shape:
196 sStr.append(str(i))
197
Kevin Cheng550ccc52021-03-03 11:21:43 -0800198 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700199
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100200 def typeStr(self, dtype):
201 if isinstance(dtype, list) or isinstance(dtype, tuple):
202 assert len(dtype) >= 2
203 strs = [self.typeStr(t) for t in dtype]
204 # Limit types to the first 2 as the 3rd is the accumulator
205 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100207 if dtype in DTYPE_ATTRIBUTES:
208 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700209 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100210 raise Exception(
211 "Unknown dtype, cannot convert to string: {}".format(dtype)
212 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700213
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100215 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100216 if dtype in DTYPE_ATTRIBUTES:
217 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700218 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100219 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700220
Luke Hutton57287132023-02-06 14:54:18 +0000221 def constrictBatchSize(self, shape):
222 # Limit the batch size unless an explicit target shape set
223 if self.args.max_batch_size and not self.args.target_shapes:
224 shape[0] = min(shape[0], self.args.max_batch_size)
225 return shape
226
James Ward30124a82023-02-02 14:56:33 +0000227 def makeDimension(self):
228 return self.randInt(
229 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
230 )
231
Eric Kunzee5e26762020-10-13 16:11:07 -0700232 # Argument generators
233 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
234 # Where the string descriptor is used to generate the test name and
235 # The build_fcn_arg_list is expanded and passed to the operator test
236 # build function
237
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100238 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
239 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
240
Matthew Haddon848efb42021-09-09 12:30:53 +0100241 # build_placeholder returns an int, ABS/other ops does not
242 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000243 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100244 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000245 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000246 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100247 return result_tens
248
249 # Ensure new output type has correct qinfo
250 if error_name == ErrorIf.WrongOutputType:
251 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000252 qinfo = [
253 TosaQuantGen.getZeroPoint(self, a.dtype),
254 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
255 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100256
257 # Invalidate Input/Output list for error if checks.
258 input_list = [a.name]
259 output_list = [result_tens.name]
260 pCount, cCount = op["operands"]
261 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000262 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
263 self, error_name, input_list, output_list
264 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100265
Les Bell729b0352021-11-24 10:28:21 +0000266 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100267 self.ser,
268 validator_fcns,
269 error_name,
270 op=op,
271 input_dtype=a.dtype,
272 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000273 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000274 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100275 input_list=input_list,
276 output_list=output_list,
277 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000278 ):
279 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100280
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000281 attr = None
282 if op["op"] == Op.NEGATE:
283 attr = ts.TosaSerializerAttribute()
284 attr.NegateAttribute(qinfo[0], qinfo[1])
285
286 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 return result_tens
288
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100289 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000290 result_tens = OutputShaper.binaryBroadcastOp(
291 self.ser, self.rng, a, b, error_name
292 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100293
294 # Invalidate Input/Output list for error if checks.
295 input_list = [a.name, b.name]
296 output_list = [result_tens.name]
297 pCount, cCount = op["operands"]
298 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000299 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
300 self, error_name, input_list, output_list
301 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100302
Les Bell729b0352021-11-24 10:28:21 +0000303 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100304 self.ser,
305 validator_fcns,
306 error_name,
307 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000308 input1=a,
309 input2=b,
310 input_dtype=a.dtype,
311 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000312 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100313 input_list=input_list,
314 output_list=output_list,
315 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000316 ):
317 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100318
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000319 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700320 return result_tens
321
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100322 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700323 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000324 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700325 return result_tens
326
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000327 def build_arithmetic_right_shift(
328 self, op, a, b, round, validator_fcns=None, error_name=None
329 ):
330 result_tens = OutputShaper.binaryBroadcastOp(
331 self.ser, self.rng, a, b, error_name
332 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100333
334 # Invalidate Input/Output list for error if checks.
335 input_list = [a.name, b.name]
336 output_list = [result_tens.name]
337 pCount, cCount = op["operands"]
338 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000339 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
340 self, error_name, input_list, output_list
341 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100342
Les Bell729b0352021-11-24 10:28:21 +0000343 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100344 self.ser,
345 validator_fcns,
346 error_name,
347 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000348 input1=a,
349 input2=b,
350 input_dtype=a.dtype,
351 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000352 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100353 input_list=input_list,
354 output_list=output_list,
355 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000356 ):
357 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800358
359 attr = ts.TosaSerializerAttribute()
360 attr.ArithmeticRightShiftAttribute(round)
361
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000362 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800363 return result_tens
364
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100365 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000366 result_tens = OutputShaper.binaryBroadcastOp(
367 self.ser, self.rng, a, b, error_name
368 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
370 # Special for multiply:
371 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100372 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700373 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100374 if error_name == ErrorIf.WrongOutputType:
375 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
376 outputDType = self.rng.choice(all_dtypes)
377 result_tens.setDtype(outputDType)
378
379 # Invalidate Input/Output list for error if checks.
380 input_list = [a.name, b.name]
381 output_list = [result_tens.name]
382 pCount, cCount = op["operands"]
383 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000384 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
385 self, error_name, input_list, output_list
386 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100387
Les Bell729b0352021-11-24 10:28:21 +0000388 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100389 self.ser,
390 validator_fcns,
391 error_name,
392 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 input1=a,
394 input2=b,
395 input_dtype=a.dtype,
396 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000397 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100398 input_list=input_list,
399 output_list=output_list,
400 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000401 ):
402 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700403
Kevin Chengaee1fac2020-11-11 13:54:06 -0800404 attr = ts.TosaSerializerAttribute()
405 attr.MulAttribute(shift)
406
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000407 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700408 return result_tens
409
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
411 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700412
Kevin Chengfe392ce2021-10-18 21:51:55 +0000413 attr = ts.TosaSerializerAttribute()
414 attr.TableAttribute(table)
415
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100416 # Invalidate Input/Output list for error if checks.
417 input_list = [a.name]
418 output_list = [result_tens.name]
419 pCount, cCount = op["operands"]
420 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000421 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
422 self, error_name, input_list, output_list
423 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100424
Les Bell729b0352021-11-24 10:28:21 +0000425 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100426 self.ser,
427 validator_fcns,
428 error_name,
429 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000430 input_shape=a.shape,
431 input_dtype=a.dtype,
432 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000433 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100434 input_list=input_list,
435 output_list=output_list,
436 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000437 ):
438 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100439
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700441
442 return result_tens
443
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100444 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
445 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
446
447 # Invalidate Input/Output list for error if checks.
448 input_list = [cond.name, a.name, b.name]
449 output_list = [result_tens.name]
450 pCount, cCount = op["operands"]
451 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000452 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
453 self, error_name, input_list, output_list
454 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100455
Les Bell729b0352021-11-24 10:28:21 +0000456 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457 self.ser,
458 validator_fcns,
459 error_name,
460 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000461 input1=cond,
462 input2=a,
463 input3=b,
464 input_shape=a.shape,
465 input_dtype=a.dtype,
466 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000467 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100468 input_list=input_list,
469 output_list=output_list,
470 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000471 ):
472 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100473
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 self.ser.addOperator(
475 op["op"],
476 input_list,
477 output_list,
478 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 return result_tens
480
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100481 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000482 result_tens = OutputShaper.binaryComparisonOp(
483 self.ser, self.rng, a, b, error_name
484 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100485
486 # Invalidate Input/Output list for error if checks.
487 input_list = [a.name, b.name]
488 output_list = [result_tens.name]
489 pCount, cCount = op["operands"]
490 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000491 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
492 self, error_name, input_list, output_list
493 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100494
Les Bell729b0352021-11-24 10:28:21 +0000495 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100496 self.ser,
497 validator_fcns,
498 error_name,
499 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 input1=a,
501 input2=b,
502 input_shape=a.shape,
503 input_dtype=a.dtype,
504 output_shape=result_tens.shape,
505 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000506 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100507 input_list=input_list,
508 output_list=output_list,
509 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000510 ):
511 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100512
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000513 self.ser.addOperator(
514 op["op"],
515 input_list,
516 output_list,
517 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700518 return result_tens
519
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100520 def build_argmax(self, op, a, axis, validator_fcns, error_name):
521 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
522
523 # Invalidate Input/Output list for error if checks.
524 input_list = [a.name]
525 output_list = [result_tens.name]
526 pCount, cCount = op["operands"]
527 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
529 self, error_name, input_list, output_list
530 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100531
Les Bell729b0352021-11-24 10:28:21 +0000532 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100533 self.ser,
534 validator_fcns,
535 error_name,
536 op=op,
537 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000538 input_shape=a.shape,
539 input_dtype=a.dtype,
540 output_shape=result_tens.shape,
541 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000542 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100543 input_list=input_list,
544 output_list=output_list,
545 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000546 ):
547 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
549 attr = ts.TosaSerializerAttribute()
550 attr.AxisAttribute(axis)
551
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000552 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700553 return result_tens
554
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000555 def build_pool2d(
556 self,
557 op,
558 input,
James Ward8b390432022-08-12 20:48:56 +0100559 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000560 stride,
561 pad,
562 kernel,
563 validator_fcns=None,
564 error_name=None,
565 qinfo=None,
566 ):
567 result_tens = OutputShaper.pool2dOp(
568 self.ser, self.rng, input, kernel, stride, pad, error_name
569 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100570
571 # Ensure new output type has correct qinfo
572 if error_name == ErrorIf.WrongInputType:
573 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000574 qinfo = [
575 TosaQuantGen.getZeroPoint(self, input.dtype),
576 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
577 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100578
579 # Invalidate Input/Output list for error if checks.
580 input_list = [input.name]
581 output_list = [result_tens.name]
582 pCount, cCount = op["operands"]
583 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
585 self, error_name, input_list, output_list
586 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100587
Les Bell729b0352021-11-24 10:28:21 +0000588 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100589 self.ser,
590 validator_fcns,
591 error_name,
592 op=op,
593 input_shape=input.shape,
594 input_dtype=input.dtype,
595 output_shape=result_tens.shape,
596 output_dtype=result_tens.dtype,
597 kernel=kernel,
598 stride=stride,
599 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000601 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100602 input_list=input_list,
603 output_list=output_list,
604 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000605 ):
606 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700607
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000608 if qinfo is None:
609 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700610
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000611 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100612 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000613
614 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700615 return result_tens
616
James Ward8b390432022-08-12 20:48:56 +0100617 def build_maxpool2d(
618 self,
619 op,
620 input,
621 stride,
622 pad,
623 kernel,
624 validator_fcns=None,
625 error_name=None,
626 qinfo=None,
627 ):
628 # Same as build_pool2d but manually sets accum_dtype value
629 # (maxpool has no accum_dtype)
630 return self.build_pool2d(
631 op,
632 input,
633 DType.UNKNOWN,
634 stride,
635 pad,
636 kernel,
637 validator_fcns,
638 error_name,
639 qinfo,
640 )
641
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000642 def build_conv2d(
643 self,
644 op,
645 ifm,
646 filter,
647 bias,
James Ward8b390432022-08-12 20:48:56 +0100648 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000649 strides,
650 padding,
651 dilations,
652 validator_fcns=None,
653 error_name=None,
654 qinfo=None,
655 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800656 assert len(padding) == 4
657 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100658 self.ser,
659 self.rng,
660 ifm,
661 filter,
662 accum_dtype,
663 strides,
664 padding,
665 dilations,
666 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000667 )
668
669 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000670 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
671 DType.INT8,
672 DType.UINT8,
673 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000674 qinfo = [
675 TosaQuantGen.getZeroPoint(self, ifm.dtype),
676 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
677 ]
Les Bell0e027d42021-11-09 14:42:14 +0000678
679 # Invalidate Input/Output list for error_if checks.
680 input_list = [ifm.name, filter.name, bias.name]
681 output_list = [result_tens.name]
682 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000683 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
684 self, error_name, input_list, output_list
685 )
Les Bell0e027d42021-11-09 14:42:14 +0000686
Les Bell729b0352021-11-24 10:28:21 +0000687 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000688 self.ser,
689 validator_fcns,
690 error_name,
691 op=op,
692 input_dtype=ifm.dtype,
693 weight_dtype=filter.dtype,
694 output_dtype=result_tens.dtype,
695 qinfo=qinfo,
696 input_list=input_list,
697 num_operands=num_operands,
698 output_list=output_list,
699 pad=padding,
700 stride=strides,
701 dilation=dilations,
702 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100703 weight_shape=filter.shape,
704 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000705 ):
706 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700707
708 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000709 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700710
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000711 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700712 return result_tens
713
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000714 def build_conv3d(
715 self,
716 op,
717 ifm,
718 filter,
719 bias,
James Ward8b390432022-08-12 20:48:56 +0100720 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000721 strides,
722 padding,
723 dilations,
724 validator_fcns=None,
725 error_name=None,
726 qinfo=None,
727 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700728 assert len(padding) == 6
729 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100730 self.ser,
731 self.rng,
732 ifm,
733 filter,
734 accum_dtype,
735 strides,
736 padding,
737 dilations,
738 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000739 )
740
741 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000742 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
743 DType.INT8,
744 DType.UINT8,
745 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000746 qinfo = [
747 TosaQuantGen.getZeroPoint(self, ifm.dtype),
748 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
749 ]
Les Bell0e027d42021-11-09 14:42:14 +0000750
751 # Invalidate Input/Output list for error_if checks.
752 input_list = [ifm.name, filter.name, bias.name]
753 output_list = [result_tens.name]
754 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000755 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
756 self, error_name, input_list, output_list
757 )
Les Bell0e027d42021-11-09 14:42:14 +0000758
Les Bell729b0352021-11-24 10:28:21 +0000759 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000760 self.ser,
761 validator_fcns,
762 error_name,
763 op=op,
764 input_dtype=ifm.dtype,
765 weight_dtype=filter.dtype,
766 output_dtype=result_tens.dtype,
767 qinfo=qinfo,
768 input_list=input_list,
769 num_operands=num_operands,
770 output_list=output_list,
771 pad=padding,
772 stride=strides,
773 dilation=dilations,
774 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100775 weight_shape=filter.shape,
776 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000777 ):
778 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700779
780 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000781 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700782
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000783 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700784 return result_tens
785
Kevin Cheng550ccc52021-03-03 11:21:43 -0800786 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 self,
788 op,
789 ifm,
790 filter,
791 bias,
James Ward8b390432022-08-12 20:48:56 +0100792 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000793 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700794 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000795 output_shape,
796 validator_fcns=None,
797 error_name=None,
798 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800799 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700800 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100802 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 )
Les Bell0e027d42021-11-09 14:42:14 +0000804
805 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000806 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
807 DType.INT8,
808 DType.UINT8,
809 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000810 qinfo = [
811 TosaQuantGen.getZeroPoint(self, ifm.dtype),
812 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
813 ]
Les Bell0e027d42021-11-09 14:42:14 +0000814
815 # Invalidate Input/Output list for error_if checks.
816 input_list = [ifm.name, filter.name, bias.name]
817 output_list = [result_tens.name]
818 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000819 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
820 self, error_name, input_list, output_list
821 )
Les Bell0e027d42021-11-09 14:42:14 +0000822
Les Bell729b0352021-11-24 10:28:21 +0000823 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000824 self.ser,
825 validator_fcns,
826 error_name,
827 op=op,
828 input_dtype=ifm.dtype,
829 weight_dtype=filter.dtype,
830 output_dtype=result_tens.dtype,
831 qinfo=qinfo,
832 input_list=input_list,
833 num_operands=num_operands,
834 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700835 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000836 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000837 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100838 weight_shape=filter.shape,
839 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000840 ):
841 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700842
843 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000844 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700845
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000846 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700847 return result_tens
848
Kevin Cheng550ccc52021-03-03 11:21:43 -0800849 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000850 self,
851 op,
852 ifm,
853 filter,
854 bias,
James Ward8b390432022-08-12 20:48:56 +0100855 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000856 strides,
857 padding,
858 dilations,
859 validator_fcns=None,
860 error_name=None,
861 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800862 ):
863 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100864 self.ser,
865 self.rng,
866 ifm,
867 filter,
868 accum_dtype,
869 strides,
870 padding,
871 dilations,
872 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000873 )
874
875 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000876 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
877 DType.INT8,
878 DType.UINT8,
879 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000880 qinfo = [
881 TosaQuantGen.getZeroPoint(self, ifm.dtype),
882 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
883 ]
Les Bell0e027d42021-11-09 14:42:14 +0000884
885 # Invalidate Input/Output list for error_if checks.
886 input_list = [ifm.name, filter.name, bias.name]
887 output_list = [result_tens.name]
888 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000889 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
890 self, error_name, input_list, output_list
891 )
Les Bell0e027d42021-11-09 14:42:14 +0000892
Les Bell729b0352021-11-24 10:28:21 +0000893 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000894 self.ser,
895 validator_fcns,
896 error_name,
897 op=op,
898 input_dtype=ifm.dtype,
899 weight_dtype=filter.dtype,
900 output_dtype=result_tens.dtype,
901 qinfo=qinfo,
902 input_list=input_list,
903 num_operands=num_operands,
904 output_list=output_list,
905 pad=padding,
906 stride=strides,
907 dilation=dilations,
908 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100909 weight_shape=filter.shape,
910 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000911 ):
912 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700913
914 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000915 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700916
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000917 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700918 return result_tens
919
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000920 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100921 self,
922 op,
923 ifm,
924 filter,
925 bias,
926 accum_dtype,
927 validator_fcns=None,
928 error_name=None,
929 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000930 ):
931 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100932 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000933 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100934
935 # Invalidate Input/Output list for error if checks.
936 input_list = [ifm.name, filter.name, bias.name]
937 output_list = [result_tens.name]
938 pCount, cCount = op["operands"]
939 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000940 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
941 self, error_name, input_list, output_list
942 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100943
Les Bell729b0352021-11-24 10:28:21 +0000944 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100945 self.ser,
946 validator_fcns,
947 error_name,
948 op=op,
949 input_shape=ifm.shape,
950 input_dtype=ifm.dtype,
951 weight_dtype=filter.dtype,
952 output_shape=result_tens.shape,
953 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000954 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000955 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100956 input_list=input_list,
957 output_list=output_list,
958 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100959 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000960 ):
961 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700962
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000963 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000964 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000965
966 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700967 return result_tens
968
James Ward8b390432022-08-12 20:48:56 +0100969 def build_matmul(
970 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
971 ):
972 result_tens = OutputShaper.matmulOp(
973 self.ser, self.rng, a, b, accum_dtype, error_name
974 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100975
976 # Invalidate Input/Output list for error if checks.
977 input_list = [a.name, b.name]
978 output_list = [result_tens.name]
979 pCount, cCount = op["operands"]
980 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000981 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
982 self, error_name, input_list, output_list
983 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100984
Les Bell729b0352021-11-24 10:28:21 +0000985 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100986 self.ser,
987 validator_fcns,
988 error_name,
989 op=op,
990 input_shape=a.shape,
991 input_dtype=a.dtype,
992 input2_shape=b.shape,
993 input2_dtype=b.dtype,
994 output_shape=result_tens.shape,
995 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000996 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000997 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100998 input_list=input_list,
999 output_list=output_list,
1000 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001001 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001002 ):
1003 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001004
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001005 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001006 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001007
1008 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001009 return result_tens
1010
Matthew Haddond6ce7252021-09-29 15:35:44 +01001011 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1012 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1013
1014 # Invalidate Input/Output list for error if checks.
1015 input_list = [a.name]
1016 output_list = [result_tens.name]
1017 pCount, cCount = op["operands"]
1018 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001019 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1020 self, error_name, input_list, output_list
1021 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001022
Les Bell729b0352021-11-24 10:28:21 +00001023 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001024 self.ser,
1025 validator_fcns,
1026 error_name,
1027 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 axis=axis,
1029 input_shape=a.shape,
1030 output_shape=result_tens.shape,
1031 input_dtype=a.dtype,
1032 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001033 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001034 input_list=input_list,
1035 output_list=output_list,
1036 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001037 ):
1038 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001039
1040 attr = ts.TosaSerializerAttribute()
1041 attr.AxisAttribute(axis)
1042
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001043 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001044 return result_tens
1045
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001046 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1047 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001048
Jeremy Johnson18e26662021-07-22 16:15:29 +01001049 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001050
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001051 if error_name == ErrorIf.MaxSmallerMin:
1052 # Make sure the numbers are different to invoke this error
1053 while v[0] == v[1]:
1054 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1055 max_val = min(v)
1056 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001057 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001058 max_val = max(v)
1059 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001060
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001061 # Invalidate Input/Output list for error if checks.
1062 input_list = [a.name]
1063 output_list = [result_tens.name]
1064 pCount, cCount = op["operands"]
1065 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001066 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1067 self, error_name, input_list, output_list
1068 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001069
Les Bell729b0352021-11-24 10:28:21 +00001070 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001071 self.ser,
1072 validator_fcns,
1073 error_name,
1074 op=op,
1075 max_val=max_val,
1076 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001077 input_shape=a.shape,
1078 output_shape=result_tens.shape,
1079 input_dtype=a.dtype,
1080 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001081 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001082 input_list=input_list,
1083 output_list=output_list,
1084 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001085 ):
1086 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001087
1088 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001089 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1090 if a.dtype == DType.FP16:
1091 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1092 min_val = min_val.astype(np.float32)
1093 max_val = max_val.astype(np.float32)
1094
1095 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001096 else:
James Ward34071252022-12-07 15:48:47 +00001097 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001098
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001099 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001100 return result_tens
1101
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001102 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1103 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001104 attr = ts.TosaSerializerAttribute()
1105
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001106 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001107
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001108 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001109 return result_tens
1110
1111 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001112 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1113 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001114
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001115 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001116 return result_tens
1117
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001118 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1119 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1120
1121 # Invalidate Input/Output list for error if checks.
1122 input_list = [a.name]
1123 output_list = [result_tens.name]
1124 pCount, cCount = op["operands"]
1125 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001126 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1127 self, error_name, input_list, output_list
1128 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001129
Les Bell729b0352021-11-24 10:28:21 +00001130 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001131 self.ser,
1132 validator_fcns,
1133 error_name,
1134 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001135 input_shape=a.shape,
1136 output_shape=result_tens.shape,
1137 input_dtype=a.dtype,
1138 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001139 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001140 input_list=input_list,
1141 output_list=output_list,
1142 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001143 ):
1144 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001145
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001146 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001147 return result_tens
1148
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001149 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1150 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1151
1152 # Invalidate Input/Output list for error if checks.
1153 input_list = [a.name]
1154 output_list = [result_tens.name]
1155 pCount, cCount = op["operands"]
1156 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001157 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1158 self, error_name, input_list, output_list
1159 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001160
Les Bell729b0352021-11-24 10:28:21 +00001161 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001162 self.ser,
1163 validator_fcns,
1164 error_name,
1165 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001166 input_shape=a.shape,
1167 output_shape=result_tens.shape,
1168 input_dtype=a.dtype,
1169 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001170 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001171 input_list=input_list,
1172 output_list=output_list,
1173 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001174 ):
1175 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001176
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001177 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001178 return result_tens
1179
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001180 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1181 if error_name != ErrorIf.WrongInputType:
1182 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001183
1184 # To store variable length list of input tensors we need to store axis along with it
1185 axis = a[-1]
1186 a = a[:-1]
1187
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001188 result_tens = OutputShaper.concatOp(
1189 self.ser, self.rng, axis, *a, error_name=error_name
1190 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001191
Matthew Haddon818ab902021-07-27 09:12:49 +01001192 input_tensor_names = []
1193 for tensor in a:
1194 input_tensor_names.append(tensor.name)
1195
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001196 # Invalidate Input/Output list for error if checks.
1197 input_list = input_tensor_names
1198 output_list = [result_tens.name]
1199 pCount, cCount = op["operands"]
1200 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001201 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1202 self, error_name, input_list, output_list
1203 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001204
Les Bell729b0352021-11-24 10:28:21 +00001205 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001206 self.ser,
1207 validator_fcns,
1208 error_name,
1209 op=op,
1210 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001211 input_shape=a[0].shape,
1212 output_shape=result_tens.shape,
1213 input_dtype=a[0].dtype,
1214 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001215 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001216 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001217 input_list=input_list,
1218 output_list=output_list,
1219 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001220 ):
1221 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001222
1223 attr = ts.TosaSerializerAttribute()
1224 attr.AxisAttribute(axis)
1225
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001226 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001227 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001228
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001229 def build_pad(
1230 self,
1231 op,
1232 a,
1233 padding,
1234 pad_const_int,
1235 pad_const_float,
1236 validator_fcns=None,
1237 error_name=None,
1238 qinfo=None,
1239 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001240 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001241
Kevin Chengfe392ce2021-10-18 21:51:55 +00001242 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001243 attr.PadAttribute(
1244 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1245 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001246
Matthew Haddone807aae2021-10-11 18:12:58 +01001247 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001248 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001249 output_list = [result_tens.name]
1250 pCount, cCount = op["operands"]
1251 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001252 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1253 self, error_name, input_list, output_list
1254 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001255
Les Bell729b0352021-11-24 10:28:21 +00001256 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001257 self.ser,
1258 validator_fcns,
1259 error_name,
1260 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001261 input_shape=a.shape,
1262 output_shape=result_tens.shape,
1263 input_dtype=a.dtype,
1264 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001265 pad=padding,
1266 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001267 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001268 input_list=input_list,
1269 output_list=output_list,
1270 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001271 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001272 ):
1273 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001274
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001275 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001276 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001277
Matthew Haddone807aae2021-10-11 18:12:58 +01001278 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001279 result_tens = OutputShaper.reshapeOp(
1280 self.ser, self.rng, a, newShape, error_name
1281 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001282
1283 # Invalidate Input/Output list for error if checks.
1284 input_list = [a.name]
1285 output_list = [result_tens.name]
1286 pCount, cCount = op["operands"]
1287 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001288 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1289 self, error_name, input_list, output_list
1290 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001291
Les Bell729b0352021-11-24 10:28:21 +00001292 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001293 self.ser,
1294 validator_fcns,
1295 error_name,
1296 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001297 input_shape=a.shape,
1298 output_shape=result_tens.shape,
1299 input_dtype=a.dtype,
1300 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001301 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001302 input_list=input_list,
1303 output_list=output_list,
1304 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001305 ):
1306 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001307
1308 attr = ts.TosaSerializerAttribute()
1309 attr.ReshapeAttribute(newShape)
1310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001311 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001312 return result_tens
1313
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001314 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1315 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1316
1317 # Invalidate Input/Output list for error if checks.
1318 input_list = [a.name]
1319 output_list = [result_tens.name]
1320 pCount, cCount = op["operands"]
1321 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001322 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1323 self, error_name, input_list, output_list
1324 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001325
Les Bell729b0352021-11-24 10:28:21 +00001326 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001327 self.ser,
1328 validator_fcns,
1329 error_name,
1330 op=op,
1331 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001332 input_shape=a.shape,
1333 output_shape=result_tens.shape,
1334 input_dtype=a.dtype,
1335 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001336 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001337 input_list=input_list,
1338 output_list=output_list,
1339 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001340 ):
1341 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001342
1343 attr = ts.TosaSerializerAttribute()
1344 attr.AxisAttribute(axis)
1345
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001346 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001347 return result_tens
1348
Matthew Haddone807aae2021-10-11 18:12:58 +01001349 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1350 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001351
Kevin Chengfe392ce2021-10-18 21:51:55 +00001352 attr = ts.TosaSerializerAttribute()
1353 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001354
Matthew Haddone807aae2021-10-11 18:12:58 +01001355 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001356 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001357 output_list = [result_tens.name]
1358 pCount, cCount = op["operands"]
1359 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001360 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1361 self, error_name, input_list, output_list
1362 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001363
Les Bell729b0352021-11-24 10:28:21 +00001364 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001365 self.ser,
1366 validator_fcns,
1367 error_name,
1368 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001369 input_shape=a.shape,
1370 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001371 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 input_dtype=a.dtype,
1373 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001374 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001375 input_list=input_list,
1376 output_list=output_list,
1377 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001378 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001379 ):
1380 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001381
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001382 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001383 return result_tens
1384
Matthew Haddone807aae2021-10-11 18:12:58 +01001385 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001386 result_tens = OutputShaper.sliceOp(
1387 self.ser, self.rng, a, start, size, error_name
1388 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001389
1390 # Invalidate Input/Output list for error if checks.
1391 input_list = [a.name]
1392 output_list = [result_tens.name]
1393 pCount, cCount = op["operands"]
1394 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001395 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1396 self, error_name, input_list, output_list
1397 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001398
Les Bell729b0352021-11-24 10:28:21 +00001399 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001400 self.ser,
1401 validator_fcns,
1402 error_name,
1403 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001404 input_shape=a.shape,
1405 output_shape=result_tens.shape,
1406 input_dtype=a.dtype,
1407 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001408 start=start,
1409 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001410 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001411 input_list=input_list,
1412 output_list=output_list,
1413 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001414 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001415 ):
1416 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001417
1418 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001419 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001420
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001421 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001422 return result_tens
1423
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1425 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1426
1427 # Invalidate Input/Output list for error if checks.
1428 input_list = [a.name]
1429 output_list = [result_tens.name]
1430 pCount, cCount = op["operands"]
1431 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001432 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1433 self, error_name, input_list, output_list
1434 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001435
Les Bell729b0352021-11-24 10:28:21 +00001436 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001437 self.ser,
1438 validator_fcns,
1439 error_name,
1440 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001441 input_shape=a.shape,
1442 output_shape=result_tens.shape,
1443 input_dtype=a.dtype,
1444 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001445 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001446 input_list=input_list,
1447 output_list=output_list,
1448 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001449 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001450 ):
1451 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001452
1453 attr = ts.TosaSerializerAttribute()
1454 attr.TileAttribute(multiples)
1455
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001456 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001457 return result_tens
1458
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001459 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001460
1461 # Create a new indicies tensor
1462 # here with data that doesn't exceed the dimensions of the values tensor
1463
Kevin Cheng550ccc52021-03-03 11:21:43 -08001464 K = values.shape[1] # K
1465 W = self.randInt(
1466 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1467 ) # W
1468 indicies_arr = np.int32(
1469 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1470 ) # (N, W)
1471 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001472
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001473 result_tens = OutputShaper.gatherOp(
1474 self.ser, self.rng, values, indicies, error_name
1475 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001476
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001477 # Invalidate Input/Output list for error if checks.
1478 input_list = [values.name, indicies.name]
1479 output_list = [result_tens.name]
1480 pCount, cCount = op["operands"]
1481 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001482 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1483 self, error_name, input_list, output_list
1484 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001485
Les Bell729b0352021-11-24 10:28:21 +00001486 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001487 self.ser,
1488 validator_fcns,
1489 error_name,
1490 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001491 input_shape=values.shape,
1492 output_shape=result_tens.shape,
1493 input_dtype=values.dtype,
1494 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001495 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001496 input_list=input_list,
1497 output_list=output_list,
1498 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001499 ):
1500 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001501
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001502 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001503
1504 return result_tens
1505
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001506 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001507
1508 # Create a new indicies tensor
1509 # here with data that doesn't exceed the dimensions of the values_in tensor
1510
Kevin Cheng550ccc52021-03-03 11:21:43 -08001511 K = values_in.shape[1] # K
1512 W = input.shape[1] # W
1513 indicies_arr = np.int32(
1514 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1515 ) # (N, W)
1516 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001517
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001518 result_tens = OutputShaper.scatterOp(
1519 self.ser, self.rng, values_in, indicies, input, error_name
1520 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001521
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001522 # Invalidate Input/Output list for error if checks.
1523 input_list = [values_in.name, indicies.name, input.name]
1524 output_list = [result_tens.name]
1525 pCount, cCount = op["operands"]
1526 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001527 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1528 self, error_name, input_list, output_list
1529 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001530
Les Bell729b0352021-11-24 10:28:21 +00001531 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001532 self.ser,
1533 validator_fcns,
1534 error_name,
1535 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001536 input_shape=values_in.shape,
1537 output_shape=result_tens.shape,
1538 input_dtype=values_in.dtype,
1539 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001540 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001541 input_list=input_list,
1542 output_list=output_list,
1543 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001544 ):
1545 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001546
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001547 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001548
Kevin Cheng77d0f762020-11-24 10:26:32 -08001549 return result_tens
1550
Kevin Cheng550ccc52021-03-03 11:21:43 -08001551 def build_resize(
1552 self,
1553 op,
1554 input,
1555 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001556 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001557 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001558 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001559 input_dtype,
1560 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001561 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001562 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001563 ):
1564 result_tens = OutputShaper.resizeOp(
1565 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001566 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001567 input,
1568 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001569 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001570 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001571 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001572 input_dtype,
1573 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001574 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001575 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001576
Matthew Haddon848efb42021-09-09 12:30:53 +01001577 # Invalidate Input/Output list for error if checks.
1578 input_list = [input.name]
1579 output_list = [result_tens.name]
1580 pCount, cCount = op["operands"]
1581 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001582 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1583 self, error_name, input_list, output_list
1584 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001585
Les Bell729b0352021-11-24 10:28:21 +00001586 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001587 self.ser,
1588 validator_fcns,
1589 error_name,
1590 op=op,
1591 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001592 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001593 input_dtype=input_dtype,
1594 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001595 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001596 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001597 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001598 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001599 input_list=input_list,
1600 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001601 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001602 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001603 ):
1604 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001605
Eric Kunzee5e26762020-10-13 16:11:07 -07001606 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001607
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001608 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001609
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001610 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001611 return result_tens
1612
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001613 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1614 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1615 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001616 self.ser.addOperator(
1617 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1618 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001619 return result_tens
1620
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001621 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001622 self.ser.addOutputTensor(val)
1623 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001624
1625 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001626 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001627 result_tens = OutputShaper.typeConversionOp(
1628 self.ser, self.rng, val, out_dtype, error_name
1629 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001630
1631 # Invalidate Input/Output list for error if checks.
1632 input_list = [val.name]
1633 output_list = [result_tens.name]
1634 pCount, cCount = op["operands"]
1635 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001636 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1637 self, error_name, input_list, output_list
1638 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001639
Les Bell729b0352021-11-24 10:28:21 +00001640 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001641 self.ser,
1642 validator_fcns,
1643 error_name,
1644 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001645 input_shape=val.shape,
1646 output_shape=result_tens.shape,
1647 input_dtype=val.dtype,
1648 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001649 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001650 input_list=input_list,
1651 output_list=output_list,
1652 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001653 ):
1654 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001655
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001656 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001657 return result_tens
1658
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001659 def build_rescale(
1660 self,
1661 op,
1662 val,
1663 out_dtype,
1664 scale32,
1665 double_round,
1666 per_channel,
1667 validator_fcns,
1668 error_name,
1669 ):
1670 result_tens = OutputShaper.typeConversionOp(
1671 self.ser, self.rng, val, out_dtype, error_name
1672 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001673
1674 if per_channel:
1675 nc = val.shape[-1]
1676 else:
1677 nc = 1
1678
1679 in_type_width = self.typeWidth(val.dtype)
1680 out_type_width = self.typeWidth(out_dtype)
1681
Kevin Cheng3a478572021-01-22 17:21:02 -08001682 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001683 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001684 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001685 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001686 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001687 in_type_width += 1
1688 elif error_name in [
1689 ErrorIf.InputZeroPointNotZero,
1690 ErrorIf.U16InputZeroPointNotValid,
1691 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001692 input_zp = self.randInt(-128, 128)
1693 if input_zp == 0:
1694 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001695 in_type_width += 1
1696 elif val.dtype == DType.UINT16:
1697 # Must come after ErrorIf.U16InputZeroPointNotValid check
1698 input_zp = self.rng.choice([0, 32768])
1699 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001700 else:
1701 input_zp = 0
1702
Kevin Cheng3a478572021-01-22 17:21:02 -08001703 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001704 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001705 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001706 elif out_dtype == DType.UINT8:
1707 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001708 out_type_width += 1
1709 elif error_name in [
1710 ErrorIf.OutputZeroPointNotZero,
1711 ErrorIf.U16OutputZeroPointNotValid,
1712 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001713 output_zp = self.randInt(-128, 128)
1714 if output_zp == 0:
1715 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001716 out_type_width += 1
1717 elif out_dtype == DType.UINT16:
1718 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1719 output_zp = self.rng.choice([0, 32768])
1720 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001721 else:
1722 output_zp = 0
1723
1724 # Calculate scale based on:
1725 # scale = a *(2^output_width)/(2^input_width))
1726
1727 a = np.float32(self.rng.random(size=[nc]))
1728 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1729
1730 if scale32:
1731 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001732 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001733 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1734 else:
1735 # Cap the scaling at 2^15 - 1 for scale16
1736 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1737
Kevin Cheng550ccc52021-03-03 11:21:43 -08001738 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001739
1740 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1741 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001742 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1743 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001744
1745 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001746 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1747 scale_arr[i], scale32
1748 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001749 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1750 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001751
Kevin Cheng550ccc52021-03-03 11:21:43 -08001752 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001753 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001754 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001755 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001756 assert val.placeholderFilename
1757 values = np.load(
1758 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1759 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001760 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1761 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1762 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1763 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001764 if not np.all(np.array_equal(values, val_adj)):
1765 # Values changed so overwrite file with new values
1766 np.save(
1767 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1768 val_adj,
1769 False,
1770 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001771
Matthew Haddonc2025212021-10-08 21:21:05 +01001772 # Invalidate Input/Output list for error if checks.
1773 input_list = [val.name]
1774 output_list = [result_tens.name]
1775 pCount, cCount = op["operands"]
1776 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001777 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1778 self, error_name, input_list, output_list
1779 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001780
1781 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001782 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001783 self.ser,
1784 validator_fcns,
1785 error_name,
1786 op=op,
1787 input_dtype=val.dtype,
1788 output_dtype=out_dtype,
1789 input_shape=val.shape,
1790 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001791 scale32=scale32,
1792 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001793 input_list=input_list,
1794 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001795 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001796 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001797 ):
1798 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001799
Eric Kunzee5e26762020-10-13 16:11:07 -07001800 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001801 attr.RescaleAttribute(
1802 input_zp,
1803 output_zp,
1804 multiplier_arr,
1805 shift_arr,
1806 scale32,
1807 double_round,
1808 per_channel,
1809 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001810
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001811 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001812 return result_tens
1813
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001814 def _get_condition_tensor(self, op, cond, error_name):
1815 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1816 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1817 else:
1818 cond_type = DType.BOOL
1819 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1820 choice = self.rng.choice([1, 2])
1821 if choice == 1:
1822 cond_shape = [2]
1823 else:
1824 cond_shape = [1, 2]
1825 else:
1826 # Must be of size 1 (rank 0)
1827 cond_shape = []
1828 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1829 return cond_tens
1830
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001831 def build_cond_if_const(
1832 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1833 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001834 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001835 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001836 # and fill them with const nodes for the body.
1837
1838 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001839 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001840
1841 # Make then/else tensors
1842 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001843
1844 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001845 if error_name in [
1846 ErrorIf.CondIfOutputListThenGraphMismatch,
1847 ErrorIf.CondIfOutputListElseGraphMismatch,
1848 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001849 incorrect_shape = deepcopy(then_tens.shape)
1850 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001851 incorrect_shape[i] += (
1852 self.rng.choice([-3, -2, 2, 3])
1853 if incorrect_shape[i] > 3
1854 else self.rng.choice([1, 2, 4])
1855 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001856 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1857
Jeremy Johnson18e26662021-07-22 16:15:29 +01001858 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1859 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001860
1861 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001862 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001863
1864 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001865 then_block = "THEN_BLOCK"
1866 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001867 attr = ts.TosaSerializerAttribute()
1868 attr.CondIfAttribute(then_block, else_block)
1869
1870 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001871 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001872
Jerry Ge9e94af82022-10-27 09:57:00 -07001873 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001874 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001875 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1876 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1877 else:
1878 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001879 self.ser.addOutputTensor(then_tens)
1880
Jerry Ge9e94af82022-10-27 09:57:00 -07001881 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001882 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1883 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1884 else:
1885 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001886 self.ser.addOutputTensor(else_tens)
1887
Les Bell729b0352021-11-24 10:28:21 +00001888 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001889 self.ser,
1890 validator_fcns,
1891 error_name,
1892 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001893 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001894 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001895 ):
1896 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001897
Eric Kunzee5e26762020-10-13 16:11:07 -07001898 return result_tens
1899
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001900 def build_cond_if_binary(
1901 self, op, a, b, cond, validator_fcns=None, error_name=None
1902 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001903 # For cond_if with a binary op in the then/else blocks, take a and b and
1904 # alternately add or subtract them based on the condition
1905
1906 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001907 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
Kevin Cheng550ccc52021-03-03 11:21:43 -08001909 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001910
1911 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001912 then_block = "THEN_BLOCK"
1913 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001914 attr = ts.TosaSerializerAttribute()
1915 attr.CondIfAttribute(then_block, else_block)
1916
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001917 if error_name in [
1918 ErrorIf.CondIfInputListThenGraphMismatch,
1919 ErrorIf.CondIfInputListElseGraphMismatch,
1920 ErrorIf.CondIfOutputListElseGraphMismatch,
1921 ErrorIf.CondIfOutputListThenGraphMismatch,
1922 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001923 incorrect_shape = a.shape.copy()
1924 for i in range(len(incorrect_shape)):
1925 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1926 incorrect_block_input = deepcopy(a)
1927 incorrect_block_input.shape = incorrect_shape
1928
Eric Kunzee5e26762020-10-13 16:11:07 -07001929 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001930 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001931 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001932 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001933
James Ward24dbc422022-10-19 12:20:31 +01001934 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001935 then_op, else_op = Op.ADD, Op.SUB
1936 elif a.dtype in (DType.INT8, DType.INT16):
1937 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1938 else:
1939 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001940
Les Bell6040b4d2021-10-11 12:50:31 +01001941 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001942 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001943 if (
1944 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1945 and block == then_block
1946 ) or (
1947 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1948 and block == else_block
1949 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001950 self.ser.addInputTensor(incorrect_block_input)
1951 self.ser.addInputTensor(b)
1952 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001953 elif (
1954 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1955 and block == then_block
1956 ) or (
1957 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1958 and block == else_block
1959 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001960 self.ser.addInputTensor(a)
1961 self.ser.addInputTensor(b)
1962 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1963 else:
1964 self.ser.addInputTensor(a)
1965 self.ser.addInputTensor(b)
1966 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001967 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001968
Les Bell729b0352021-11-24 10:28:21 +00001969 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001970 self.ser,
1971 validator_fcns,
1972 error_name,
1973 op=op,
1974 a=a,
1975 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07001976 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001977 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001978 ):
1979 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001980
Eric Kunzee5e26762020-10-13 16:11:07 -07001981 return result_tens
1982
Matthew Haddon630c17c2021-10-14 15:05:41 +01001983 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001984 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001985
Kevin Cheng550ccc52021-03-03 11:21:43 -08001986 cond_block = "COND_BLOCK"
1987 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001988
1989 attr = ts.TosaSerializerAttribute()
1990 attr.WhileLoopAttribute(cond_block, body_block)
1991
1992 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001993 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001994 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001995 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001996
1997 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001998 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1999 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002000 if error_name == ErrorIf.InputListOutputListMismatch:
2001 incorrect_acc = deepcopy(acc)
2002 for i in range(len(incorrect_acc.shape)):
2003 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2004 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2005 else:
2006 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002007
2008 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002009 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002010 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002011 [iter.name, a.name, acc.name],
2012 [iter_out.name, a_out.name, acc_out.name],
2013 attr,
2014 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002015 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002016
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002017 if error_name in [
2018 ErrorIf.InputListCondGraphMismatch,
2019 ErrorIf.InputListBodyGraphInputMismatch,
2020 ErrorIf.InputListBodyGraphOutputMismatch,
2021 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002022 incorrect_iter = deepcopy(iter)
2023 for i in range(len(incorrect_iter.shape)):
2024 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2025 if len(incorrect_iter.shape) == 0:
2026 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2027
2028 incorrect_acc = deepcopy(acc)
2029 for i in range(len(incorrect_acc.shape)):
2030 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2031
Eric Kunzee5e26762020-10-13 16:11:07 -07002032 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002033 self.ser.addBasicBlock(cond_block)
2034
Matthew Haddon630c17c2021-10-14 15:05:41 +01002035 if error_name == ErrorIf.InputListCondGraphMismatch:
2036 self.ser.addInputTensor(incorrect_iter)
2037 self.ser.addInputTensor(a)
2038 self.ser.addInputTensor(incorrect_acc)
2039 else:
2040 self.ser.addInputTensor(iter)
2041 self.ser.addInputTensor(a)
2042 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002043 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002044
2045 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002046 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002047 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002048 cond_type = DType.BOOL
2049 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2050 choice = self.rng.choice([1, 2])
2051 if choice == 1:
2052 cond_shape = [3]
2053 else:
2054 cond_shape = [1, 2]
2055 else:
2056 cond_shape = []
2057 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002058
Kevin Cheng550ccc52021-03-03 11:21:43 -08002059 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002060
2061 # BODY block (input: a, acc, iter, output: a, acc, iter)
2062 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002063 self.ser.addBasicBlock(body_block)
2064
Matthew Haddon630c17c2021-10-14 15:05:41 +01002065 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2066 self.ser.addInputTensor(incorrect_iter)
2067 self.ser.addInputTensor(a)
2068 self.ser.addInputTensor(incorrect_acc)
2069 else:
2070 self.ser.addInputTensor(iter)
2071 self.ser.addInputTensor(a)
2072 self.ser.addInputTensor(acc)
2073
Kevin Cheng550ccc52021-03-03 11:21:43 -08002074 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002075
2076 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002077 iter_body_out = self.ser.addIntermediate(
2078 incorrect_iter.shape, incorrect_iter.dtype
2079 )
2080 acc_body_out = self.ser.addIntermediate(
2081 incorrect_acc.shape, incorrect_acc.dtype
2082 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002083 else:
2084 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2085 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2086
Eric Kunzee5e26762020-10-13 16:11:07 -07002087 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2088 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2089 self.ser.addOutputTensor(iter_body_out)
2090 self.ser.addOutputTensor(a)
2091 self.ser.addOutputTensor(acc_body_out)
2092
Les Bell729b0352021-11-24 10:28:21 +00002093 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002094 self.ser,
2095 validator_fcns,
2096 error_name,
2097 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002098 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002099 ):
2100 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002101
Eric Kunzee5e26762020-10-13 16:11:07 -07002102 return acc_out
2103
Luke Hutton57287132023-02-06 14:54:18 +00002104 def build_fft2d(
2105 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2106 ):
2107 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2108
2109 input_names = [val1.name, val2.name]
2110 pCount, cCount = op["operands"]
2111 num_operands = pCount + cCount
2112
2113 output_names = [res.name for res in results]
2114 output_shapes = [res.shape for res in results]
2115 output_dtypes = [res.dtype for res in results]
2116
2117 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2118 self, error_name, input_names, output_names
2119 )
2120
2121 if not TosaErrorValidator.evValidateErrorIfs(
2122 self.ser,
2123 validator_fcns,
2124 error_name,
2125 op=op,
2126 inverse=inverse,
2127 input1=val1,
2128 input2=val2,
2129 input_shape=val1.shape,
2130 input_dtype=val1.dtype,
2131 output_shape=output_shapes,
2132 output_dtype=output_dtypes,
2133 result_tensors=results,
2134 input_list=input_names,
2135 output_list=output_names,
2136 num_operands=num_operands,
2137 ):
2138 return None
2139
2140 attr = ts.TosaSerializerAttribute()
2141 attr.FFTAttribute(inverse)
2142
2143 self.ser.addOperator(op["op"], input_names, output_names, attr)
2144 return results
2145
Luke Hutton261b7b62023-01-10 14:50:31 +00002146 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2147 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2148
2149 input_names = [val.name]
2150 pCount, cCount = op["operands"]
2151 num_operands = pCount + cCount
2152
2153 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002154 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002155 output_dtypes = [res.dtype for res in results]
2156
2157 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2158 self, error_name, input_names, output_names
2159 )
2160
2161 if not TosaErrorValidator.evValidateErrorIfs(
2162 self.ser,
2163 validator_fcns,
2164 error_name,
2165 op=op,
2166 input_shape=val.shape,
2167 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002168 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002169 output_dtype=output_dtypes,
2170 result_tensors=results,
2171 input_list=input_names,
2172 output_list=output_names,
2173 num_operands=num_operands,
2174 ):
2175 return None
2176
2177 self.ser.addOperator(op["op"], input_names, output_names)
2178 return results
2179
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002180 def create_filter_lists(
2181 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2182 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002183 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2184 default_test_rank_range = range(1, 5)
2185 if not shapeFilter:
2186 shapeFilter = [None]
2187
2188 # Calculate the filters based on what is requested and what the operator allows
2189 rmin, rmax = op["rank"]
2190 if rankFilter is not None:
2191 cleanRankFilter = []
2192 # Ensure rankFilter values are allowed by operator
2193 for rank in rankFilter:
2194 if rank >= rmin and rank <= rmax:
2195 cleanRankFilter.append(rank)
2196 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002197 # Ensure default behaviour is bounded by default range or by operator,
2198 # whichever is the smaller range of ranks.
2199 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002200 cleanRankFilter = (
2201 opRankRange
2202 if len(opRankRange) <= len(default_test_rank_range)
2203 else default_test_rank_range
2204 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002205 else:
2206 cleanRankFilter = range(rmin, rmax + 1)
2207
2208 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002209
Matthew Haddon1c00b712021-10-01 15:51:03 +01002210 if dtypeFilter is not None:
2211 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002212 # Create list of operator dtypes filtered by requested dtypes
2213 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002214 if dtype in dtypeFilter or (
2215 isinstance(dtype, list) and dtype[0] in dtypeFilter
2216 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002217 cleanDtypeFilter.append(dtype)
2218 else:
2219 cleanDtypeFilter = dtypes
2220
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002221 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002222 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002223 "shapeFilter": shapeFilter,
2224 "rankFilter": cleanRankFilter,
2225 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002226 }
2227 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002228 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002229 if validator is not None:
2230 validator_info = validator(check=False, op=op)
2231 else:
2232 return None
2233
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002234 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002235
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002236 # Set parameters as required
2237 if error_arguments["rank"] is not None:
2238 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002239 else:
2240 rankFilter = cleanRankFilter
2241
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002242 if error_arguments["dtype"] is not None:
2243 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002244 else:
2245 dtypeFilter = cleanDtypeFilter
2246
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002247 if error_arguments["shape"] is not None:
2248 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002249 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002250 shapeFilter = shapeFilter[
2251 :2
2252 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002253
2254 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002255 "shapeFilter": shapeFilter,
2256 "rankFilter": rankFilter,
2257 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002258 }
2259 return filterDict
2260
Kevin Cheng550ccc52021-03-03 11:21:43 -08002261 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002262 self,
2263 opName,
2264 shapeFilter=[None],
2265 rankFilter=None,
2266 dtypeFilter=None,
2267 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002268 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002269
2270 try:
2271 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002272 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002273 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
2275 # Initialize a new random number generator
2276 self.rng = np.random.default_rng(self.random_seed)
2277
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002278 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002279
Eric Kunzee5e26762020-10-13 16:11:07 -07002280 # Test list consists of a tuple of:
2281 # (opName, testNameStr, dtype, shapeList, argumentsList)
2282 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002283 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002284 error_if_validators = op["error_if_validators"]
2285 else:
2286 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002287
Matthew Haddon1c00b712021-10-01 15:51:03 +01002288 for validator in error_if_validators:
2289 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002290 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002291 else:
2292 error_name = None
2293
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002294 filterDict = self.create_filter_lists(
2295 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2296 )
2297 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002298 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002299 cleanRankFilter = filterDict["rankFilter"]
2300 cleanDtypeFilter = filterDict["dtypeFilter"]
2301 cleanShapeFilter = filterDict["shapeFilter"]
2302 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002303
2304 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002305 for t in cleanDtypeFilter:
2306 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002307 # Filter out by rank
2308 if shape is not None and len(shape) != r:
2309 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002310 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002311 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002312
Matthew Haddon74567092021-07-16 15:38:20 +01002313 shapeStr = self.shapeStr(shapeList[0])
2314 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002315
Matthew Haddon74567092021-07-16 15:38:20 +01002316 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2317 argList = []
2318 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002319 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002320 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002321 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002322
Matthew Haddon74567092021-07-16 15:38:20 +01002323 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002324 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002325 if argStr:
2326 testStr = "{}_{}_{}_{}".format(
2327 opName, shapeStr, typeStr, argStr
2328 )
2329 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002330 testStr = "{}_{}_{}".format(
2331 opName, shapeStr, typeStr
2332 )
2333 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002334 if argStr:
2335 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2336 opName, error_name, shapeStr, typeStr, argStr
2337 )
2338 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002339 testStr = "{}_ERRORIF_{}_{}_{}".format(
2340 opName, error_name, shapeStr, typeStr
2341 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002342
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002343 testList.append(
2344 (opName, testStr, t, error_name, shapeList, args)
2345 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002346
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002347 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002348 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2349 if "invalid_test_validators" in op:
2350 invalid_test_validators = op["invalid_test_validators"]
2351 clean_testList = []
2352 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002353 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002354 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002355 if validator_fcn(
2356 opName=test[0],
2357 input_dtype=test[2],
2358 shapeList=test[4],
2359 args=test[5],
2360 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002361 remove_test = True
2362 if not remove_test:
2363 clean_testList.append(test)
2364 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002365
2366 return testList
2367
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002368 def serializeTest(
2369 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2370 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002371 try:
2372 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002373 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002374 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002375
Jeremy Johnson0c716862023-04-13 17:18:19 +01002376 if self.args.verbose:
2377 print(f"Creating {testStr}")
2378
Eric Kunzee5e26762020-10-13 16:11:07 -07002379 # Create a serializer
2380 self.createSerializer(opName, testStr)
2381
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002382 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002383 if "error_if_validators" in op:
2384 error_if_validators = op["error_if_validators"]
2385 else:
2386 error_if_validators = None
2387
Kevin Cheng550ccc52021-03-03 11:21:43 -08002388 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002389 num_operands = pCount + cCount
2390
2391 if isinstance(dtype_or_dtypeList, list):
2392 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002393 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002394 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002395 else:
2396 dtypeList = [dtype_or_dtypeList] * (num_operands)
2397
Kevin Cheng93a16282021-08-31 16:14:03 -07002398 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002399 assert (
2400 len(shapeList) == num_operands
2401 ), "shapeList length {} must match number of operands {}".format(
2402 len(shapeList), num_operands
2403 )
2404 assert (
2405 len(dtypeList) == num_operands
2406 ), "dtypeList length {} must match number of operands {}".format(
2407 len(dtypeList), num_operands
2408 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002409
2410 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002411 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002412 except KeyError:
2413 qgen = None
2414
2415 # Build the random tensor operands and the test
2416 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002417
Matthew Haddon1c00b712021-10-01 15:51:03 +01002418 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002419 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002420 else:
2421 qinfo = None
2422
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002423 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002424
Matthew Haddon1c00b712021-10-01 15:51:03 +01002425 try:
2426 if error_if_validators is None:
2427 if qinfo is not None:
2428 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2429 else:
2430 resultName = build_fcn(self, op, *tens, *testArgs)
2431 else:
2432 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002433 resultName = build_fcn(
2434 self,
2435 op,
2436 *tens,
2437 *testArgs,
2438 validator_fcns=error_if_validators,
2439 error_name=error_name,
2440 qinfo=qinfo,
2441 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002442 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002443 resultName = build_fcn(
2444 self,
2445 op,
2446 *tens,
2447 *testArgs,
2448 validator_fcns=error_if_validators,
2449 error_name=error_name,
2450 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002451 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002452 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002453 raise e
2454
Les Bell729b0352021-11-24 10:28:21 +00002455 if resultName:
2456 # The test is valid, serialize it
2457 self.serialize("test")
2458 else:
2459 # The test is not valid
2460 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002461
Eric Kunzee5e26762020-10-13 16:11:07 -07002462 def createDynamicOpLists(self):
2463
Jeremy Johnson00423432022-09-12 17:27:37 +01002464 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2465 # Already created these lists (can occur when class is initialized more than once)
2466 return
2467
Eric Kunzee5e26762020-10-13 16:11:07 -07002468 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002469 if not self.args.level8k:
2470 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2471 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2472 else:
2473 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2474 KERNELS_2D = [[1, bigK], [bigK, 2]]
2475 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002476
Kevin Cheng1533b852021-09-01 12:51:58 -07002477 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002478 testName = "conv2d_{}x{}".format(k[0], k[1])
2479 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2480 self.TOSA_OP_LIST[testName]["filter"] = k
2481 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002482
Kevin Cheng550ccc52021-03-03 11:21:43 -08002483 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2484 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2485 "depthwise_conv2d_TEMPLATE"
2486 ].copy()
2487 self.TOSA_OP_LIST[testName]["filter"] = k
2488 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002489
Kevin Cheng550ccc52021-03-03 11:21:43 -08002490 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2491 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2492 "transpose_conv2d_TEMPLATE"
2493 ].copy()
2494 self.TOSA_OP_LIST[testName]["filter"] = k
2495 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002496
Kevin Cheng1533b852021-09-01 12:51:58 -07002497 for k in KERNELS_3D:
2498 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2499 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2500 self.TOSA_OP_LIST[testName]["filter"] = k
2501 self.TOSA_OP_LIST[testName]["template"] = False
2502
Eric Kunzee5e26762020-10-13 16:11:07 -07002503 # Delete any templates after having created any dynamic ops
2504 # This is a two-pass operation because it's bad practice to delete
2505 # keys from dictionaries while iterating
2506 keyList = []
2507 for k in self.TOSA_OP_LIST:
2508 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002509 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002510 keyList.append(k)
2511 continue
2512 except KeyError:
2513 pass
2514
2515 for k in keyList:
2516 del self.TOSA_OP_LIST[k]
2517
2518 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002519 """Fill in default fields for ops if they aren't already specified.
2520 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002521 for op in self.TOSA_OP_LIST:
2522
2523 # Required fields
2524 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002525 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002526 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002527 raise Exception(
2528 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2529 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002530
2531 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002532 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002533 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002534 raise Exception(
2535 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2536 op
2537 )
2538 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002539
2540 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002541 _ = self.TOSA_OP_LIST[op]["types"]
2542 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002543 raise Exception(
2544 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2545 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002546
2547 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002548 _ = self.TOSA_OP_LIST[op]["op"]
2549 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002550 raise Exception(
2551 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2552 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002553
2554 # Put in default rank range, if missing
2555 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002556 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002557 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002558 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002559
2560 # Tensor operator list
2561 # 'op': op name
2562 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002563 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2564 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002565 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2566 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002567 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002568
Kevin Cheng550ccc52021-03-03 11:21:43 -08002569 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002570 TYPE_INT_FP = [
2571 DType.INT8,
2572 DType.INT16,
2573 DType.INT32,
2574 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002575 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002576 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002577 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002578
Kevin Cheng550ccc52021-03-03 11:21:43 -08002579 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002580 TYPE_FI32 = [
2581 DType.FP32,
2582 DType.FP16,
2583 DType.BF16,
2584 DType.INT32,
2585 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002586 TYPE_FIB = [
2587 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002588 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002589 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002590 DType.INT8,
2591 DType.INT16,
2592 DType.INT32,
2593 DType.BOOL,
2594 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002595 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002596
James Ward24dbc422022-10-19 12:20:31 +01002597 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002598
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002599 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002600 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002601 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002602 [DType.INT8, DType.INT8, DType.INT32],
2603 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002604 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002605 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002606 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002607 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002608 ]
2609
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002610 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002611
2612 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002613 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002614 "argmax": {
2615 "op": Op.ARGMAX,
2616 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002617 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002618 "build_fcn": (
2619 build_argmax,
2620 TosaTensorGen.tgBasic,
2621 TosaTensorValuesGen.tvgDefault,
2622 TosaArgGen.agAxis,
2623 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002624 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002625 "error_if_validators": (
2626 TosaErrorValidator.evAxisSmallerZero,
2627 TosaErrorValidator.evAxisLargerRank,
2628 TosaErrorValidator.evArgmaxOutputRankMismatch,
2629 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2630 TosaErrorValidator.evWrongRank,
2631 TosaErrorValidator.evWrongInputType,
2632 TosaErrorValidator.evWrongOutputType,
2633 TosaErrorValidator.evWrongInputList,
2634 TosaErrorValidator.evWrongOutputList,
2635 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002636 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002637 "avg_pool2d": {
2638 "op": Op.AVG_POOL2D,
2639 "operands": (1, 0),
2640 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002641 "build_fcn": (
2642 build_pool2d,
2643 TosaTensorGen.tgNHWC,
2644 TosaTensorValuesGen.tvgDefault,
2645 TosaArgGen.agPooling,
2646 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002647 "qgen": TosaQuantGen.qgUnary,
2648 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002649 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002650 "error_if_validators": (
2651 TosaErrorValidator.evKernelSmallerOne,
2652 TosaErrorValidator.evStrideSmallerOne,
2653 TosaErrorValidator.evPadSmallerZero,
2654 TosaErrorValidator.evWrongRank,
2655 TosaErrorValidator.evWrongInputType,
2656 TosaErrorValidator.evWrongOutputType,
2657 TosaErrorValidator.evWrongInputList,
2658 TosaErrorValidator.evWrongOutputList,
2659 TosaErrorValidator.evInputZeroPointNotZero,
2660 TosaErrorValidator.evOutputZeroPointNotZero,
2661 TosaErrorValidator.evPadLargerEqualKernel,
2662 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002663 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002664 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002665 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002666 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002667 "conv2d_TEMPLATE": {
2668 "op": Op.CONV2D,
2669 "operands": (1, 2),
2670 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002671 "build_fcn": (
2672 build_conv2d,
2673 TosaTensorGen.tgConv2D,
2674 TosaTensorValuesGen.tvgDefault,
2675 TosaArgGen.agConv,
2676 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002677 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002678 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002679 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2680 "error_if_validators": (
2681 TosaErrorValidator.evWrongInputType,
2682 TosaErrorValidator.evWrongOutputType,
2683 TosaErrorValidator.evWrongInputList,
2684 TosaErrorValidator.evWrongOutputList,
2685 TosaErrorValidator.evInputZeroPointNotZero,
2686 TosaErrorValidator.evWeightZeroPointNotZero,
2687 TosaErrorValidator.evPadSmallerZero,
2688 TosaErrorValidator.evStrideSmallerOne,
2689 TosaErrorValidator.evDilationSmallerOne,
2690 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002691 TosaErrorValidator.evConvOutputShapeMismatch,
2692 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002693 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002694 "template": True,
2695 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002696 # Templated operator. Filled in by createDynamicOpLists
2697 "conv3d_TEMPLATE": {
2698 "op": Op.CONV3D,
2699 "operands": (1, 2),
2700 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002701 "build_fcn": (
2702 build_conv3d,
2703 TosaTensorGen.tgConv3D,
2704 TosaTensorValuesGen.tvgDefault,
2705 TosaArgGen.agConv,
2706 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002707 "qgen": TosaQuantGen.qgConv,
2708 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002709 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2710 "error_if_validators": (
2711 TosaErrorValidator.evWrongInputType,
2712 TosaErrorValidator.evWrongOutputType,
2713 TosaErrorValidator.evWrongInputList,
2714 TosaErrorValidator.evWrongOutputList,
2715 TosaErrorValidator.evInputZeroPointNotZero,
2716 TosaErrorValidator.evWeightZeroPointNotZero,
2717 TosaErrorValidator.evPadSmallerZero,
2718 TosaErrorValidator.evStrideSmallerOne,
2719 TosaErrorValidator.evDilationSmallerOne,
2720 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002721 TosaErrorValidator.evConvOutputShapeMismatch,
2722 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002723 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002724 "template": True,
2725 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002726 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002727 "depthwise_conv2d_TEMPLATE": {
2728 "op": Op.DEPTHWISE_CONV2D,
2729 "operands": (1, 2),
2730 "filter": [1, 1],
2731 "rank": (4, 4),
2732 "build_fcn": (
2733 build_depthwise_conv2d,
2734 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002735 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002736 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002737 ),
2738 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002739 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002740 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2741 "error_if_validators": (
2742 TosaErrorValidator.evWrongInputType,
2743 TosaErrorValidator.evWrongOutputType,
2744 TosaErrorValidator.evWrongInputList,
2745 TosaErrorValidator.evWrongOutputList,
2746 TosaErrorValidator.evInputZeroPointNotZero,
2747 TosaErrorValidator.evWeightZeroPointNotZero,
2748 TosaErrorValidator.evPadSmallerZero,
2749 TosaErrorValidator.evStrideSmallerOne,
2750 TosaErrorValidator.evDilationSmallerOne,
2751 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002752 TosaErrorValidator.evConvOutputShapeMismatch,
2753 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002754 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002755 "template": True,
2756 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002757 "fully_connected": {
2758 "op": Op.FULLY_CONNECTED,
2759 "operands": (1, 2),
2760 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002761 "build_fcn": (
2762 build_fully_connected,
2763 TosaTensorGen.tgFullyConnected,
2764 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002765 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002766 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002767 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002768 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002769 "error_if_validators": (
2770 TosaErrorValidator.evInputZeroPointNotZero,
2771 TosaErrorValidator.evWeightZeroPointNotZero,
2772 TosaErrorValidator.evWrongRank,
2773 TosaErrorValidator.evWrongInputType,
2774 TosaErrorValidator.evWrongOutputType,
2775 TosaErrorValidator.evWrongInputList,
2776 TosaErrorValidator.evWrongOutputList,
2777 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002778 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002779 "matmul": {
2780 "op": Op.MATMUL,
2781 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002782 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002783 "build_fcn": (
2784 build_matmul,
2785 TosaTensorGen.tgMatmul,
2786 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002787 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002788 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002789 "qgen": TosaQuantGen.qgMatmul,
2790 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002791 "error_if_validators": (
2792 TosaErrorValidator.evInputZeroPointNotZero,
2793 TosaErrorValidator.evWrongRank,
2794 TosaErrorValidator.evWrongInputType,
2795 TosaErrorValidator.evWrongOutputType,
2796 TosaErrorValidator.evWrongInputList,
2797 TosaErrorValidator.evWrongOutputList,
2798 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002799 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002800 "max_pool2d": {
2801 "op": Op.MAX_POOL2D,
2802 "operands": (1, 0),
2803 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002804 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002805 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002806 TosaTensorGen.tgNHWC,
2807 TosaTensorValuesGen.tvgDefault,
2808 TosaArgGen.agPooling,
2809 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002810 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002811 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002812 "error_if_validators": (
2813 TosaErrorValidator.evKernelSmallerOne,
2814 TosaErrorValidator.evStrideSmallerOne,
2815 TosaErrorValidator.evPadSmallerZero,
2816 TosaErrorValidator.evWrongRank,
2817 TosaErrorValidator.evWrongInputType,
2818 TosaErrorValidator.evWrongOutputType,
2819 TosaErrorValidator.evWrongInputList,
2820 TosaErrorValidator.evWrongOutputList,
2821 TosaErrorValidator.evPadLargerEqualKernel,
2822 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002823 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002824 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002825 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002826 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002827 "transpose_conv2d_TEMPLATE": {
2828 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002829 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002830 "rank": (4, 4),
2831 "build_fcn": (
2832 build_transpose_conv2d,
2833 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002834 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002835 TosaArgGen.agTransposeConv2D,
2836 ),
2837 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002838 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002839 "invalid_test_validators": (
2840 TosaInvalidValidator.ivHeightWidthInvalid,
2841 TosaInvalidValidator.ivNonPositiveOutputShape,
2842 ),
2843 "error_if_validators": (
2844 TosaErrorValidator.evWrongInputType,
2845 TosaErrorValidator.evWrongOutputType,
2846 TosaErrorValidator.evWrongInputList,
2847 TosaErrorValidator.evWrongOutputList,
2848 TosaErrorValidator.evInputZeroPointNotZero,
2849 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002850 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002851 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002852 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002853 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002854 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002855 "template": True,
2856 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002857 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002858 "clamp": {
2859 "op": Op.CLAMP,
2860 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002861 "build_fcn": (
2862 build_clamp,
2863 TosaTensorGen.tgBasic,
2864 TosaTensorValuesGen.tvgDefault,
2865 None,
2866 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002867 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002868 "error_if_validators": (
2869 TosaErrorValidator.evMaxSmallerMin,
2870 TosaErrorValidator.evWrongInputType,
2871 TosaErrorValidator.evWrongOutputType,
2872 TosaErrorValidator.evWrongInputList,
2873 TosaErrorValidator.evWrongOutputList,
2874 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002875 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002876 "sigmoid": {
2877 "op": Op.SIGMOID,
2878 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002879 "build_fcn": (
2880 build_sigmoid,
2881 TosaTensorGen.tgBasic,
2882 TosaTensorValuesGen.tvgDefault,
2883 None,
2884 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002885 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002886 "error_if_validators": (
2887 TosaErrorValidator.evWrongInputType,
2888 TosaErrorValidator.evWrongOutputType,
2889 TosaErrorValidator.evWrongInputList,
2890 TosaErrorValidator.evWrongOutputList,
2891 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002892 },
2893 "tanh": {
2894 "op": Op.TANH,
2895 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002896 "build_fcn": (
2897 build_tanh,
2898 TosaTensorGen.tgBasic,
2899 TosaTensorValuesGen.tvgDefault,
2900 None,
2901 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002902 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002903 "error_if_validators": (
2904 TosaErrorValidator.evWrongInputType,
2905 TosaErrorValidator.evWrongOutputType,
2906 TosaErrorValidator.evWrongInputList,
2907 TosaErrorValidator.evWrongOutputList,
2908 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002909 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002910 # Elementwise Binary Operators
2911 "add": {
2912 "op": Op.ADD,
2913 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002914 "build_fcn": (
2915 build_binary_broadcast,
2916 TosaTensorGen.tgBroadcastFuzz,
2917 TosaTensorValuesGen.tvgAddSub,
2918 None,
2919 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002920 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002921 "error_if_validators": (
2922 TosaErrorValidator.evRankMismatch,
2923 TosaErrorValidator.evWrongInputType,
2924 TosaErrorValidator.evWrongOutputType,
2925 TosaErrorValidator.evWrongInputList,
2926 TosaErrorValidator.evWrongOutputList,
2927 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00002928 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002929 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002930 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002931 "arithmetic_right_shift": {
2932 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2933 "operands": (2, 0),
2934 "build_fcn": (
2935 build_arithmetic_right_shift,
2936 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002937 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002938 TosaArgGen.agArithmeticRightShift,
2939 ),
2940 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002941 "error_if_validators": (
2942 TosaErrorValidator.evRankMismatch,
2943 TosaErrorValidator.evWrongInputType,
2944 TosaErrorValidator.evWrongOutputType,
2945 TosaErrorValidator.evWrongInputList,
2946 TosaErrorValidator.evWrongOutputList,
2947 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00002948 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002949 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002950 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002951 "bitwise_and": {
2952 "op": Op.BITWISE_AND,
2953 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002954 "build_fcn": (
2955 build_binary_broadcast,
2956 TosaTensorGen.tgBroadcastFuzz,
2957 TosaTensorValuesGen.tvgDefault,
2958 None,
2959 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002960 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002961 "error_if_validators": (
2962 TosaErrorValidator.evRankMismatch,
2963 TosaErrorValidator.evWrongInputType,
2964 TosaErrorValidator.evWrongOutputType,
2965 TosaErrorValidator.evWrongInputList,
2966 TosaErrorValidator.evWrongOutputList,
2967 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00002968 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002969 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002970 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002971 "bitwise_or": {
2972 "op": Op.BITWISE_OR,
2973 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002974 "build_fcn": (
2975 build_binary_broadcast,
2976 TosaTensorGen.tgBroadcastFuzz,
2977 TosaTensorValuesGen.tvgDefault,
2978 None,
2979 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002980 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002981 "error_if_validators": (
2982 TosaErrorValidator.evRankMismatch,
2983 TosaErrorValidator.evWrongInputType,
2984 TosaErrorValidator.evWrongOutputType,
2985 TosaErrorValidator.evWrongInputList,
2986 TosaErrorValidator.evWrongOutputList,
2987 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00002988 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002989 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002990 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002991 "bitwise_xor": {
2992 "op": Op.BITWISE_XOR,
2993 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002994 "build_fcn": (
2995 build_binary_broadcast,
2996 TosaTensorGen.tgBroadcastFuzz,
2997 TosaTensorValuesGen.tvgDefault,
2998 None,
2999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003000 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003001 "error_if_validators": (
3002 TosaErrorValidator.evRankMismatch,
3003 TosaErrorValidator.evWrongInputType,
3004 TosaErrorValidator.evWrongOutputType,
3005 TosaErrorValidator.evWrongInputList,
3006 TosaErrorValidator.evWrongOutputList,
3007 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003008 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003009 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003010 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003011 "intdiv": {
3012 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003013 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003014 "build_fcn": (
3015 build_binary_broadcast,
3016 TosaTensorGen.tgBroadcastFuzz,
3017 TosaTensorValuesGen.tvgIntDiv,
3018 None,
3019 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003020 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003021 "error_if_validators": (
3022 TosaErrorValidator.evRankMismatch,
3023 TosaErrorValidator.evWrongInputType,
3024 TosaErrorValidator.evWrongOutputType,
3025 TosaErrorValidator.evWrongInputList,
3026 TosaErrorValidator.evWrongOutputList,
3027 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003028 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003029 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003030 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003031 "logical_and": {
3032 "op": Op.LOGICAL_AND,
3033 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003034 "build_fcn": (
3035 build_binary_broadcast,
3036 TosaTensorGen.tgBroadcastFuzz,
3037 TosaTensorValuesGen.tvgDefault,
3038 None,
3039 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003040 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003041 "error_if_validators": (
3042 TosaErrorValidator.evRankMismatch,
3043 TosaErrorValidator.evWrongInputType,
3044 TosaErrorValidator.evWrongOutputType,
3045 TosaErrorValidator.evWrongInputList,
3046 TosaErrorValidator.evWrongOutputList,
3047 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003048 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003049 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003050 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003051 "logical_left_shift": {
3052 "op": Op.LOGICAL_LEFT_SHIFT,
3053 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003054 "build_fcn": (
3055 build_binary_broadcast,
3056 TosaTensorGen.tgBroadcastFuzz,
3057 TosaTensorValuesGen.tvgLogicalShift,
3058 None,
3059 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003060 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003061 "error_if_validators": (
3062 TosaErrorValidator.evRankMismatch,
3063 TosaErrorValidator.evWrongInputType,
3064 TosaErrorValidator.evWrongOutputType,
3065 TosaErrorValidator.evWrongInputList,
3066 TosaErrorValidator.evWrongOutputList,
3067 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003068 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003069 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003070 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003071 "logical_right_shift": {
3072 "op": Op.LOGICAL_RIGHT_SHIFT,
3073 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003074 "build_fcn": (
3075 build_binary_broadcast,
3076 TosaTensorGen.tgBroadcastFuzz,
3077 TosaTensorValuesGen.tvgLogicalShift,
3078 None,
3079 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003080 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003081 "error_if_validators": (
3082 TosaErrorValidator.evRankMismatch,
3083 TosaErrorValidator.evWrongInputType,
3084 TosaErrorValidator.evWrongOutputType,
3085 TosaErrorValidator.evWrongInputList,
3086 TosaErrorValidator.evWrongOutputList,
3087 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003088 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003089 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003090 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003091 "logical_or": {
3092 "op": Op.LOGICAL_OR,
3093 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003094 "build_fcn": (
3095 build_binary_broadcast,
3096 TosaTensorGen.tgBroadcastFuzz,
3097 TosaTensorValuesGen.tvgDefault,
3098 None,
3099 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003100 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003101 "error_if_validators": (
3102 TosaErrorValidator.evRankMismatch,
3103 TosaErrorValidator.evWrongInputType,
3104 TosaErrorValidator.evWrongOutputType,
3105 TosaErrorValidator.evWrongInputList,
3106 TosaErrorValidator.evWrongOutputList,
3107 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003108 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003109 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003110 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003111 "logical_xor": {
3112 "op": Op.LOGICAL_XOR,
3113 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003114 "build_fcn": (
3115 build_binary_broadcast,
3116 TosaTensorGen.tgBroadcastFuzz,
3117 TosaTensorValuesGen.tvgDefault,
3118 None,
3119 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003120 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003121 "error_if_validators": (
3122 TosaErrorValidator.evRankMismatch,
3123 TosaErrorValidator.evWrongInputType,
3124 TosaErrorValidator.evWrongOutputType,
3125 TosaErrorValidator.evWrongInputList,
3126 TosaErrorValidator.evWrongOutputList,
3127 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003128 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003129 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003130 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003131 "maximum": {
3132 "op": Op.MAXIMUM,
3133 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003134 "build_fcn": (
3135 build_binary_broadcast,
3136 TosaTensorGen.tgBroadcastFuzz,
3137 TosaTensorValuesGen.tvgDefault,
3138 None,
3139 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003140 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003141 "error_if_validators": (
3142 TosaErrorValidator.evRankMismatch,
3143 TosaErrorValidator.evWrongInputType,
3144 TosaErrorValidator.evWrongOutputType,
3145 TosaErrorValidator.evWrongInputList,
3146 TosaErrorValidator.evWrongOutputList,
3147 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003148 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003149 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003150 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003151 "minimum": {
3152 "op": Op.MINIMUM,
3153 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003154 "build_fcn": (
3155 build_binary_broadcast,
3156 TosaTensorGen.tgBroadcastFuzz,
3157 TosaTensorValuesGen.tvgDefault,
3158 None,
3159 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003160 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003161 "error_if_validators": (
3162 TosaErrorValidator.evRankMismatch,
3163 TosaErrorValidator.evWrongInputType,
3164 TosaErrorValidator.evWrongOutputType,
3165 TosaErrorValidator.evWrongInputList,
3166 TosaErrorValidator.evWrongOutputList,
3167 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003168 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003169 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003170 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003171 "mul": {
3172 "op": Op.MUL,
3173 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003174 "build_fcn": (
3175 build_mul,
3176 TosaTensorGen.tgBroadcastFuzz,
3177 TosaTensorValuesGen.tvgMul,
3178 TosaArgGen.agMul,
3179 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003180 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003181 "error_if_validators": (
3182 TosaErrorValidator.evWrongInputType,
3183 TosaErrorValidator.evWrongOutputType,
3184 TosaErrorValidator.evWrongInputList,
3185 TosaErrorValidator.evWrongOutputList,
3186 TosaErrorValidator.evRankMismatch,
3187 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003188 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003189 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003190 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003191 "pow": {
3192 "op": Op.POW,
3193 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003194 "build_fcn": (
3195 build_binary_broadcast,
3196 TosaTensorGen.tgBroadcastFuzz,
3197 TosaTensorValuesGen.tvgDefault,
3198 None,
3199 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003200 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003201 "error_if_validators": (
3202 TosaErrorValidator.evRankMismatch,
3203 TosaErrorValidator.evWrongInputType,
3204 TosaErrorValidator.evWrongOutputType,
3205 TosaErrorValidator.evWrongInputList,
3206 TosaErrorValidator.evWrongOutputList,
3207 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003208 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003209 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003210 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003211 "sub": {
3212 "op": Op.SUB,
3213 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003214 "build_fcn": (
3215 build_binary_broadcast,
3216 TosaTensorGen.tgBroadcastFuzz,
3217 TosaTensorValuesGen.tvgAddSub,
3218 None,
3219 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003220 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003221 "error_if_validators": (
3222 TosaErrorValidator.evRankMismatch,
3223 TosaErrorValidator.evWrongInputType,
3224 TosaErrorValidator.evWrongOutputType,
3225 TosaErrorValidator.evWrongInputList,
3226 TosaErrorValidator.evWrongOutputList,
3227 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003228 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003229 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003230 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003231 "table": {
3232 "op": Op.TABLE,
3233 # Use the automatic generation functions to create the input array
3234 # but create the table tensor in the build function, as it may be
3235 # a different type from the input
3236 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003237 "build_fcn": (
3238 build_table,
3239 TosaTensorGen.tgBasic,
3240 TosaTensorValuesGen.tvgDefault,
3241 TosaArgGen.agTable,
3242 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003243 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003244 "error_if_validators": (
3245 TosaErrorValidator.evWrongInputType,
3246 TosaErrorValidator.evWrongOutputType,
3247 TosaErrorValidator.evWrongInputList,
3248 TosaErrorValidator.evWrongOutputList,
3249 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003250 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003251 # Elementwise Unary operators
3252 "abs": {
3253 "op": Op.ABS,
3254 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 "build_fcn": (
3256 build_unary,
3257 TosaTensorGen.tgBasic,
3258 TosaTensorValuesGen.tvgDefault,
3259 None,
3260 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003261 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003262 "error_if_validators": (
3263 TosaErrorValidator.evWrongInputType,
3264 TosaErrorValidator.evWrongOutputType,
3265 TosaErrorValidator.evWrongInputList,
3266 TosaErrorValidator.evWrongOutputList,
3267 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003268 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003269 "bitwise_not": {
3270 "op": Op.BITWISE_NOT,
3271 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003272 "build_fcn": (
3273 build_unary,
3274 TosaTensorGen.tgBasic,
3275 TosaTensorValuesGen.tvgDefault,
3276 None,
3277 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003278 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003279 "error_if_validators": (
3280 TosaErrorValidator.evWrongInputType,
3281 TosaErrorValidator.evWrongOutputType,
3282 TosaErrorValidator.evWrongInputList,
3283 TosaErrorValidator.evWrongOutputList,
3284 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003285 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003286 "ceil": {
3287 "op": Op.CEIL,
3288 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003289 "build_fcn": (
3290 build_unary,
3291 TosaTensorGen.tgBasic,
3292 TosaTensorValuesGen.tvgDefault,
3293 None,
3294 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003295 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003296 "error_if_validators": (
3297 TosaErrorValidator.evWrongInputType,
3298 TosaErrorValidator.evWrongOutputType,
3299 TosaErrorValidator.evWrongInputList,
3300 TosaErrorValidator.evWrongOutputList,
3301 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003302 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003303 "clz": {
3304 "op": Op.CLZ,
3305 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003306 "build_fcn": (
3307 build_unary,
3308 TosaTensorGen.tgBasic,
3309 TosaTensorValuesGen.tvgDefault,
3310 None,
3311 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003312 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003313 "error_if_validators": (
3314 TosaErrorValidator.evWrongInputType,
3315 TosaErrorValidator.evWrongOutputType,
3316 TosaErrorValidator.evWrongInputList,
3317 TosaErrorValidator.evWrongOutputList,
3318 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003319 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003320 "exp": {
3321 "op": Op.EXP,
3322 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003323 "build_fcn": (
3324 build_unary,
3325 TosaTensorGen.tgBasic,
3326 TosaTensorValuesGen.tvgDefault,
3327 None,
3328 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003329 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003330 "error_if_validators": (
3331 TosaErrorValidator.evWrongInputType,
3332 TosaErrorValidator.evWrongOutputType,
3333 TosaErrorValidator.evWrongInputList,
3334 TosaErrorValidator.evWrongOutputList,
3335 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003336 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 "floor": {
3338 "op": Op.FLOOR,
3339 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003340 "build_fcn": (
3341 build_unary,
3342 TosaTensorGen.tgBasic,
3343 TosaTensorValuesGen.tvgDefault,
3344 None,
3345 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003346 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003347 "error_if_validators": (
3348 TosaErrorValidator.evWrongInputType,
3349 TosaErrorValidator.evWrongOutputType,
3350 TosaErrorValidator.evWrongInputList,
3351 TosaErrorValidator.evWrongOutputList,
3352 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003353 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003354 "log": {
3355 "op": Op.LOG,
3356 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003357 "build_fcn": (
3358 build_unary,
3359 TosaTensorGen.tgBasic,
3360 TosaTensorValuesGen.tvgDefault,
3361 None,
3362 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003363 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003364 "error_if_validators": (
3365 TosaErrorValidator.evWrongInputType,
3366 TosaErrorValidator.evWrongOutputType,
3367 TosaErrorValidator.evWrongInputList,
3368 TosaErrorValidator.evWrongOutputList,
3369 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003370 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003371 "logical_not": {
3372 "op": Op.LOGICAL_NOT,
3373 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003374 "build_fcn": (
3375 build_unary,
3376 TosaTensorGen.tgBasic,
3377 TosaTensorValuesGen.tvgDefault,
3378 None,
3379 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003380 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003381 "error_if_validators": (
3382 TosaErrorValidator.evWrongInputType,
3383 TosaErrorValidator.evWrongOutputType,
3384 TosaErrorValidator.evWrongInputList,
3385 TosaErrorValidator.evWrongOutputList,
3386 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003387 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003388 "negate": {
3389 "op": Op.NEGATE,
3390 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003391 "build_fcn": (
3392 build_unary,
3393 TosaTensorGen.tgBasic,
3394 TosaTensorValuesGen.tvgNegate,
3395 None,
3396 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003397 "qgen": TosaQuantGen.qgUnary,
3398 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003399 "error_if_validators": (
3400 TosaErrorValidator.evInputZeroPointNotZero,
3401 TosaErrorValidator.evOutputZeroPointNotZero,
3402 TosaErrorValidator.evWrongInputType,
3403 TosaErrorValidator.evWrongOutputType,
3404 TosaErrorValidator.evWrongInputList,
3405 TosaErrorValidator.evWrongOutputList,
3406 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003407 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003408 "reciprocal": {
3409 "op": Op.RECIPROCAL,
3410 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003411 "build_fcn": (
3412 build_unary,
3413 TosaTensorGen.tgBasic,
3414 TosaTensorValuesGen.tvgDefault,
3415 None,
3416 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003417 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003418 "error_if_validators": (
3419 TosaErrorValidator.evWrongInputType,
3420 TosaErrorValidator.evWrongOutputType,
3421 TosaErrorValidator.evWrongInputList,
3422 TosaErrorValidator.evWrongOutputList,
3423 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003424 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003425 "rsqrt": {
3426 "op": Op.RSQRT,
3427 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003428 "build_fcn": (
3429 build_unary,
3430 TosaTensorGen.tgBasic,
3431 TosaTensorValuesGen.tvgDefault,
3432 None,
3433 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003434 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003435 "error_if_validators": (
3436 TosaErrorValidator.evWrongInputType,
3437 TosaErrorValidator.evWrongOutputType,
3438 TosaErrorValidator.evWrongInputList,
3439 TosaErrorValidator.evWrongOutputList,
3440 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003441 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003442 # Elementwise Ternary operators
3443 "select": {
3444 "op": Op.SELECT,
3445 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003446 "build_fcn": (
3447 build_select,
3448 TosaTensorGen.tgBroadcastFuzz,
3449 TosaTensorValuesGen.tvgSelect,
3450 None,
3451 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003452 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003453 "error_if_validators": (
3454 TosaErrorValidator.evRankMismatch,
3455 TosaErrorValidator.evWrongInputType,
3456 TosaErrorValidator.evWrongOutputType,
3457 TosaErrorValidator.evWrongInputList,
3458 TosaErrorValidator.evWrongOutputList,
3459 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003460 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003461 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003462 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003463 # Comparison operators
3464 "equal": {
3465 "op": Op.EQUAL,
3466 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003467 "build_fcn": (
3468 build_comparison,
3469 TosaTensorGen.tgBroadcastFuzz,
3470 TosaTensorValuesGen.tvgEqual,
3471 None,
3472 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003474 "error_if_validators": (
3475 TosaErrorValidator.evRankMismatch,
3476 TosaErrorValidator.evWrongInputType,
3477 TosaErrorValidator.evWrongOutputType,
3478 TosaErrorValidator.evWrongInputList,
3479 TosaErrorValidator.evWrongOutputList,
3480 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003481 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003482 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003483 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003484 "greater_equal": {
3485 "op": Op.GREATER_EQUAL,
3486 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003487 "build_fcn": (
3488 build_comparison,
3489 TosaTensorGen.tgBroadcastFuzz,
3490 TosaTensorValuesGen.tvgDefault,
3491 None,
3492 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003493 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003494 "error_if_validators": (
3495 TosaErrorValidator.evRankMismatch,
3496 TosaErrorValidator.evWrongInputType,
3497 TosaErrorValidator.evWrongOutputType,
3498 TosaErrorValidator.evWrongInputList,
3499 TosaErrorValidator.evWrongOutputList,
3500 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003501 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003502 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003503 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 "greater": {
3505 "op": Op.GREATER,
3506 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003507 "build_fcn": (
3508 build_comparison,
3509 TosaTensorGen.tgBroadcastFuzz,
3510 TosaTensorValuesGen.tvgDefault,
3511 None,
3512 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003513 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003514 "error_if_validators": (
3515 TosaErrorValidator.evRankMismatch,
3516 TosaErrorValidator.evWrongInputType,
3517 TosaErrorValidator.evWrongOutputType,
3518 TosaErrorValidator.evWrongInputList,
3519 TosaErrorValidator.evWrongOutputList,
3520 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003521 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003522 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003523 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003524 # Reduction operators
3525 "reduce_all": {
3526 "op": Op.REDUCE_ALL,
3527 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003528 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003529 "build_fcn": (
3530 build_reduce,
3531 TosaTensorGen.tgBasic,
3532 TosaTensorValuesGen.tvgDefault,
3533 TosaArgGen.agAxis,
3534 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003536 "error_if_validators": (
3537 TosaErrorValidator.evAxisLargerRank,
3538 TosaErrorValidator.evAxisSmallerZero,
3539 TosaErrorValidator.evShapeOfAxisNotOne,
3540 TosaErrorValidator.evWrongInputType,
3541 TosaErrorValidator.evWrongOutputType,
3542 TosaErrorValidator.evWrongRank,
3543 TosaErrorValidator.evWrongInputList,
3544 TosaErrorValidator.evWrongOutputList,
3545 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003546 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003547 "reduce_any": {
3548 "op": Op.REDUCE_ANY,
3549 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003550 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003551 "build_fcn": (
3552 build_reduce,
3553 TosaTensorGen.tgBasic,
3554 TosaTensorValuesGen.tvgDefault,
3555 TosaArgGen.agAxis,
3556 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003557 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003558 "error_if_validators": (
3559 TosaErrorValidator.evAxisLargerRank,
3560 TosaErrorValidator.evAxisSmallerZero,
3561 TosaErrorValidator.evShapeOfAxisNotOne,
3562 TosaErrorValidator.evWrongInputType,
3563 TosaErrorValidator.evWrongOutputType,
3564 TosaErrorValidator.evWrongRank,
3565 TosaErrorValidator.evWrongInputList,
3566 TosaErrorValidator.evWrongOutputList,
3567 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003568 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003569 "reduce_max": {
3570 "op": Op.REDUCE_MAX,
3571 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003572 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003573 "build_fcn": (
3574 build_reduce,
3575 TosaTensorGen.tgBasic,
3576 TosaTensorValuesGen.tvgDefault,
3577 TosaArgGen.agAxis,
3578 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003579 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003580 "error_if_validators": (
3581 TosaErrorValidator.evAxisLargerRank,
3582 TosaErrorValidator.evAxisSmallerZero,
3583 TosaErrorValidator.evShapeOfAxisNotOne,
3584 TosaErrorValidator.evWrongInputType,
3585 TosaErrorValidator.evWrongOutputType,
3586 TosaErrorValidator.evWrongRank,
3587 TosaErrorValidator.evWrongInputList,
3588 TosaErrorValidator.evWrongOutputList,
3589 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003590 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003591 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003592 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003593 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003594 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003595 "build_fcn": (
3596 build_reduce,
3597 TosaTensorGen.tgBasic,
3598 TosaTensorValuesGen.tvgDefault,
3599 TosaArgGen.agAxis,
3600 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003601 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003602 "error_if_validators": (
3603 TosaErrorValidator.evAxisLargerRank,
3604 TosaErrorValidator.evAxisSmallerZero,
3605 TosaErrorValidator.evShapeOfAxisNotOne,
3606 TosaErrorValidator.evWrongInputType,
3607 TosaErrorValidator.evWrongOutputType,
3608 TosaErrorValidator.evWrongRank,
3609 TosaErrorValidator.evWrongInputList,
3610 TosaErrorValidator.evWrongOutputList,
3611 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003612 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003613 "reduce_product": {
3614 "op": Op.REDUCE_PRODUCT,
3615 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003616 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003617 "build_fcn": (
3618 build_reduce,
3619 TosaTensorGen.tgBasic,
3620 TosaTensorValuesGen.tvgDefault,
3621 TosaArgGen.agAxis,
3622 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003623 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003624 "error_if_validators": (
3625 TosaErrorValidator.evAxisLargerRank,
3626 TosaErrorValidator.evAxisSmallerZero,
3627 TosaErrorValidator.evShapeOfAxisNotOne,
3628 TosaErrorValidator.evWrongInputType,
3629 TosaErrorValidator.evWrongOutputType,
3630 TosaErrorValidator.evWrongRank,
3631 TosaErrorValidator.evWrongInputList,
3632 TosaErrorValidator.evWrongOutputList,
3633 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003634 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003635 "reduce_sum": {
3636 "op": Op.REDUCE_SUM,
3637 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003638 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003639 "build_fcn": (
3640 build_reduce,
3641 TosaTensorGen.tgBasic,
3642 TosaTensorValuesGen.tvgReduceSum,
3643 TosaArgGen.agAxis,
3644 ),
James Ward24dbc422022-10-19 12:20:31 +01003645 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003646 "error_if_validators": (
3647 TosaErrorValidator.evAxisLargerRank,
3648 TosaErrorValidator.evAxisSmallerZero,
3649 TosaErrorValidator.evShapeOfAxisNotOne,
3650 TosaErrorValidator.evWrongInputType,
3651 TosaErrorValidator.evWrongOutputType,
3652 TosaErrorValidator.evWrongRank,
3653 TosaErrorValidator.evWrongInputList,
3654 TosaErrorValidator.evWrongOutputList,
3655 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003656 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003657 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003658 "concat": {
3659 "op": Op.CONCAT,
3660 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003661 "build_fcn": (
3662 build_concat,
3663 TosaTensorGen.tgConcat,
3664 TosaTensorValuesGen.tvgConcat,
3665 TosaArgGen.agAxis,
3666 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003667 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003668 "error_if_validators": (
3669 TosaErrorValidator.evAxisLargerRank,
3670 TosaErrorValidator.evAxisSmallerZero,
3671 TosaErrorValidator.evConcatInputRankMismatch,
3672 TosaErrorValidator.evConcatShapeSumMismatch,
3673 TosaErrorValidator.evConcatInputDimMismatch,
3674 TosaErrorValidator.evWrongInputType,
3675 TosaErrorValidator.evWrongOutputType,
3676 TosaErrorValidator.evWrongOutputList,
3677 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003678 },
3679 "pad": {
3680 "op": Op.PAD,
3681 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003682 "build_fcn": (
3683 build_pad,
3684 TosaTensorGen.tgBasic,
3685 TosaTensorValuesGen.tvgDefault,
3686 TosaArgGen.agPad,
3687 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003688 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003689 "error_if_validators": (
3690 TosaErrorValidator.evWrongInputType,
3691 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003692 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003693 TosaErrorValidator.evWrongOutputType,
3694 TosaErrorValidator.evWrongInputList,
3695 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003696 TosaErrorValidator.evRankMismatch,
3697 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003698 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003699 },
3700 "reshape": {
3701 "op": Op.RESHAPE,
3702 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003703 "build_fcn": (
3704 build_reshape,
3705 TosaTensorGen.tgBasic,
3706 TosaTensorValuesGen.tvgDefault,
3707 TosaArgGen.agReshape,
3708 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003709 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003710 "error_if_validators": (
3711 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3712 TosaErrorValidator.evWrongInputType,
3713 TosaErrorValidator.evWrongOutputType,
3714 TosaErrorValidator.evWrongInputList,
3715 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00003716 TosaErrorValidator.evReshapeOutputSizeMultiInference,
3717 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003718 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003719 },
3720 "reverse": {
3721 "op": Op.REVERSE,
3722 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003723 "build_fcn": (
3724 build_reverse,
3725 TosaTensorGen.tgBasic,
3726 TosaTensorValuesGen.tvgDefault,
3727 TosaArgGen.agAxis,
3728 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003729 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003730 "error_if_validators": (
3731 TosaErrorValidator.evAxisSmallerZero,
3732 TosaErrorValidator.evAxisLargerRank,
3733 TosaErrorValidator.evWrongInputType,
3734 TosaErrorValidator.evWrongOutputType,
3735 TosaErrorValidator.evWrongInputList,
3736 TosaErrorValidator.evWrongOutputList,
3737 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003738 },
3739 "slice": {
3740 "op": Op.SLICE,
3741 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003742 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003743 "build_fcn": (
3744 build_slice,
3745 TosaTensorGen.tgBasic,
3746 TosaTensorValuesGen.tvgDefault,
3747 TosaArgGen.agSlice,
3748 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003749 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003750 "error_if_validators": (
3751 TosaErrorValidator.evStartSmallerZero,
3752 TosaErrorValidator.evSizeSmallerEqualZero,
3753 TosaErrorValidator.evStartSizeOutsideBounds,
3754 TosaErrorValidator.evSizeOutputShapeMismatch,
3755 TosaErrorValidator.evInputSizeStartLengthMismatch,
3756 TosaErrorValidator.evWrongRank,
3757 TosaErrorValidator.evWrongInputType,
3758 TosaErrorValidator.evWrongOutputType,
3759 TosaErrorValidator.evWrongInputList,
3760 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003761 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003762 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003763 },
3764 "tile": {
3765 "op": Op.TILE,
3766 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003767 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003768 "build_fcn": (
3769 build_tile,
3770 TosaTensorGen.tgBasic,
3771 TosaTensorValuesGen.tvgDefault,
3772 TosaArgGen.agTile,
3773 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003774 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003775 "error_if_validators": (
3776 TosaErrorValidator.evWrongInputType,
3777 TosaErrorValidator.evWrongOutputType,
3778 TosaErrorValidator.evWrongInputList,
3779 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003780 TosaErrorValidator.evRankMismatch,
3781 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003782 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003783 },
3784 "transpose": {
3785 "op": Op.TRANSPOSE,
3786 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003787 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003788 "build_fcn": (
3789 build_transpose,
3790 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003791 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003792 TosaArgGen.agTranspose,
3793 ),
3794 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003795 "error_if_validators": (
3796 TosaErrorValidator.evIndexOutsideBounds,
3797 TosaErrorValidator.evIndexUsedTwice,
3798 TosaErrorValidator.evWrongInputType,
3799 TosaErrorValidator.evWrongOutputType,
3800 TosaErrorValidator.evWrongInputList,
3801 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003802 TosaErrorValidator.evWrongRank,
3803 TosaErrorValidator.evRankMismatch,
3804 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003805 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003806 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003807 # Data nodes
3808 "const": {
3809 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003810 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003811 "build_fcn": (
3812 build_const,
3813 TosaTensorGen.tgBasic,
3814 TosaTensorValuesGen.tvgDefault,
3815 None,
3816 ),
Luke Hutton65872422023-02-20 10:33:04 +00003817 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003819 "identity": {
3820 "op": Op.IDENTITY,
3821 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003822 "build_fcn": (
3823 build_unary,
3824 TosaTensorGen.tgBasic,
3825 TosaTensorValuesGen.tvgDefault,
3826 None,
3827 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003828 "types": TYPE_FIB,
3829 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003830 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003831 "gather": {
3832 "op": Op.GATHER,
3833 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3834 "operands": (1, 0),
3835 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003836 "build_fcn": (
3837 build_gather,
3838 TosaTensorGen.tgBasic,
3839 TosaTensorValuesGen.tvgDefault,
3840 None,
3841 ),
James Ward24dbc422022-10-19 12:20:31 +01003842 "types": (
3843 DType.INT8,
3844 DType.INT16,
3845 DType.INT32,
3846 DType.FP16,
3847 DType.BF16,
3848 DType.FP32,
3849 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003850 "error_if_validators": (
3851 TosaErrorValidator.evWrongInputType,
3852 TosaErrorValidator.evWrongOutputType,
3853 TosaErrorValidator.evWrongInputList,
3854 TosaErrorValidator.evWrongOutputList,
3855 TosaErrorValidator.evWrongRank,
3856 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003857 },
3858 "scatter": {
3859 "op": Op.SCATTER,
3860 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003861 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003862 "operands": (2, 0),
3863 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003864 "build_fcn": (
3865 build_scatter,
3866 TosaTensorGen.tgScatter,
3867 TosaTensorValuesGen.tvgDefault,
3868 None,
3869 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003870 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 "error_if_validators": (
3872 TosaErrorValidator.evWrongInputType,
3873 TosaErrorValidator.evWrongOutputType,
3874 TosaErrorValidator.evWrongInputList,
3875 TosaErrorValidator.evWrongOutputList,
3876 TosaErrorValidator.evWrongRank,
3877 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003878 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003879 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003880 "resize": {
3881 "op": Op.RESIZE,
3882 "operands": (1, 0),
3883 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003884 "build_fcn": (
3885 build_resize,
3886 TosaTensorGen.tgNHWC,
3887 TosaTensorValuesGen.tvgDefault,
3888 TosaArgGen.agResize,
3889 ),
James Ward24dbc422022-10-19 12:20:31 +01003890 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003891 "invalid_test_validators": (
3892 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003893 ),
3894 "error_if_validators": (
3895 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003896 TosaErrorValidator.evScaleSmallerEqualZero,
3897 TosaErrorValidator.evScaleNLargerMax,
3898 TosaErrorValidator.evScaleDLargerMax,
3899 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003900 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003901 TosaErrorValidator.evBorderSmallerMin,
3902 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003903 TosaErrorValidator.evWrongInputType,
3904 TosaErrorValidator.evWrongOutputType,
3905 TosaErrorValidator.evWrongRank,
3906 TosaErrorValidator.evWrongInputList,
3907 TosaErrorValidator.evWrongOutputList,
3908 TosaErrorValidator.evBatchMismatch,
3909 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003910 TosaErrorValidator.evResizeOutputShapeMismatch,
3911 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003912 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003913 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003914 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003915 "cast": {
3916 "op": Op.CAST,
3917 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003918 "build_fcn": (
3919 build_cast,
3920 TosaTensorGen.tgBasic,
3921 TosaTensorValuesGen.tvgDefault,
3922 TosaArgGen.agCast,
3923 ),
James Ward8b390432022-08-12 20:48:56 +01003924 "types": (
3925 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003926 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003927 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003928 DType.INT8,
3929 DType.INT16,
3930 DType.INT32,
3931 DType.BOOL,
3932 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003933 "error_if_validators": (
3934 TosaErrorValidator.evWrongInputType,
3935 TosaErrorValidator.evWrongOutputType,
3936 TosaErrorValidator.evWrongInputList,
3937 TosaErrorValidator.evWrongOutputList,
3938 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003939 },
3940 "rescale": {
3941 "op": Op.RESCALE,
3942 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003943 "build_fcn": (
3944 build_rescale,
3945 TosaTensorGen.tgBasic,
3946 TosaTensorValuesGen.tvgDefault,
3947 TosaArgGen.agRescale,
3948 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003949 "types": [
3950 DType.UINT8,
3951 DType.INT8,
3952 DType.INT16,
3953 DType.INT32,
3954 DType.INT48,
3955 DType.UINT16,
3956 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003957 "error_if_validators": (
3958 TosaErrorValidator.evInputZeroPointNotZero,
3959 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003960 TosaErrorValidator.evU16InputZeroPointNotValid,
3961 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003962 TosaErrorValidator.evScaleTrue,
3963 TosaErrorValidator.evScaleNotTrue,
3964 TosaErrorValidator.evWrongInputType,
3965 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003966 TosaErrorValidator.evWrongInputList,
3967 TosaErrorValidator.evWrongOutputList,
3968 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003969 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003970 # Custom
3971 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003972 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003973 # Two varients of cond_if, one that generates one of two constant tensors (no
3974 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3975 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003976 "cond_if_const": {
3977 "op": Op.COND_IF,
3978 "operands": (0, 2),
3979 "build_fcn": (
3980 build_cond_if_const,
3981 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003982 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003983 TosaArgGen.agCondIf,
3984 ),
3985 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003986 "error_if_validators": (
3987 TosaErrorValidator.evOutputListThenGraphMismatch,
3988 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003989 TosaErrorValidator.evCondIfCondNotMatchingBool,
3990 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003991 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003992 },
3993 "cond_if_binary": {
3994 "op": Op.COND_IF,
3995 "operands": (2, 0),
3996 "build_fcn": (
3997 build_cond_if_binary,
3998 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003999 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004000 TosaArgGen.agCondIf,
4001 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004002 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004003 "error_if_validators": (
4004 TosaErrorValidator.evInputListThenGraphMismatch,
4005 TosaErrorValidator.evInputListElseGraphMismatch,
4006 TosaErrorValidator.evOutputListThenGraphMismatch,
4007 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004008 TosaErrorValidator.evCondIfCondNotMatchingBool,
4009 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004010 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004011 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004012 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004013 "while_loop": {
4014 "op": Op.WHILE_LOOP,
4015 "operands": (0, 1),
4016 "build_fcn": (
4017 build_while_loop,
4018 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004019 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004020 TosaArgGen.agWhileLoop,
4021 ),
4022 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004023 "error_if_validators": (
4024 TosaErrorValidator.evInputListOutputListMismatch,
4025 TosaErrorValidator.evInputListCondGraphMismatch,
4026 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4027 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4028 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004029 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004030 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004031 },
Luke Hutton57287132023-02-06 14:54:18 +00004032 "fft2d": {
4033 "op": Op.FFT2D,
4034 "operands": (2, 0),
4035 "rank": (3, 3),
4036 "build_fcn": (
4037 build_fft2d,
4038 TosaTensorGen.tgFFT2d,
4039 TosaTensorValuesGen.tvgDefault,
4040 TosaArgGen.agFFT2d,
4041 ),
4042 "types": [DType.FP32],
4043 "error_if_validators": (
4044 TosaErrorValidator.evWrongInputType,
4045 TosaErrorValidator.evWrongOutputType,
4046 TosaErrorValidator.evWrongInputList,
4047 TosaErrorValidator.evWrongOutputList,
4048 TosaErrorValidator.evWrongRank,
4049 TosaErrorValidator.evBatchMismatch,
4050 TosaErrorValidator.evKernelNotPowerOfTwo,
4051 TosaErrorValidator.evFFTInputShapeMismatch,
4052 TosaErrorValidator.evFFTOutputShapeMismatch,
4053 ),
4054 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004055 "rfft2d": {
4056 "op": Op.RFFT2D,
4057 "operands": (1, 0),
4058 "rank": (3, 3),
4059 "build_fcn": (
4060 build_rfft2d,
4061 TosaTensorGen.tgRFFT2d,
4062 TosaTensorValuesGen.tvgDefault,
4063 TosaArgGen.agNone,
4064 ),
4065 "types": [DType.FP32],
4066 "error_if_validators": (
4067 TosaErrorValidator.evWrongInputType,
4068 TosaErrorValidator.evWrongOutputType,
4069 TosaErrorValidator.evWrongInputList,
4070 TosaErrorValidator.evWrongOutputList,
4071 TosaErrorValidator.evWrongRank,
4072 TosaErrorValidator.evBatchMismatch,
4073 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004074 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004075 ),
4076 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004077 }
4078
Kevin Cheng550ccc52021-03-03 11:21:43 -08004079
Eric Kunzee5e26762020-10-13 16:11:07 -07004080class OutputShaper:
4081 # Methods in this class compute the expected output shape and datatype
4082 # for common classes of operations
4083 def __init__(self):
4084 pass
4085
4086 # These methods return arguments that can be used for
4087 # creating a new output tensor
4088 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004089 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4090 if error_name != ErrorIf.RankMismatch:
4091 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004092 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004093
4094 shape = []
4095 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004096 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004097 shape.append(b.shape[i])
4098 else:
4099 shape.append(a.shape[i])
4100
Jerry Ge135c9552023-05-23 20:59:32 +00004101 fuzz_idx = rng.integers(0, len(a.shape))
4102 if error_name == ErrorIf.DimensionMismatch:
4103 shape[fuzz_idx] += 1
4104
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004105 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004106 all_dtypes = [
4107 DType.INT8,
4108 DType.INT16,
4109 DType.INT32,
4110 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004111 DType.FP16,
4112 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004113 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004114 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004115 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4116 outputDType = rng.choice(wrong_dtypes)
4117 else:
4118 outputDType = a.dtype
4119
4120 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004121
4122 @staticmethod
4123 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004124 assert len(a.shape) == len(b.shape)
4125 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004126
4127 shape = []
4128 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004129 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004130 shape.append(a.shape[i])
4131
Kevin Cheng550ccc52021-03-03 11:21:43 -08004132 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004133
4134 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004135 def unaryOp(ser, rng, a, error_name=None):
4136 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004137 all_dtypes = [
4138 DType.INT8,
4139 DType.INT16,
4140 DType.INT32,
4141 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004142 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004143 DType.FP16,
4144 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004145 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004146 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4147 outputDType = rng.choice(wrong_dtypes)
4148 else:
4149 outputDType = a.dtype
4150
4151 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004152
4153 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004154 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004155 if error_name != ErrorIf.RankMismatch:
4156 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004157 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004158
4159 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004160 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004161 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004162 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4163 else:
4164 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004165
Jerry Ge135c9552023-05-23 20:59:32 +00004166 fuzz_idx = rng.integers(0, len(a.shape))
4167 if error_name == ErrorIf.DimensionMismatch:
4168 shape[fuzz_idx] += 1
4169
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004170 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004171 all_dtypes = [
4172 DType.INT8,
4173 DType.INT16,
4174 DType.INT32,
4175 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004176 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004177 DType.FP16,
4178 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004179 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004180 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4181 outputDType = rng.choice(wrong_dtypes)
4182 else:
4183 outputDType = a.dtype
4184
4185 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004186
4187 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004188 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004189 if error_name != ErrorIf.RankMismatch:
4190 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004191 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004192
4193 # Do broadcast
4194 shape = []
4195 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004196 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004197 shape.append(b.shape[i])
4198 else:
4199 shape.append(a.shape[i])
4200
Jerry Ge135c9552023-05-23 20:59:32 +00004201 fuzz_idx = rng.integers(0, len(a.shape))
4202 if error_name == ErrorIf.DimensionMismatch:
4203 shape[fuzz_idx] += 1
4204
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004205 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004206 wrong_dtypes = [
4207 DType.INT8,
4208 DType.INT16,
4209 DType.INT32,
4210 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004211 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004212 DType.FP16,
4213 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004214 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004215 outputDType = rng.choice(wrong_dtypes)
4216 else:
4217 outputDType = DType.BOOL
4218
4219 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004220
4221 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004222 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004223 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004224 if error_name not in [
4225 ErrorIf.AxisSmallerZero,
4226 ErrorIf.AxisLargerRank,
4227 ErrorIf.ShapeOfAxisNotOne,
4228 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004229 shape[axis] = 1
4230 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4231 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004232
Matthew Haddond6ce7252021-09-29 15:35:44 +01004233 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004234 all_dtypes = [
4235 DType.INT8,
4236 DType.INT16,
4237 DType.INT32,
4238 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004239 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004240 DType.FP16,
4241 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004242 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004243 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4244 outputDType = rng.choice(wrong_dtypes)
4245 else:
4246 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004247
Matthew Haddond6ce7252021-09-29 15:35:44 +01004248 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004249
4250 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004251 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004252 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004253
4254 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4255 del shape[axis]
4256
4257 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4258 remove = rng.choice([True, False])
4259 if remove and len(shape) > 1:
4260 del shape[0]
4261 else:
4262 shape.append(1)
4263 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4264 for i in range(len(shape)):
4265 shape[i] = shape[i] + rng.integers(1, 10)
4266
4267 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004268 all_dtypes = [
4269 DType.INT8,
4270 DType.INT16,
4271 DType.INT32,
4272 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004273 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004274 DType.FP16,
4275 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004276 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004277 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4278 outputDType = rng.choice(wrong_dtypes)
4279 else:
4280 outputDType = DType.INT32
4281
4282 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004283
4284 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004285 def conv2dOp(
4286 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4287 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004288
4289 # IFM: NHWC
4290 # Filter: OHWI
4291 # OFM: NHWC
4292
Kevin Cheng550ccc52021-03-03 11:21:43 -08004293 h = (
4294 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004295 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004296 + padding[0]
4297 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004298 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004299 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004300
Kevin Cheng550ccc52021-03-03 11:21:43 -08004301 w = (
4302 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004303 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004304 + padding[2]
4305 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004306 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004307 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004308
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004309 if error_name == ErrorIf.ConvOutputShapeMismatch:
4310 choices = [1, 2, 3]
4311 change = rng.choice(choices)
4312 # increment in multiples of stride to not hit non-integer error case
4313 if change in [1, 3]:
4314 h = h + (rng.choice(choices) * strides[0])
4315 if change in [2, 3]:
4316 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004317
Eric Kunzee5e26762020-10-13 16:11:07 -07004318 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4319
James Ward8b390432022-08-12 20:48:56 +01004320 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004321 # Pick some potentially correct output dtype if input type is incorrect
4322 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004323 else:
James Ward8b390432022-08-12 20:48:56 +01004324 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004325
4326 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004327 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004328 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004329 else:
4330 excludes = [out_dtype]
4331 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004332 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004333
Kevin Cheng550ccc52021-03-03 11:21:43 -08004334 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004335
4336 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004337 def conv3dOp(
4338 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4339 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004340
4341 # IFM: NDHWC
4342 # Filter: ODHWI
4343 # OFM: NDHWC
4344
4345 d = (
4346 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004347 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004348 + padding[0]
4349 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004350 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004351 ) // strides[0] + 1
4352
4353 h = (
4354 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004355 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004356 + padding[2]
4357 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004358 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004359 ) // strides[1] + 1
4360
4361 w = (
4362 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004363 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004364 + padding[4]
4365 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004366 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004367 ) // strides[2] + 1
4368
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004369 if error_name == ErrorIf.ConvOutputShapeMismatch:
4370 choices = [1, 2, 3, 4]
4371 change = rng.choice(choices)
4372 # increment in multiples of stride to not hit non-integer error case
4373 if change in [1, 4]:
4374 d = d + (rng.choice(choices) * strides[0])
4375 if change in [2, 4]:
4376 h = h + (rng.choice(choices) * strides[1])
4377 if change in [3, 4]:
4378 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004379
Kevin Cheng1533b852021-09-01 12:51:58 -07004380 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4381
James Ward8b390432022-08-12 20:48:56 +01004382 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004383 # Pick some potentially correct output dtype if input type is incorrect
4384 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004385 else:
James Ward8b390432022-08-12 20:48:56 +01004386 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004387
4388 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004389 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004390 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004391 else:
4392 excludes = [out_dtype]
4393 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004394 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004395
4396 return ser.addOutput(ofm_shape, out_dtype)
4397
4398 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004399 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004400 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004401 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004402 # IFM: NHWC
4403 # Filter: HWCM
4404 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004405
Kevin Cheng550ccc52021-03-03 11:21:43 -08004406 h = (
4407 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004408 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004409 + padding[0]
4410 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004411 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004412 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004413
Kevin Cheng550ccc52021-03-03 11:21:43 -08004414 w = (
4415 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004416 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004417 + padding[2]
4418 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004419 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004420 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004421
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004422 if error_name == ErrorIf.ConvOutputShapeMismatch:
4423 choices = [1, 2, 3]
4424 change = rng.choice(choices)
4425 # increment in multiples of stride to not hit non-integer error case
4426 if change in [1, 3]:
4427 h = h + (rng.choice(choices) * strides[0])
4428 if change in [2, 3]:
4429 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004430
Eric Kunzee5e26762020-10-13 16:11:07 -07004431 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4432
James Ward8b390432022-08-12 20:48:56 +01004433 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004434 # Pick some potentially correct output dtype if input type is incorrect
4435 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004436 else:
James Ward8b390432022-08-12 20:48:56 +01004437 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004438
4439 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004440 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004441 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004442 else:
4443 excludes = [out_dtype]
4444 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004445 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004446
Kevin Cheng550ccc52021-03-03 11:21:43 -08004447 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004448
4449 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004450 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004451 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004452 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004453 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004454 h = 1
4455 w = 1
4456 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004457 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4458 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004459
4460 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004461 choices = [1, 2, 3]
4462 change = rng.choice(choices)
4463 # increment in multiples of stride to not hit non-integer error case
4464 if change in [1, 3]:
4465 h = h + (rng.choice(choices) * stride[0])
4466 if change in [2, 3]:
4467 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004468 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004469
4470 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004471 all_dtypes = [
4472 DType.INT8,
4473 DType.INT16,
4474 DType.INT32,
4475 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004476 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004477 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004478 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004479 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004480 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4481 outputDType = rng.choice(wrong_dtypes)
4482 else:
4483 outputDType = ifm.dtype
4484
4485 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004486
4487 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004488 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004489 # input: N, IC
4490 # filter: OC, IC
4491 # output: N, OC
4492
4493 output_shape = [input.shape[0], filter.shape[0]]
4494
James Ward8b390432022-08-12 20:48:56 +01004495 # Validated in arg_gen (also invalidated for ErrorIf)
4496 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004497
Kevin Cheng550ccc52021-03-03 11:21:43 -08004498 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004499
4500 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004501 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004502 # a: N, H, C
4503 # b: N, C, W
4504 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004505
Kevin Cheng2d60f002021-06-09 14:18:32 -07004506 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004507
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004508 if error_name == ErrorIf.WrongOutputType:
4509 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004510 incorrect_types = (
4511 DType.INT4,
4512 DType.INT8,
4513 DType.INT16,
4514 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004515 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004516 DType.FP16,
4517 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004518 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004519 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004520 incorrect_types = (
4521 DType.INT4,
4522 DType.INT8,
4523 DType.INT16,
4524 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004525 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004526 DType.FP16,
4527 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004528 )
James Ward24dbc422022-10-19 12:20:31 +01004529 elif (
4530 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4531 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004532 incorrect_types = (
4533 DType.INT4,
4534 DType.INT8,
4535 DType.INT16,
4536 DType.INT32,
4537 DType.INT48,
4538 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004539 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004540 elif error_name == ErrorIf.WrongInputType:
4541 # Pick some potentially correct output dtype if input type is incorrect
4542 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004543 else:
James Ward8b390432022-08-12 20:48:56 +01004544 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004545
Kevin Cheng550ccc52021-03-03 11:21:43 -08004546 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004547
4548 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004549 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004550 input1 = a[0]
4551 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004552
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004553 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004554 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004555 if not (
4556 # unable to concat tensors of different ranks
4557 error_name == ErrorIf.ConcatInputRankMismatch
4558 # unable to concat tensors along an invalid axis
4559 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004560 ):
4561 for tensor in remaining_inputs:
4562 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004563
Matthew Haddon01c359d2021-10-15 16:30:48 +01004564 if error_name == ErrorIf.ConcatShapeSumMismatch:
4565 output_shape[axis] += rng.integers(5, 10)
4566
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004567 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004568 all_dtypes = {
4569 DType.INT8,
4570 DType.INT16,
4571 DType.INT32,
4572 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004573 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004574 DType.FP16,
4575 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004576 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004577 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4578 outputDType = rng.choice(wrong_dtypes)
4579 else:
4580 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004581
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004582 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004583
4584 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004585 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004586
4587 output_shape = a.shape.copy()
4588
4589 for i in range(len(output_shape)):
4590 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4591
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004592 if error_name == ErrorIf.PadOutputShapeMismatch:
4593 bad_dim = rng.choice(range(len(output_shape)))
4594 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004595 elif error_name == ErrorIf.RankMismatch:
4596 output_shape = get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004597
Matthew Haddone807aae2021-10-11 18:12:58 +01004598 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004599 all_dtypes = [
4600 DType.INT8,
4601 DType.INT16,
4602 DType.INT32,
4603 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004604 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004605 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004606 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004607 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004608 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4609 outputDType = rng.choice(wrong_dtypes)
4610 else:
4611 outputDType = a.dtype
4612
4613 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004614
4615 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004616 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004617 output_shape = shape.copy()
4618
Matthew Haddone807aae2021-10-11 18:12:58 +01004619 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4620 for i in range(len(output_shape)):
4621 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4622
4623 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004624 all_dtypes = [
4625 DType.INT8,
4626 DType.INT16,
4627 DType.INT32,
4628 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004629 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004630 DType.FP16,
4631 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004632 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004633 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4634 outputDType = rng.choice(wrong_dtypes)
4635 else:
4636 outputDType = a.dtype
4637
4638 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004639
4640 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004641 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004642
Matthew Haddone807aae2021-10-11 18:12:58 +01004643 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004644 all_dtypes = [
4645 DType.INT8,
4646 DType.INT16,
4647 DType.INT32,
4648 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004649 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004650 DType.FP16,
4651 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004652 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004653 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004654 outputDType = rng.choice(wrong_dtypes)
4655 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004656 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004657
Luke Huttona4e48ca2023-02-22 11:53:48 +00004658 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004659 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01004660 for index in range(len(output_shape)):
4661 if output_shape[index] <= 2:
4662 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4663 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004664 output_shape[index] = output_shape[index] + rng.choice(
4665 [-2, -1, 1, 2]
4666 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00004667 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
4668 output_shape = input.shape.copy()
4669 elif error_name == ErrorIf.RankMismatch:
4670 output_shape = get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01004671
4672 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004673
4674 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004675 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004676
4677 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004678 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004679
4680 for i in range(len(output_shape)):
4681 output_shape[i] = a.shape[i] * multiples[i]
4682
Luke Huttona4e48ca2023-02-22 11:53:48 +00004683 if error_name == ErrorIf.RankMismatch:
4684 output_shape = get_rank_mismatch_shape(rng, output_shape)
4685
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004686 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004687 all_dtypes = [
4688 DType.INT8,
4689 DType.INT16,
4690 DType.INT32,
4691 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004692 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004693 DType.FP16,
4694 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004695 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004696 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4697 outputDType = rng.choice(wrong_dtypes)
4698 else:
4699 outputDType = a.dtype
4700
4701 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004702
4703 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004704 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004705 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004706
Kevin Cheng550ccc52021-03-03 11:21:43 -08004707 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004708
Luke Huttona4e48ca2023-02-22 11:53:48 +00004709 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01004710 for i in range(len(output_shape)):
4711 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004712
Luke Huttona4e48ca2023-02-22 11:53:48 +00004713 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4714 for i in range(len(output_shape)):
4715 output_shape[i] += rng.integers(1, 10)
4716 elif error_name == ErrorIf.RankMismatch:
4717 output_shape = get_rank_mismatch_shape(rng, output_shape)
4718
Matthew Haddone807aae2021-10-11 18:12:58 +01004719 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004720 all_dtypes = [
4721 DType.INT8,
4722 DType.INT16,
4723 DType.INT32,
4724 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004725 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004726 DType.FP16,
4727 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004728 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004729 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4730 outputDType = rng.choice(wrong_dtypes)
4731 else:
4732 outputDType = a.dtype
4733
4734 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004735
4736 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004737 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004738 if error_name != ErrorIf.WrongRank:
4739 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004740 assert len(indices.shape) == 2
4741 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004742
Kevin Cheng77d0f762020-11-24 10:26:32 -08004743 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4744
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004745 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004746 all_dtypes = [
4747 DType.INT8,
4748 DType.INT16,
4749 DType.INT32,
4750 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004751 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004752 DType.FP16,
4753 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004754 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004755 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4756 outputDType = rng.choice(wrong_dtypes)
4757 else:
4758 outputDType = values.dtype
4759
4760 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004761
4762 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004763 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004764 if error_name != ErrorIf.WrongRank:
4765 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004766 assert len(indices.shape) == 2
4767 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004768 assert values_in.shape[0] == indices.shape[0] # N
4769 assert input.shape[1] == indices.shape[1] # W
4770 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004771
4772 output_shape = values_in.shape
4773
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004774 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004775 all_dtypes = [
4776 DType.INT8,
4777 DType.INT16,
4778 DType.INT32,
4779 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004780 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004781 DType.FP16,
4782 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004783 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004784 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4785 outputDType = rng.choice(wrong_dtypes)
4786 else:
4787 outputDType = values_in.dtype
4788
4789 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004790
4791 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004792 def tableOp(ser, rng, input, error_name=None):
4793 # Same shape as the input, dtype dependent on input dtype
4794 if error_name != ErrorIf.WrongInputType:
4795 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004796 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004797 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004798 wrong_dtypes = [
4799 DType.INT8,
4800 DType.INT16,
4801 DType.INT32,
4802 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004803 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004804 DType.FP16,
4805 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004806 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004807 wrong_dtypes.remove(output_dtype)
4808 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004809 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004810
4811 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004812 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004813 serializer,
4814 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004815 input,
4816 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004817 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004818 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004819 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004820 input_dtype,
4821 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004822 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004823 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004824 # Calculate OH, OW
4825 scale_y_n = scale[0]
4826 scale_y_d = scale[1]
4827 scale_x_n = scale[2]
4828 scale_x_d = scale[3]
4829 if error_name == ErrorIf.ScaleSmallerEqualZero:
4830 scale_y_n = max(scale_y_n, 1)
4831 scale_y_d = max(scale_y_d, 1)
4832 scale_x_n = max(scale_x_n, 1)
4833 scale_x_d = max(scale_x_d, 1)
4834
4835 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4836 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4837
4838 if error_name is not None:
4839 # Make sure the output tensor is valid, which can occur when
4840 # scale, offset or border have been changed for ERROR_IFs
4841 oh = max(oh, 1)
4842 ow = max(ow, 1)
4843 if error_name != ErrorIf.MaxDimExceeded:
4844 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4845 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4846
4847 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4848 choices = [1, 2, 3]
4849 change = rng.choice(choices)
4850 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4851 if change in [1, 3]:
4852 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4853 oh -= scale_y_d
4854 assert oh > 0 # Should have been caught in agResize
4855 else:
4856 oh += scale_y_d
4857 if change in [2, 3]:
4858 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4859 ow -= scale_x_d
4860 assert ow > 0 # Should have been caught in agResize
4861 else:
4862 ow += scale_x_d
4863
Matthew Haddon848efb42021-09-09 12:30:53 +01004864 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004865 output_dims = [
4866 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004867 oh,
4868 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004869 input.shape[0],
4870 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004871 elif error_name == ErrorIf.BatchMismatch:
4872 output_dims = [
4873 input.shape[0] + rng.integers(1, 10),
4874 oh,
4875 ow,
4876 input.shape[3],
4877 ]
4878 elif error_name == ErrorIf.ChannelMismatch:
4879 output_dims = [
4880 input.shape[0],
4881 oh,
4882 ow,
4883 input.shape[3] + rng.integers(1, 10),
4884 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004885 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004886 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004887
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004888 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004889
4890 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004891 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004892 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004893
4894 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004895 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004896 if error_name == ErrorIf.ConvOutputShapeMismatch:
4897 choices = [1, 2, 3]
4898 change = rng.choice(choices)
4899 if change in [1, 3]:
4900 output_shape[1] = output_shape[1] + rng.choice(choices)
4901 if change in [2, 3]:
4902 output_shape[2] = output_shape[2] + rng.choice(choices)
4903
James Ward8b390432022-08-12 20:48:56 +01004904 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004905 # Pick some potentially correct output dtype if input type is incorrect
4906 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004907 else:
James Ward8b390432022-08-12 20:48:56 +01004908 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004909
4910 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004911 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004912 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004913 else:
4914 excludes = [out_dtype]
4915 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004916 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004917
Kevin Cheng550ccc52021-03-03 11:21:43 -08004918 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00004919
4920 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00004921 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
4922 outputs = []
4923
4924 assert ifm1.dtype == ifm2.dtype
4925 input_dtype = ifm1.dtype
4926
4927 if error_name != ErrorIf.FFTInputShapeMismatch:
4928 assert ifm1.shape == ifm2.shape
4929
4930 input_shape = ifm1.shape
4931 if error_name != ErrorIf.WrongRank:
4932 assert len(input_shape) == 3
4933
4934 output_shape = input_shape.copy()
4935 output_dtype = input_dtype
4936
4937 if error_name == ErrorIf.WrongOutputType:
4938 excludes = [DType.FP32]
4939 wrong_dtypes = list(usableDTypes(excludes=excludes))
4940 output_dtype = rng.choice(wrong_dtypes)
4941 elif error_name == ErrorIf.BatchMismatch:
4942 output_shape[0] += rng.integers(1, 10)
4943 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4944 modify_dim = rng.choice([1, 2])
4945 output_shape[modify_dim] += rng.integers(1, 10)
4946
4947 outputs.append(serializer.addOutput(output_shape, output_dtype))
4948 outputs.append(serializer.addOutput(output_shape, output_dtype))
4949 return outputs
4950
4951 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00004952 def rfft2dOp(serializer, rng, value, error_name=None):
4953 outputs = []
4954
4955 input_shape = value.shape
4956 if error_name != ErrorIf.WrongRank:
4957 assert len(input_shape) == 3
4958
4959 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
4960
4961 output_dtype = value.dtype
4962 if error_name == ErrorIf.WrongOutputType:
4963 excludes = [DType.FP32]
4964 wrong_dtypes = list(usableDTypes(excludes=excludes))
4965 output_dtype = rng.choice(wrong_dtypes)
4966 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00004967 output_shape[0] += rng.integers(1, 10)
4968 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4969 modify_dim = rng.choice([1, 2])
4970 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00004971
4972 outputs.append(serializer.addOutput(output_shape, output_dtype))
4973 outputs.append(serializer.addOutput(output_shape, output_dtype))
4974 return outputs