blob: bd371ebb80d01a70095f72b8623de02ea469f7ab [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Luke Huttona4e48ca2023-02-22 11:53:48 +000017from generator.tosa_utils import get_rank_mismatch_shape
Jeremy Johnson05c711e2022-12-12 18:00:41 +000018from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010019from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010021from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
25
Eric Kunzee5e26762020-10-13 16:11:07 -070026class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010027 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000028 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010029 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010030 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010031 TOSA_8K_LEVEL_MAX_KERNEL = 8192
32 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010033
Eric Kunzee5e26762020-10-13 16:11:07 -070034 def __init__(self, args):
35 self.args = args
36 self.basePath = args.output_dir
37 self.random_seed = args.random_seed
38 self.ser = None
39 self.rng = np.random.default_rng(self.random_seed)
40 self.createDynamicOpLists()
41 self.initOpListDefaults()
42 self.quantGen = TosaQuantGen()
43 # Force makeShape to do a specific starting shape
44 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010045 # Work out floating point range
46 self.random_fp_low = min(args.tensor_fp_value_range)
47 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070048
49 def createSerializer(self, opName, testPath):
50 self.testPath = os.path.join(opName, testPath)
51
52 fullPath = os.path.join(self.basePath, self.testPath)
53 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010054 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070055
56 def getSerializer(self):
57 return self.ser
58
59 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080060 with open(
61 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
62 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070063 fd.write(self.ser.serialize())
64
Kevin Cheng550ccc52021-03-03 11:21:43 -080065 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
66 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070067
Matthew Haddon74567092021-07-16 15:38:20 +010068 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000069 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010070 seed = self.random_seed + 1
71 self.rng = np.random.default_rng(seed)
72
Eric Kunzee5e26762020-10-13 16:11:07 -070073 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070075 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070076 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070077 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070078 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070079 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010080 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
81 elif dtype == DType.UINT8:
82 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070083 elif dtype == DType.INT16:
84 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010085 elif dtype == DType.UINT16:
86 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070087 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080088 return np.int32(
89 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
90 )
Eric Kunzee5e26762020-10-13 16:11:07 -070091 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080092 return np.int64(
93 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
94 )
James Ward8b390432022-08-12 20:48:56 +010095 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010096 return np.float16(
97 self.rng.uniform(
98 low=self.random_fp_low, high=self.random_fp_high, size=shape
99 )
100 )
James Ward24dbc422022-10-19 12:20:31 +0100101 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100102 f32_tensor = np.float32(
103 self.rng.uniform(
104 low=self.random_fp_low, high=self.random_fp_high, size=shape
105 )
106 )
James Ward24dbc422022-10-19 12:20:31 +0100107 # Floor the last 16 bits of each f32 value
108 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100109 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100110 return np.float32(
111 self.rng.uniform(
112 low=self.random_fp_low, high=self.random_fp_high, size=shape
113 )
114 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700115 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
Kevin Cheng989cb052021-04-28 16:29:44 -0700118 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700119 placeholders = []
120
Kevin Cheng989cb052021-04-28 16:29:44 -0700121 assert len(shape_list) == len(dtype_list)
122
123 for idx, shape in enumerate(shape_list):
124 arr = self.getRandTensor(shape, dtype_list[idx])
125 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700126
127 return placeholders
128
Kevin Cheng989cb052021-04-28 16:29:44 -0700129 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700130 consts = []
131
Kevin Cheng989cb052021-04-28 16:29:44 -0700132 assert len(shape_list) == len(dtype_list)
133
134 for idx, shape in enumerate(shape_list):
135 arr = self.getRandTensor(shape, dtype_list[idx])
136 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700137
138 return consts
139
140 def makeShape(self, rank):
141 if self.targetted_shape:
142 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800143 return np.int32(
144 self.rng.integers(
145 low=self.args.tensor_shape_range[0],
146 high=self.args.tensor_shape_range[1],
147 size=rank,
148 )
149 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700150
151 def setTargetShape(self, shape):
152 self.targetted_shape = shape
153
154 def randInt(self, low=0, high=256):
155 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
156
157 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100158 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100159 return np.float32(
160 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
161 )
James Ward8b390432022-08-12 20:48:56 +0100162 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100163 return np.float16(
164 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
165 )
James Ward24dbc422022-10-19 12:20:31 +0100166 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100167 rand_f32 = np.float32(
168 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
169 )
James Ward24dbc422022-10-19 12:20:31 +0100170 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 elif dtype == DType.BOOL:
172 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700173 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700175 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700176 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100177 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 elif dtype == DType.INT16:
179 low, high = (-32768, 32768)
180 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800183 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 # Special size
185 return np.int64(self.rng.integers(low, high, size=1))[0]
186 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800187 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700188
189 return np.int32(self.rng.integers(low, high, size=1))[0]
190
191 def shapeStr(self, shape):
192
193 sStr = []
194 # Convert to strings
195 for i in shape:
196 sStr.append(str(i))
197
Kevin Cheng550ccc52021-03-03 11:21:43 -0800198 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700199
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100200 def typeStr(self, dtype):
201 if isinstance(dtype, list) or isinstance(dtype, tuple):
202 assert len(dtype) >= 2
203 strs = [self.typeStr(t) for t in dtype]
204 # Limit types to the first 2 as the 3rd is the accumulator
205 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100207 if dtype in DTYPE_ATTRIBUTES:
208 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700209 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100210 raise Exception(
211 "Unknown dtype, cannot convert to string: {}".format(dtype)
212 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700213
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100215 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100216 if dtype in DTYPE_ATTRIBUTES:
217 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700218 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100219 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700220
Luke Hutton57287132023-02-06 14:54:18 +0000221 def constrictBatchSize(self, shape):
222 # Limit the batch size unless an explicit target shape set
223 if self.args.max_batch_size and not self.args.target_shapes:
224 shape[0] = min(shape[0], self.args.max_batch_size)
225 return shape
226
James Ward30124a82023-02-02 14:56:33 +0000227 def makeDimension(self):
228 return self.randInt(
229 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
230 )
231
Eric Kunzee5e26762020-10-13 16:11:07 -0700232 # Argument generators
233 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
234 # Where the string descriptor is used to generate the test name and
235 # The build_fcn_arg_list is expanded and passed to the operator test
236 # build function
237
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100238 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
239 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
240
Matthew Haddon848efb42021-09-09 12:30:53 +0100241 # build_placeholder returns an int, ABS/other ops does not
242 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000243 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100244 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000245 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000246 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100247 return result_tens
248
249 # Ensure new output type has correct qinfo
250 if error_name == ErrorIf.WrongOutputType:
251 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000252 qinfo = [
253 TosaQuantGen.getZeroPoint(self, a.dtype),
254 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
255 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100256
257 # Invalidate Input/Output list for error if checks.
258 input_list = [a.name]
259 output_list = [result_tens.name]
260 pCount, cCount = op["operands"]
261 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000262 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
263 self, error_name, input_list, output_list
264 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100265
Les Bell729b0352021-11-24 10:28:21 +0000266 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100267 self.ser,
268 validator_fcns,
269 error_name,
270 op=op,
271 input_dtype=a.dtype,
272 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000273 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000274 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100275 input_list=input_list,
276 output_list=output_list,
277 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000278 ):
279 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100280
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000281 attr = None
282 if op["op"] == Op.NEGATE:
283 attr = ts.TosaSerializerAttribute()
284 attr.NegateAttribute(qinfo[0], qinfo[1])
285
286 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 return result_tens
288
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100289 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000290 result_tens = OutputShaper.binaryBroadcastOp(
291 self.ser, self.rng, a, b, error_name
292 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100293
294 # Invalidate Input/Output list for error if checks.
295 input_list = [a.name, b.name]
296 output_list = [result_tens.name]
297 pCount, cCount = op["operands"]
298 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000299 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
300 self, error_name, input_list, output_list
301 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100302
Les Bell729b0352021-11-24 10:28:21 +0000303 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100304 self.ser,
305 validator_fcns,
306 error_name,
307 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000308 input1=a,
309 input2=b,
310 input_dtype=a.dtype,
311 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000312 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100313 input_list=input_list,
314 output_list=output_list,
315 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000316 ):
317 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100318
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000319 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700320 return result_tens
321
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100322 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700323 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000324 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700325 return result_tens
326
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000327 def build_arithmetic_right_shift(
328 self, op, a, b, round, validator_fcns=None, error_name=None
329 ):
330 result_tens = OutputShaper.binaryBroadcastOp(
331 self.ser, self.rng, a, b, error_name
332 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100333
334 # Invalidate Input/Output list for error if checks.
335 input_list = [a.name, b.name]
336 output_list = [result_tens.name]
337 pCount, cCount = op["operands"]
338 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000339 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
340 self, error_name, input_list, output_list
341 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100342
Les Bell729b0352021-11-24 10:28:21 +0000343 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100344 self.ser,
345 validator_fcns,
346 error_name,
347 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000348 input1=a,
349 input2=b,
350 input_dtype=a.dtype,
351 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000352 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100353 input_list=input_list,
354 output_list=output_list,
355 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000356 ):
357 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800358
359 attr = ts.TosaSerializerAttribute()
360 attr.ArithmeticRightShiftAttribute(round)
361
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000362 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800363 return result_tens
364
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100365 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000366 result_tens = OutputShaper.binaryBroadcastOp(
367 self.ser, self.rng, a, b, error_name
368 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
370 # Special for multiply:
371 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100372 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700373 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100374 if error_name == ErrorIf.WrongOutputType:
375 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
376 outputDType = self.rng.choice(all_dtypes)
377 result_tens.setDtype(outputDType)
378
379 # Invalidate Input/Output list for error if checks.
380 input_list = [a.name, b.name]
381 output_list = [result_tens.name]
382 pCount, cCount = op["operands"]
383 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000384 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
385 self, error_name, input_list, output_list
386 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100387
Les Bell729b0352021-11-24 10:28:21 +0000388 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100389 self.ser,
390 validator_fcns,
391 error_name,
392 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 input1=a,
394 input2=b,
395 input_dtype=a.dtype,
396 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000397 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100398 input_list=input_list,
399 output_list=output_list,
400 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000401 ):
402 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700403
Kevin Chengaee1fac2020-11-11 13:54:06 -0800404 attr = ts.TosaSerializerAttribute()
405 attr.MulAttribute(shift)
406
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000407 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700408 return result_tens
409
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
411 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700412
Kevin Chengfe392ce2021-10-18 21:51:55 +0000413 attr = ts.TosaSerializerAttribute()
414 attr.TableAttribute(table)
415
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100416 # Invalidate Input/Output list for error if checks.
417 input_list = [a.name]
418 output_list = [result_tens.name]
419 pCount, cCount = op["operands"]
420 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000421 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
422 self, error_name, input_list, output_list
423 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100424
Les Bell729b0352021-11-24 10:28:21 +0000425 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100426 self.ser,
427 validator_fcns,
428 error_name,
429 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000430 input_shape=a.shape,
431 input_dtype=a.dtype,
432 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000433 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100434 input_list=input_list,
435 output_list=output_list,
436 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000437 ):
438 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100439
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700441
442 return result_tens
443
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100444 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
445 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
446
447 # Invalidate Input/Output list for error if checks.
448 input_list = [cond.name, a.name, b.name]
449 output_list = [result_tens.name]
450 pCount, cCount = op["operands"]
451 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000452 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
453 self, error_name, input_list, output_list
454 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100455
Les Bell729b0352021-11-24 10:28:21 +0000456 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457 self.ser,
458 validator_fcns,
459 error_name,
460 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000461 input1=cond,
462 input2=a,
463 input3=b,
464 input_shape=a.shape,
465 input_dtype=a.dtype,
466 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000467 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100468 input_list=input_list,
469 output_list=output_list,
470 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000471 ):
472 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100473
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 self.ser.addOperator(
475 op["op"],
476 input_list,
477 output_list,
478 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 return result_tens
480
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100481 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000482 result_tens = OutputShaper.binaryComparisonOp(
483 self.ser, self.rng, a, b, error_name
484 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100485
486 # Invalidate Input/Output list for error if checks.
487 input_list = [a.name, b.name]
488 output_list = [result_tens.name]
489 pCount, cCount = op["operands"]
490 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000491 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
492 self, error_name, input_list, output_list
493 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100494
Les Bell729b0352021-11-24 10:28:21 +0000495 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100496 self.ser,
497 validator_fcns,
498 error_name,
499 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 input1=a,
501 input2=b,
502 input_shape=a.shape,
503 input_dtype=a.dtype,
504 output_shape=result_tens.shape,
505 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000506 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100507 input_list=input_list,
508 output_list=output_list,
509 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000510 ):
511 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100512
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000513 self.ser.addOperator(
514 op["op"],
515 input_list,
516 output_list,
517 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700518 return result_tens
519
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100520 def build_argmax(self, op, a, axis, validator_fcns, error_name):
521 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
522
523 # Invalidate Input/Output list for error if checks.
524 input_list = [a.name]
525 output_list = [result_tens.name]
526 pCount, cCount = op["operands"]
527 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
529 self, error_name, input_list, output_list
530 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100531
Les Bell729b0352021-11-24 10:28:21 +0000532 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100533 self.ser,
534 validator_fcns,
535 error_name,
536 op=op,
537 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000538 input_shape=a.shape,
539 input_dtype=a.dtype,
540 output_shape=result_tens.shape,
541 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000542 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100543 input_list=input_list,
544 output_list=output_list,
545 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000546 ):
547 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
549 attr = ts.TosaSerializerAttribute()
550 attr.AxisAttribute(axis)
551
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000552 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700553 return result_tens
554
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000555 def build_pool2d(
556 self,
557 op,
558 input,
James Ward8b390432022-08-12 20:48:56 +0100559 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000560 stride,
561 pad,
562 kernel,
563 validator_fcns=None,
564 error_name=None,
565 qinfo=None,
566 ):
567 result_tens = OutputShaper.pool2dOp(
568 self.ser, self.rng, input, kernel, stride, pad, error_name
569 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100570
571 # Ensure new output type has correct qinfo
572 if error_name == ErrorIf.WrongInputType:
573 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000574 qinfo = [
575 TosaQuantGen.getZeroPoint(self, input.dtype),
576 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
577 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100578
579 # Invalidate Input/Output list for error if checks.
580 input_list = [input.name]
581 output_list = [result_tens.name]
582 pCount, cCount = op["operands"]
583 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
585 self, error_name, input_list, output_list
586 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100587
Les Bell729b0352021-11-24 10:28:21 +0000588 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100589 self.ser,
590 validator_fcns,
591 error_name,
592 op=op,
593 input_shape=input.shape,
594 input_dtype=input.dtype,
595 output_shape=result_tens.shape,
596 output_dtype=result_tens.dtype,
597 kernel=kernel,
598 stride=stride,
599 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000601 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100602 input_list=input_list,
603 output_list=output_list,
604 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000605 ):
606 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700607
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000608 if qinfo is None:
609 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700610
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000611 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100612 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000613
614 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700615 return result_tens
616
James Ward8b390432022-08-12 20:48:56 +0100617 def build_maxpool2d(
618 self,
619 op,
620 input,
621 stride,
622 pad,
623 kernel,
624 validator_fcns=None,
625 error_name=None,
626 qinfo=None,
627 ):
628 # Same as build_pool2d but manually sets accum_dtype value
629 # (maxpool has no accum_dtype)
630 return self.build_pool2d(
631 op,
632 input,
633 DType.UNKNOWN,
634 stride,
635 pad,
636 kernel,
637 validator_fcns,
638 error_name,
639 qinfo,
640 )
641
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000642 def build_conv2d(
643 self,
644 op,
645 ifm,
646 filter,
647 bias,
James Ward8b390432022-08-12 20:48:56 +0100648 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000649 strides,
650 padding,
651 dilations,
652 validator_fcns=None,
653 error_name=None,
654 qinfo=None,
655 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800656 assert len(padding) == 4
657 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100658 self.ser,
659 self.rng,
660 ifm,
661 filter,
662 accum_dtype,
663 strides,
664 padding,
665 dilations,
666 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000667 )
668
669 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000670 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
671 DType.INT8,
672 DType.UINT8,
673 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000674 qinfo = [
675 TosaQuantGen.getZeroPoint(self, ifm.dtype),
676 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
677 ]
Les Bell0e027d42021-11-09 14:42:14 +0000678
679 # Invalidate Input/Output list for error_if checks.
680 input_list = [ifm.name, filter.name, bias.name]
681 output_list = [result_tens.name]
682 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000683 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
684 self, error_name, input_list, output_list
685 )
Les Bell0e027d42021-11-09 14:42:14 +0000686
Les Bell729b0352021-11-24 10:28:21 +0000687 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000688 self.ser,
689 validator_fcns,
690 error_name,
691 op=op,
692 input_dtype=ifm.dtype,
693 weight_dtype=filter.dtype,
694 output_dtype=result_tens.dtype,
695 qinfo=qinfo,
696 input_list=input_list,
697 num_operands=num_operands,
698 output_list=output_list,
699 pad=padding,
700 stride=strides,
701 dilation=dilations,
702 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100703 weight_shape=filter.shape,
704 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000705 ):
706 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700707
708 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000709 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700710
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000711 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700712 return result_tens
713
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000714 def build_conv3d(
715 self,
716 op,
717 ifm,
718 filter,
719 bias,
James Ward8b390432022-08-12 20:48:56 +0100720 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000721 strides,
722 padding,
723 dilations,
724 validator_fcns=None,
725 error_name=None,
726 qinfo=None,
727 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700728 assert len(padding) == 6
729 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100730 self.ser,
731 self.rng,
732 ifm,
733 filter,
734 accum_dtype,
735 strides,
736 padding,
737 dilations,
738 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000739 )
740
741 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000742 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
743 DType.INT8,
744 DType.UINT8,
745 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000746 qinfo = [
747 TosaQuantGen.getZeroPoint(self, ifm.dtype),
748 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
749 ]
Les Bell0e027d42021-11-09 14:42:14 +0000750
751 # Invalidate Input/Output list for error_if checks.
752 input_list = [ifm.name, filter.name, bias.name]
753 output_list = [result_tens.name]
754 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000755 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
756 self, error_name, input_list, output_list
757 )
Les Bell0e027d42021-11-09 14:42:14 +0000758
Les Bell729b0352021-11-24 10:28:21 +0000759 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000760 self.ser,
761 validator_fcns,
762 error_name,
763 op=op,
764 input_dtype=ifm.dtype,
765 weight_dtype=filter.dtype,
766 output_dtype=result_tens.dtype,
767 qinfo=qinfo,
768 input_list=input_list,
769 num_operands=num_operands,
770 output_list=output_list,
771 pad=padding,
772 stride=strides,
773 dilation=dilations,
774 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100775 weight_shape=filter.shape,
776 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000777 ):
778 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700779
780 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000781 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700782
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000783 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700784 return result_tens
785
Kevin Cheng550ccc52021-03-03 11:21:43 -0800786 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 self,
788 op,
789 ifm,
790 filter,
791 bias,
James Ward8b390432022-08-12 20:48:56 +0100792 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000793 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700794 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000795 output_shape,
796 validator_fcns=None,
797 error_name=None,
798 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800799 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700800 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100802 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 )
Les Bell0e027d42021-11-09 14:42:14 +0000804
805 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000806 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
807 DType.INT8,
808 DType.UINT8,
809 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000810 qinfo = [
811 TosaQuantGen.getZeroPoint(self, ifm.dtype),
812 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
813 ]
Les Bell0e027d42021-11-09 14:42:14 +0000814
815 # Invalidate Input/Output list for error_if checks.
816 input_list = [ifm.name, filter.name, bias.name]
817 output_list = [result_tens.name]
818 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000819 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
820 self, error_name, input_list, output_list
821 )
Les Bell0e027d42021-11-09 14:42:14 +0000822
Les Bell729b0352021-11-24 10:28:21 +0000823 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000824 self.ser,
825 validator_fcns,
826 error_name,
827 op=op,
828 input_dtype=ifm.dtype,
829 weight_dtype=filter.dtype,
830 output_dtype=result_tens.dtype,
831 qinfo=qinfo,
832 input_list=input_list,
833 num_operands=num_operands,
834 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700835 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000836 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000837 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100838 weight_shape=filter.shape,
839 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000840 ):
841 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700842
843 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000844 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700845
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000846 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700847 return result_tens
848
Kevin Cheng550ccc52021-03-03 11:21:43 -0800849 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000850 self,
851 op,
852 ifm,
853 filter,
854 bias,
James Ward8b390432022-08-12 20:48:56 +0100855 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000856 strides,
857 padding,
858 dilations,
859 validator_fcns=None,
860 error_name=None,
861 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800862 ):
863 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100864 self.ser,
865 self.rng,
866 ifm,
867 filter,
868 accum_dtype,
869 strides,
870 padding,
871 dilations,
872 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000873 )
874
875 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000876 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
877 DType.INT8,
878 DType.UINT8,
879 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000880 qinfo = [
881 TosaQuantGen.getZeroPoint(self, ifm.dtype),
882 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
883 ]
Les Bell0e027d42021-11-09 14:42:14 +0000884
885 # Invalidate Input/Output list for error_if checks.
886 input_list = [ifm.name, filter.name, bias.name]
887 output_list = [result_tens.name]
888 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000889 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
890 self, error_name, input_list, output_list
891 )
Les Bell0e027d42021-11-09 14:42:14 +0000892
Les Bell729b0352021-11-24 10:28:21 +0000893 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000894 self.ser,
895 validator_fcns,
896 error_name,
897 op=op,
898 input_dtype=ifm.dtype,
899 weight_dtype=filter.dtype,
900 output_dtype=result_tens.dtype,
901 qinfo=qinfo,
902 input_list=input_list,
903 num_operands=num_operands,
904 output_list=output_list,
905 pad=padding,
906 stride=strides,
907 dilation=dilations,
908 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100909 weight_shape=filter.shape,
910 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000911 ):
912 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700913
914 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000915 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700916
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000917 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700918 return result_tens
919
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000920 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100921 self,
922 op,
923 ifm,
924 filter,
925 bias,
926 accum_dtype,
927 validator_fcns=None,
928 error_name=None,
929 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000930 ):
931 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100932 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000933 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100934
935 # Invalidate Input/Output list for error if checks.
936 input_list = [ifm.name, filter.name, bias.name]
937 output_list = [result_tens.name]
938 pCount, cCount = op["operands"]
939 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000940 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
941 self, error_name, input_list, output_list
942 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100943
Les Bell729b0352021-11-24 10:28:21 +0000944 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100945 self.ser,
946 validator_fcns,
947 error_name,
948 op=op,
949 input_shape=ifm.shape,
950 input_dtype=ifm.dtype,
951 weight_dtype=filter.dtype,
952 output_shape=result_tens.shape,
953 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000954 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000955 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100956 input_list=input_list,
957 output_list=output_list,
958 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100959 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000960 ):
961 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700962
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000963 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000964 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000965
966 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700967 return result_tens
968
James Ward8b390432022-08-12 20:48:56 +0100969 def build_matmul(
970 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
971 ):
972 result_tens = OutputShaper.matmulOp(
973 self.ser, self.rng, a, b, accum_dtype, error_name
974 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100975
976 # Invalidate Input/Output list for error if checks.
977 input_list = [a.name, b.name]
978 output_list = [result_tens.name]
979 pCount, cCount = op["operands"]
980 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000981 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
982 self, error_name, input_list, output_list
983 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100984
Les Bell729b0352021-11-24 10:28:21 +0000985 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100986 self.ser,
987 validator_fcns,
988 error_name,
989 op=op,
990 input_shape=a.shape,
991 input_dtype=a.dtype,
992 input2_shape=b.shape,
993 input2_dtype=b.dtype,
994 output_shape=result_tens.shape,
995 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000996 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000997 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100998 input_list=input_list,
999 output_list=output_list,
1000 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001001 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001002 ):
1003 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001004
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001005 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001006 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001007
1008 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001009 return result_tens
1010
Matthew Haddond6ce7252021-09-29 15:35:44 +01001011 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1012 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1013
1014 # Invalidate Input/Output list for error if checks.
1015 input_list = [a.name]
1016 output_list = [result_tens.name]
1017 pCount, cCount = op["operands"]
1018 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001019 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1020 self, error_name, input_list, output_list
1021 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001022
Les Bell729b0352021-11-24 10:28:21 +00001023 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001024 self.ser,
1025 validator_fcns,
1026 error_name,
1027 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 axis=axis,
1029 input_shape=a.shape,
1030 output_shape=result_tens.shape,
1031 input_dtype=a.dtype,
1032 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001033 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001034 input_list=input_list,
1035 output_list=output_list,
1036 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001037 ):
1038 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001039
1040 attr = ts.TosaSerializerAttribute()
1041 attr.AxisAttribute(axis)
1042
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001043 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001044 return result_tens
1045
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001046 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1047 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001048
Jeremy Johnson18e26662021-07-22 16:15:29 +01001049 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001050
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001051 if error_name == ErrorIf.MaxSmallerMin:
1052 # Make sure the numbers are different to invoke this error
1053 while v[0] == v[1]:
1054 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1055 max_val = min(v)
1056 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001057 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001058 max_val = max(v)
1059 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001060
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001061 # Invalidate Input/Output list for error if checks.
1062 input_list = [a.name]
1063 output_list = [result_tens.name]
1064 pCount, cCount = op["operands"]
1065 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001066 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1067 self, error_name, input_list, output_list
1068 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001069
Les Bell729b0352021-11-24 10:28:21 +00001070 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001071 self.ser,
1072 validator_fcns,
1073 error_name,
1074 op=op,
1075 max_val=max_val,
1076 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001077 input_shape=a.shape,
1078 output_shape=result_tens.shape,
1079 input_dtype=a.dtype,
1080 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001081 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001082 input_list=input_list,
1083 output_list=output_list,
1084 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001085 ):
1086 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001087
1088 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001089 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1090 if a.dtype == DType.FP16:
1091 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1092 min_val = min_val.astype(np.float32)
1093 max_val = max_val.astype(np.float32)
1094
1095 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001096 else:
James Ward34071252022-12-07 15:48:47 +00001097 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001098
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001099 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001100 return result_tens
1101
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001102 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1103 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001104 attr = ts.TosaSerializerAttribute()
1105
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001106 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001107
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001108 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001109 return result_tens
1110
1111 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001112 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1113 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001114
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001115 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001116 return result_tens
1117
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001118 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1119 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1120
1121 # Invalidate Input/Output list for error if checks.
1122 input_list = [a.name]
1123 output_list = [result_tens.name]
1124 pCount, cCount = op["operands"]
1125 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001126 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1127 self, error_name, input_list, output_list
1128 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001129
Les Bell729b0352021-11-24 10:28:21 +00001130 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001131 self.ser,
1132 validator_fcns,
1133 error_name,
1134 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001135 input_shape=a.shape,
1136 output_shape=result_tens.shape,
1137 input_dtype=a.dtype,
1138 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001139 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001140 input_list=input_list,
1141 output_list=output_list,
1142 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001143 ):
1144 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001145
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001146 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001147 return result_tens
1148
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001149 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1150 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1151
1152 # Invalidate Input/Output list for error if checks.
1153 input_list = [a.name]
1154 output_list = [result_tens.name]
1155 pCount, cCount = op["operands"]
1156 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001157 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1158 self, error_name, input_list, output_list
1159 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001160
Les Bell729b0352021-11-24 10:28:21 +00001161 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001162 self.ser,
1163 validator_fcns,
1164 error_name,
1165 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001166 input_shape=a.shape,
1167 output_shape=result_tens.shape,
1168 input_dtype=a.dtype,
1169 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001170 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001171 input_list=input_list,
1172 output_list=output_list,
1173 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001174 ):
1175 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001176
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001177 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001178 return result_tens
1179
Won Jeon78155c62023-06-10 00:20:04 +00001180 def build_erf(self, op, a, validator_fcns=None, error_name=None):
1181 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1182
1183 # Invalidate Input/Output list for error if checks.
1184 input_list = [a.name]
1185 output_list = [result_tens.name]
1186 pCount, cCount = op["operands"]
1187 num_operands = pCount + cCount
1188 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1189 self, error_name, input_list, output_list
1190 )
1191
1192 if not TosaErrorValidator.evValidateErrorIfs(
1193 self.ser,
1194 validator_fcns,
1195 error_name,
1196 op=op,
1197 input_shape=a.shape,
1198 output_shape=result_tens.shape,
1199 input_dtype=a.dtype,
1200 output_dtype=result_tens.dtype,
1201 result_tensors=[result_tens],
1202 input_list=input_list,
1203 output_list=output_list,
1204 num_operands=num_operands,
1205 ):
1206 return None
1207
1208 self.ser.addOperator(op["op"], input_list, output_list)
1209 return result_tens
1210
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001211 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1212 if error_name != ErrorIf.WrongInputType:
1213 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001214
1215 # To store variable length list of input tensors we need to store axis along with it
1216 axis = a[-1]
1217 a = a[:-1]
1218
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001219 result_tens = OutputShaper.concatOp(
1220 self.ser, self.rng, axis, *a, error_name=error_name
1221 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001222
Matthew Haddon818ab902021-07-27 09:12:49 +01001223 input_tensor_names = []
1224 for tensor in a:
1225 input_tensor_names.append(tensor.name)
1226
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001227 # Invalidate Input/Output list for error if checks.
1228 input_list = input_tensor_names
1229 output_list = [result_tens.name]
1230 pCount, cCount = op["operands"]
1231 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001232 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1233 self, error_name, input_list, output_list
1234 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001235
Les Bell729b0352021-11-24 10:28:21 +00001236 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001237 self.ser,
1238 validator_fcns,
1239 error_name,
1240 op=op,
1241 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001242 input_shape=a[0].shape,
1243 output_shape=result_tens.shape,
1244 input_dtype=a[0].dtype,
1245 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001246 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001247 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001248 input_list=input_list,
1249 output_list=output_list,
1250 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001251 ):
1252 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001253
1254 attr = ts.TosaSerializerAttribute()
1255 attr.AxisAttribute(axis)
1256
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001258 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001259
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001260 def build_pad(
1261 self,
1262 op,
1263 a,
1264 padding,
1265 pad_const_int,
1266 pad_const_float,
1267 validator_fcns=None,
1268 error_name=None,
1269 qinfo=None,
1270 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001271 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001272
Kevin Chengfe392ce2021-10-18 21:51:55 +00001273 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001274 attr.PadAttribute(
1275 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1276 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001277
Matthew Haddone807aae2021-10-11 18:12:58 +01001278 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001279 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001280 output_list = [result_tens.name]
1281 pCount, cCount = op["operands"]
1282 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001283 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1284 self, error_name, input_list, output_list
1285 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001286
Les Bell729b0352021-11-24 10:28:21 +00001287 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001288 self.ser,
1289 validator_fcns,
1290 error_name,
1291 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001292 input_shape=a.shape,
1293 output_shape=result_tens.shape,
1294 input_dtype=a.dtype,
1295 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001296 pad=padding,
1297 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001298 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001299 input_list=input_list,
1300 output_list=output_list,
1301 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001302 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001303 ):
1304 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001305
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001306 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001307 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001308
Matthew Haddone807aae2021-10-11 18:12:58 +01001309 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 result_tens = OutputShaper.reshapeOp(
1311 self.ser, self.rng, a, newShape, error_name
1312 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001313
1314 # Invalidate Input/Output list for error if checks.
1315 input_list = [a.name]
1316 output_list = [result_tens.name]
1317 pCount, cCount = op["operands"]
1318 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001319 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1320 self, error_name, input_list, output_list
1321 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001322
Les Bell729b0352021-11-24 10:28:21 +00001323 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001324 self.ser,
1325 validator_fcns,
1326 error_name,
1327 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001328 input_shape=a.shape,
1329 output_shape=result_tens.shape,
1330 input_dtype=a.dtype,
1331 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001332 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001333 input_list=input_list,
1334 output_list=output_list,
1335 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001336 ):
1337 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001338
1339 attr = ts.TosaSerializerAttribute()
1340 attr.ReshapeAttribute(newShape)
1341
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001342 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001343 return result_tens
1344
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001345 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1346 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1347
1348 # Invalidate Input/Output list for error if checks.
1349 input_list = [a.name]
1350 output_list = [result_tens.name]
1351 pCount, cCount = op["operands"]
1352 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001353 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1354 self, error_name, input_list, output_list
1355 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001356
Les Bell729b0352021-11-24 10:28:21 +00001357 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001358 self.ser,
1359 validator_fcns,
1360 error_name,
1361 op=op,
1362 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001363 input_shape=a.shape,
1364 output_shape=result_tens.shape,
1365 input_dtype=a.dtype,
1366 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001367 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001368 input_list=input_list,
1369 output_list=output_list,
1370 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001371 ):
1372 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001373
1374 attr = ts.TosaSerializerAttribute()
1375 attr.AxisAttribute(axis)
1376
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001377 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001378 return result_tens
1379
Matthew Haddone807aae2021-10-11 18:12:58 +01001380 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1381 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001382
Kevin Chengfe392ce2021-10-18 21:51:55 +00001383 attr = ts.TosaSerializerAttribute()
1384 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001385
Matthew Haddone807aae2021-10-11 18:12:58 +01001386 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001387 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001388 output_list = [result_tens.name]
1389 pCount, cCount = op["operands"]
1390 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001391 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1392 self, error_name, input_list, output_list
1393 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001394
Les Bell729b0352021-11-24 10:28:21 +00001395 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001396 self.ser,
1397 validator_fcns,
1398 error_name,
1399 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001400 input_shape=a.shape,
1401 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001402 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001403 input_dtype=a.dtype,
1404 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001405 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001406 input_list=input_list,
1407 output_list=output_list,
1408 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001409 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001410 ):
1411 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001412
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001413 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001414 return result_tens
1415
Matthew Haddone807aae2021-10-11 18:12:58 +01001416 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001417 result_tens = OutputShaper.sliceOp(
1418 self.ser, self.rng, a, start, size, error_name
1419 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001420
1421 # Invalidate Input/Output list for error if checks.
1422 input_list = [a.name]
1423 output_list = [result_tens.name]
1424 pCount, cCount = op["operands"]
1425 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001426 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1427 self, error_name, input_list, output_list
1428 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001429
Les Bell729b0352021-11-24 10:28:21 +00001430 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001431 self.ser,
1432 validator_fcns,
1433 error_name,
1434 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001435 input_shape=a.shape,
1436 output_shape=result_tens.shape,
1437 input_dtype=a.dtype,
1438 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001439 start=start,
1440 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001441 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001442 input_list=input_list,
1443 output_list=output_list,
1444 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001445 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001446 ):
1447 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001448
1449 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001450 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001451
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001452 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001453 return result_tens
1454
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001455 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1456 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1457
1458 # Invalidate Input/Output list for error if checks.
1459 input_list = [a.name]
1460 output_list = [result_tens.name]
1461 pCount, cCount = op["operands"]
1462 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001463 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1464 self, error_name, input_list, output_list
1465 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001466
Les Bell729b0352021-11-24 10:28:21 +00001467 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001468 self.ser,
1469 validator_fcns,
1470 error_name,
1471 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001472 input_shape=a.shape,
1473 output_shape=result_tens.shape,
1474 input_dtype=a.dtype,
1475 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001476 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001477 input_list=input_list,
1478 output_list=output_list,
1479 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001480 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001481 ):
1482 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001483
1484 attr = ts.TosaSerializerAttribute()
1485 attr.TileAttribute(multiples)
1486
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001487 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001488 return result_tens
1489
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001490 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001491
1492 # Create a new indicies tensor
1493 # here with data that doesn't exceed the dimensions of the values tensor
1494
Kevin Cheng550ccc52021-03-03 11:21:43 -08001495 K = values.shape[1] # K
1496 W = self.randInt(
1497 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1498 ) # W
1499 indicies_arr = np.int32(
1500 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1501 ) # (N, W)
1502 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001503
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001504 result_tens = OutputShaper.gatherOp(
1505 self.ser, self.rng, values, indicies, error_name
1506 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001507
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001508 # Invalidate Input/Output list for error if checks.
1509 input_list = [values.name, indicies.name]
1510 output_list = [result_tens.name]
1511 pCount, cCount = op["operands"]
1512 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001513 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1514 self, error_name, input_list, output_list
1515 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001516
Les Bell729b0352021-11-24 10:28:21 +00001517 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001518 self.ser,
1519 validator_fcns,
1520 error_name,
1521 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001522 input_shape=values.shape,
1523 output_shape=result_tens.shape,
1524 input_dtype=values.dtype,
1525 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001526 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001527 input_list=input_list,
1528 output_list=output_list,
1529 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001530 ):
1531 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001532
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001533 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001534
1535 return result_tens
1536
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001537 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001538
1539 # Create a new indicies tensor
1540 # here with data that doesn't exceed the dimensions of the values_in tensor
1541
Kevin Cheng550ccc52021-03-03 11:21:43 -08001542 K = values_in.shape[1] # K
1543 W = input.shape[1] # W
1544 indicies_arr = np.int32(
1545 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1546 ) # (N, W)
1547 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001548
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001549 result_tens = OutputShaper.scatterOp(
1550 self.ser, self.rng, values_in, indicies, input, error_name
1551 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001552
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001553 # Invalidate Input/Output list for error if checks.
1554 input_list = [values_in.name, indicies.name, input.name]
1555 output_list = [result_tens.name]
1556 pCount, cCount = op["operands"]
1557 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001558 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1559 self, error_name, input_list, output_list
1560 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001561
Les Bell729b0352021-11-24 10:28:21 +00001562 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001563 self.ser,
1564 validator_fcns,
1565 error_name,
1566 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001567 input_shape=values_in.shape,
1568 output_shape=result_tens.shape,
1569 input_dtype=values_in.dtype,
1570 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001571 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001572 input_list=input_list,
1573 output_list=output_list,
1574 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001575 ):
1576 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001577
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001578 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001579
Kevin Cheng77d0f762020-11-24 10:26:32 -08001580 return result_tens
1581
Kevin Cheng550ccc52021-03-03 11:21:43 -08001582 def build_resize(
1583 self,
1584 op,
1585 input,
1586 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001587 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001588 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001589 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001590 input_dtype,
1591 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001592 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001593 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001594 ):
1595 result_tens = OutputShaper.resizeOp(
1596 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001597 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001598 input,
1599 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001600 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001601 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001602 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001603 input_dtype,
1604 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001605 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001606 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001607
Matthew Haddon848efb42021-09-09 12:30:53 +01001608 # Invalidate Input/Output list for error if checks.
1609 input_list = [input.name]
1610 output_list = [result_tens.name]
1611 pCount, cCount = op["operands"]
1612 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001613 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1614 self, error_name, input_list, output_list
1615 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001616
Les Bell729b0352021-11-24 10:28:21 +00001617 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001618 self.ser,
1619 validator_fcns,
1620 error_name,
1621 op=op,
1622 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001623 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001624 input_dtype=input_dtype,
1625 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001626 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001627 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001628 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001629 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001630 input_list=input_list,
1631 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001632 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001633 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001634 ):
1635 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001636
Eric Kunzee5e26762020-10-13 16:11:07 -07001637 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001638
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001639 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001640
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001641 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001642 return result_tens
1643
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001644 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1645 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1646 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001647 self.ser.addOperator(
1648 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1649 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001650 return result_tens
1651
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001652 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001653 self.ser.addOutputTensor(val)
1654 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001655
1656 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001657 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001658 result_tens = OutputShaper.typeConversionOp(
1659 self.ser, self.rng, val, out_dtype, error_name
1660 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001661
1662 # Invalidate Input/Output list for error if checks.
1663 input_list = [val.name]
1664 output_list = [result_tens.name]
1665 pCount, cCount = op["operands"]
1666 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001667 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1668 self, error_name, input_list, output_list
1669 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001670
Les Bell729b0352021-11-24 10:28:21 +00001671 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001672 self.ser,
1673 validator_fcns,
1674 error_name,
1675 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001676 input_shape=val.shape,
1677 output_shape=result_tens.shape,
1678 input_dtype=val.dtype,
1679 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001680 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001681 input_list=input_list,
1682 output_list=output_list,
1683 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001684 ):
1685 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001686
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001687 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001688 return result_tens
1689
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001690 def build_rescale(
1691 self,
1692 op,
1693 val,
1694 out_dtype,
1695 scale32,
1696 double_round,
1697 per_channel,
1698 validator_fcns,
1699 error_name,
1700 ):
1701 result_tens = OutputShaper.typeConversionOp(
1702 self.ser, self.rng, val, out_dtype, error_name
1703 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001704
1705 if per_channel:
1706 nc = val.shape[-1]
1707 else:
1708 nc = 1
1709
1710 in_type_width = self.typeWidth(val.dtype)
1711 out_type_width = self.typeWidth(out_dtype)
1712
Kevin Cheng3a478572021-01-22 17:21:02 -08001713 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001714 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001715 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001716 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001717 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001718 in_type_width += 1
1719 elif error_name in [
1720 ErrorIf.InputZeroPointNotZero,
1721 ErrorIf.U16InputZeroPointNotValid,
1722 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001723 input_zp = self.randInt(-128, 128)
1724 if input_zp == 0:
1725 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001726 in_type_width += 1
1727 elif val.dtype == DType.UINT16:
1728 # Must come after ErrorIf.U16InputZeroPointNotValid check
1729 input_zp = self.rng.choice([0, 32768])
1730 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001731 else:
1732 input_zp = 0
1733
Kevin Cheng3a478572021-01-22 17:21:02 -08001734 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001735 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001736 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001737 elif out_dtype == DType.UINT8:
1738 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001739 out_type_width += 1
1740 elif error_name in [
1741 ErrorIf.OutputZeroPointNotZero,
1742 ErrorIf.U16OutputZeroPointNotValid,
1743 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001744 output_zp = self.randInt(-128, 128)
1745 if output_zp == 0:
1746 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001747 out_type_width += 1
1748 elif out_dtype == DType.UINT16:
1749 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1750 output_zp = self.rng.choice([0, 32768])
1751 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001752 else:
1753 output_zp = 0
1754
1755 # Calculate scale based on:
1756 # scale = a *(2^output_width)/(2^input_width))
1757
1758 a = np.float32(self.rng.random(size=[nc]))
1759 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1760
1761 if scale32:
1762 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001763 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001764 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1765 else:
1766 # Cap the scaling at 2^15 - 1 for scale16
1767 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1768
Kevin Cheng550ccc52021-03-03 11:21:43 -08001769 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001770
1771 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1772 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001773 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1774 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001775
1776 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001777 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1778 scale_arr[i], scale32
1779 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001780 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1781 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001782
Kevin Cheng550ccc52021-03-03 11:21:43 -08001783 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001784 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001785 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001786 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001787 assert val.placeholderFilename
1788 values = np.load(
1789 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1790 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001791 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1792 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1793 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1794 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001795 if not np.all(np.array_equal(values, val_adj)):
1796 # Values changed so overwrite file with new values
1797 np.save(
1798 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1799 val_adj,
1800 False,
1801 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001802
Matthew Haddonc2025212021-10-08 21:21:05 +01001803 # Invalidate Input/Output list for error if checks.
1804 input_list = [val.name]
1805 output_list = [result_tens.name]
1806 pCount, cCount = op["operands"]
1807 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001808 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1809 self, error_name, input_list, output_list
1810 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001811
1812 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001813 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001814 self.ser,
1815 validator_fcns,
1816 error_name,
1817 op=op,
1818 input_dtype=val.dtype,
1819 output_dtype=out_dtype,
1820 input_shape=val.shape,
1821 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001822 scale32=scale32,
1823 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001824 input_list=input_list,
1825 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001826 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001827 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001828 ):
1829 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001830
Eric Kunzee5e26762020-10-13 16:11:07 -07001831 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001832 attr.RescaleAttribute(
1833 input_zp,
1834 output_zp,
1835 multiplier_arr,
1836 shift_arr,
1837 scale32,
1838 double_round,
1839 per_channel,
1840 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001841
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001842 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001843 return result_tens
1844
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001845 def _get_condition_tensor(self, op, cond, error_name):
1846 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1847 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1848 else:
1849 cond_type = DType.BOOL
1850 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1851 choice = self.rng.choice([1, 2])
1852 if choice == 1:
1853 cond_shape = [2]
1854 else:
1855 cond_shape = [1, 2]
1856 else:
1857 # Must be of size 1 (rank 0)
1858 cond_shape = []
1859 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1860 return cond_tens
1861
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001862 def build_cond_if_const(
1863 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1864 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001865 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001866 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001867 # and fill them with const nodes for the body.
1868
1869 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001870 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001871
1872 # Make then/else tensors
1873 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001874
1875 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001876 if error_name in [
1877 ErrorIf.CondIfOutputListThenGraphMismatch,
1878 ErrorIf.CondIfOutputListElseGraphMismatch,
1879 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001880 incorrect_shape = deepcopy(then_tens.shape)
1881 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001882 incorrect_shape[i] += (
1883 self.rng.choice([-3, -2, 2, 3])
1884 if incorrect_shape[i] > 3
1885 else self.rng.choice([1, 2, 4])
1886 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001887 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1888
Jeremy Johnson18e26662021-07-22 16:15:29 +01001889 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1890 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001891
1892 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001893 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001894
1895 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001896 then_block = "THEN_BLOCK"
1897 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001898 attr = ts.TosaSerializerAttribute()
1899 attr.CondIfAttribute(then_block, else_block)
1900
1901 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001902 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001903
Jerry Ge9e94af82022-10-27 09:57:00 -07001904 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001905 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001906 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1907 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1908 else:
1909 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001910 self.ser.addOutputTensor(then_tens)
1911
Jerry Ge9e94af82022-10-27 09:57:00 -07001912 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001913 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1914 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1915 else:
1916 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001917 self.ser.addOutputTensor(else_tens)
1918
Les Bell729b0352021-11-24 10:28:21 +00001919 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001920 self.ser,
1921 validator_fcns,
1922 error_name,
1923 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001924 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001925 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001926 ):
1927 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001928
Eric Kunzee5e26762020-10-13 16:11:07 -07001929 return result_tens
1930
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001931 def build_cond_if_binary(
1932 self, op, a, b, cond, validator_fcns=None, error_name=None
1933 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001934 # For cond_if with a binary op in the then/else blocks, take a and b and
1935 # alternately add or subtract them based on the condition
1936
1937 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001938 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001939
Kevin Cheng550ccc52021-03-03 11:21:43 -08001940 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001941
1942 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001943 then_block = "THEN_BLOCK"
1944 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001945 attr = ts.TosaSerializerAttribute()
1946 attr.CondIfAttribute(then_block, else_block)
1947
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001948 if error_name in [
1949 ErrorIf.CondIfInputListThenGraphMismatch,
1950 ErrorIf.CondIfInputListElseGraphMismatch,
1951 ErrorIf.CondIfOutputListElseGraphMismatch,
1952 ErrorIf.CondIfOutputListThenGraphMismatch,
1953 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001954 incorrect_shape = a.shape.copy()
1955 for i in range(len(incorrect_shape)):
1956 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1957 incorrect_block_input = deepcopy(a)
1958 incorrect_block_input.shape = incorrect_shape
1959
Eric Kunzee5e26762020-10-13 16:11:07 -07001960 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001961 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001962 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001963 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001964
James Ward24dbc422022-10-19 12:20:31 +01001965 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001966 then_op, else_op = Op.ADD, Op.SUB
1967 elif a.dtype in (DType.INT8, DType.INT16):
1968 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1969 else:
1970 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001971
Les Bell6040b4d2021-10-11 12:50:31 +01001972 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001973 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001974 if (
1975 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1976 and block == then_block
1977 ) or (
1978 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1979 and block == else_block
1980 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001981 self.ser.addInputTensor(incorrect_block_input)
1982 self.ser.addInputTensor(b)
1983 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001984 elif (
1985 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1986 and block == then_block
1987 ) or (
1988 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1989 and block == else_block
1990 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001991 self.ser.addInputTensor(a)
1992 self.ser.addInputTensor(b)
1993 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1994 else:
1995 self.ser.addInputTensor(a)
1996 self.ser.addInputTensor(b)
1997 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001998 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001999
Les Bell729b0352021-11-24 10:28:21 +00002000 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002001 self.ser,
2002 validator_fcns,
2003 error_name,
2004 op=op,
2005 a=a,
2006 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002007 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002008 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002009 ):
2010 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002011
Eric Kunzee5e26762020-10-13 16:11:07 -07002012 return result_tens
2013
Matthew Haddon630c17c2021-10-14 15:05:41 +01002014 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002015 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002016
Kevin Cheng550ccc52021-03-03 11:21:43 -08002017 cond_block = "COND_BLOCK"
2018 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002019
2020 attr = ts.TosaSerializerAttribute()
2021 attr.WhileLoopAttribute(cond_block, body_block)
2022
2023 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002024 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002025 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002026 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002027
2028 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002029 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2030 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002031 if error_name == ErrorIf.InputListOutputListMismatch:
2032 incorrect_acc = deepcopy(acc)
2033 for i in range(len(incorrect_acc.shape)):
2034 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2035 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2036 else:
2037 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002038
2039 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002040 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002041 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002042 [iter.name, a.name, acc.name],
2043 [iter_out.name, a_out.name, acc_out.name],
2044 attr,
2045 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002046 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002047
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002048 if error_name in [
2049 ErrorIf.InputListCondGraphMismatch,
2050 ErrorIf.InputListBodyGraphInputMismatch,
2051 ErrorIf.InputListBodyGraphOutputMismatch,
2052 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002053 incorrect_iter = deepcopy(iter)
2054 for i in range(len(incorrect_iter.shape)):
2055 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2056 if len(incorrect_iter.shape) == 0:
2057 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2058
2059 incorrect_acc = deepcopy(acc)
2060 for i in range(len(incorrect_acc.shape)):
2061 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2062
Eric Kunzee5e26762020-10-13 16:11:07 -07002063 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002064 self.ser.addBasicBlock(cond_block)
2065
Matthew Haddon630c17c2021-10-14 15:05:41 +01002066 if error_name == ErrorIf.InputListCondGraphMismatch:
2067 self.ser.addInputTensor(incorrect_iter)
2068 self.ser.addInputTensor(a)
2069 self.ser.addInputTensor(incorrect_acc)
2070 else:
2071 self.ser.addInputTensor(iter)
2072 self.ser.addInputTensor(a)
2073 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002074 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002075
2076 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002077 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002078 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002079 cond_type = DType.BOOL
2080 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2081 choice = self.rng.choice([1, 2])
2082 if choice == 1:
2083 cond_shape = [3]
2084 else:
2085 cond_shape = [1, 2]
2086 else:
2087 cond_shape = []
2088 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002089
Kevin Cheng550ccc52021-03-03 11:21:43 -08002090 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002091
2092 # BODY block (input: a, acc, iter, output: a, acc, iter)
2093 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002094 self.ser.addBasicBlock(body_block)
2095
Matthew Haddon630c17c2021-10-14 15:05:41 +01002096 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2097 self.ser.addInputTensor(incorrect_iter)
2098 self.ser.addInputTensor(a)
2099 self.ser.addInputTensor(incorrect_acc)
2100 else:
2101 self.ser.addInputTensor(iter)
2102 self.ser.addInputTensor(a)
2103 self.ser.addInputTensor(acc)
2104
Kevin Cheng550ccc52021-03-03 11:21:43 -08002105 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002106
2107 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002108 iter_body_out = self.ser.addIntermediate(
2109 incorrect_iter.shape, incorrect_iter.dtype
2110 )
2111 acc_body_out = self.ser.addIntermediate(
2112 incorrect_acc.shape, incorrect_acc.dtype
2113 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002114 else:
2115 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2116 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2117
Eric Kunzee5e26762020-10-13 16:11:07 -07002118 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2119 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2120 self.ser.addOutputTensor(iter_body_out)
2121 self.ser.addOutputTensor(a)
2122 self.ser.addOutputTensor(acc_body_out)
2123
Les Bell729b0352021-11-24 10:28:21 +00002124 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002125 self.ser,
2126 validator_fcns,
2127 error_name,
2128 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002129 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002130 ):
2131 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002132
Eric Kunzee5e26762020-10-13 16:11:07 -07002133 return acc_out
2134
Luke Hutton57287132023-02-06 14:54:18 +00002135 def build_fft2d(
2136 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2137 ):
2138 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2139
2140 input_names = [val1.name, val2.name]
2141 pCount, cCount = op["operands"]
2142 num_operands = pCount + cCount
2143
2144 output_names = [res.name for res in results]
2145 output_shapes = [res.shape for res in results]
2146 output_dtypes = [res.dtype for res in results]
2147
2148 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2149 self, error_name, input_names, output_names
2150 )
2151
2152 if not TosaErrorValidator.evValidateErrorIfs(
2153 self.ser,
2154 validator_fcns,
2155 error_name,
2156 op=op,
2157 inverse=inverse,
2158 input1=val1,
2159 input2=val2,
2160 input_shape=val1.shape,
2161 input_dtype=val1.dtype,
2162 output_shape=output_shapes,
2163 output_dtype=output_dtypes,
2164 result_tensors=results,
2165 input_list=input_names,
2166 output_list=output_names,
2167 num_operands=num_operands,
2168 ):
2169 return None
2170
2171 attr = ts.TosaSerializerAttribute()
2172 attr.FFTAttribute(inverse)
2173
2174 self.ser.addOperator(op["op"], input_names, output_names, attr)
2175 return results
2176
Luke Hutton261b7b62023-01-10 14:50:31 +00002177 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2178 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2179
2180 input_names = [val.name]
2181 pCount, cCount = op["operands"]
2182 num_operands = pCount + cCount
2183
2184 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002185 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002186 output_dtypes = [res.dtype for res in results]
2187
2188 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2189 self, error_name, input_names, output_names
2190 )
2191
2192 if not TosaErrorValidator.evValidateErrorIfs(
2193 self.ser,
2194 validator_fcns,
2195 error_name,
2196 op=op,
2197 input_shape=val.shape,
2198 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002199 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002200 output_dtype=output_dtypes,
2201 result_tensors=results,
2202 input_list=input_names,
2203 output_list=output_names,
2204 num_operands=num_operands,
2205 ):
2206 return None
2207
2208 self.ser.addOperator(op["op"], input_names, output_names)
2209 return results
2210
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002211 def create_filter_lists(
2212 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2213 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002214 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2215 default_test_rank_range = range(1, 5)
2216 if not shapeFilter:
2217 shapeFilter = [None]
2218
2219 # Calculate the filters based on what is requested and what the operator allows
2220 rmin, rmax = op["rank"]
2221 if rankFilter is not None:
2222 cleanRankFilter = []
2223 # Ensure rankFilter values are allowed by operator
2224 for rank in rankFilter:
2225 if rank >= rmin and rank <= rmax:
2226 cleanRankFilter.append(rank)
2227 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002228 # Ensure default behaviour is bounded by default range or by operator,
2229 # whichever is the smaller range of ranks.
2230 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002231 cleanRankFilter = (
2232 opRankRange
2233 if len(opRankRange) <= len(default_test_rank_range)
2234 else default_test_rank_range
2235 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002236 else:
2237 cleanRankFilter = range(rmin, rmax + 1)
2238
2239 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002240
Matthew Haddon1c00b712021-10-01 15:51:03 +01002241 if dtypeFilter is not None:
2242 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002243 # Create list of operator dtypes filtered by requested dtypes
2244 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002245 if dtype in dtypeFilter or (
2246 isinstance(dtype, list) and dtype[0] in dtypeFilter
2247 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002248 cleanDtypeFilter.append(dtype)
2249 else:
2250 cleanDtypeFilter = dtypes
2251
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002252 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002253 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002254 "shapeFilter": shapeFilter,
2255 "rankFilter": cleanRankFilter,
2256 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002257 }
2258 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002259 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002260 if validator is not None:
2261 validator_info = validator(check=False, op=op)
2262 else:
2263 return None
2264
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002265 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002266
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002267 # Set parameters as required
2268 if error_arguments["rank"] is not None:
2269 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002270 else:
2271 rankFilter = cleanRankFilter
2272
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002273 if error_arguments["dtype"] is not None:
2274 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002275 else:
2276 dtypeFilter = cleanDtypeFilter
2277
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002278 if error_arguments["shape"] is not None:
2279 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002280 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002281 shapeFilter = shapeFilter[
2282 :2
2283 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002284
2285 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002286 "shapeFilter": shapeFilter,
2287 "rankFilter": rankFilter,
2288 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002289 }
2290 return filterDict
2291
Kevin Cheng550ccc52021-03-03 11:21:43 -08002292 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002293 self,
2294 opName,
2295 shapeFilter=[None],
2296 rankFilter=None,
2297 dtypeFilter=None,
2298 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002299 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002300
2301 try:
2302 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002303 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002304 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002305
2306 # Initialize a new random number generator
2307 self.rng = np.random.default_rng(self.random_seed)
2308
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002309 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002310
Eric Kunzee5e26762020-10-13 16:11:07 -07002311 # Test list consists of a tuple of:
2312 # (opName, testNameStr, dtype, shapeList, argumentsList)
2313 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002314 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002315 error_if_validators = op["error_if_validators"]
2316 else:
2317 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002318
Matthew Haddon1c00b712021-10-01 15:51:03 +01002319 for validator in error_if_validators:
2320 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002321 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002322 else:
2323 error_name = None
2324
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002325 filterDict = self.create_filter_lists(
2326 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2327 )
2328 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002329 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002330 cleanRankFilter = filterDict["rankFilter"]
2331 cleanDtypeFilter = filterDict["dtypeFilter"]
2332 cleanShapeFilter = filterDict["shapeFilter"]
2333 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002334
2335 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002336 for t in cleanDtypeFilter:
2337 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002338 # Filter out by rank
2339 if shape is not None and len(shape) != r:
2340 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002341 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002342 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002343
Matthew Haddon74567092021-07-16 15:38:20 +01002344 shapeStr = self.shapeStr(shapeList[0])
2345 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002346
Matthew Haddon74567092021-07-16 15:38:20 +01002347 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2348 argList = []
2349 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002350 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002351 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002352 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002353
Matthew Haddon74567092021-07-16 15:38:20 +01002354 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002355 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002356 if argStr:
2357 testStr = "{}_{}_{}_{}".format(
2358 opName, shapeStr, typeStr, argStr
2359 )
2360 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002361 testStr = "{}_{}_{}".format(
2362 opName, shapeStr, typeStr
2363 )
2364 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002365 if argStr:
2366 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2367 opName, error_name, shapeStr, typeStr, argStr
2368 )
2369 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002370 testStr = "{}_ERRORIF_{}_{}_{}".format(
2371 opName, error_name, shapeStr, typeStr
2372 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002373
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002374 testList.append(
2375 (opName, testStr, t, error_name, shapeList, args)
2376 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002377
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002378 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002379 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2380 if "invalid_test_validators" in op:
2381 invalid_test_validators = op["invalid_test_validators"]
2382 clean_testList = []
2383 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002384 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002385 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002386 if validator_fcn(
2387 opName=test[0],
2388 input_dtype=test[2],
2389 shapeList=test[4],
2390 args=test[5],
2391 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002392 remove_test = True
2393 if not remove_test:
2394 clean_testList.append(test)
2395 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002396
2397 return testList
2398
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002399 def serializeTest(
2400 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2401 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002402 try:
2403 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002404 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002405 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002406
Jeremy Johnson0c716862023-04-13 17:18:19 +01002407 if self.args.verbose:
2408 print(f"Creating {testStr}")
2409
Eric Kunzee5e26762020-10-13 16:11:07 -07002410 # Create a serializer
2411 self.createSerializer(opName, testStr)
2412
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002413 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002414 if "error_if_validators" in op:
2415 error_if_validators = op["error_if_validators"]
2416 else:
2417 error_if_validators = None
2418
Kevin Cheng550ccc52021-03-03 11:21:43 -08002419 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002420 num_operands = pCount + cCount
2421
2422 if isinstance(dtype_or_dtypeList, list):
2423 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002424 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002425 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002426 else:
2427 dtypeList = [dtype_or_dtypeList] * (num_operands)
2428
Kevin Cheng93a16282021-08-31 16:14:03 -07002429 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002430 assert (
2431 len(shapeList) == num_operands
2432 ), "shapeList length {} must match number of operands {}".format(
2433 len(shapeList), num_operands
2434 )
2435 assert (
2436 len(dtypeList) == num_operands
2437 ), "dtypeList length {} must match number of operands {}".format(
2438 len(dtypeList), num_operands
2439 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002440
2441 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002442 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002443 except KeyError:
2444 qgen = None
2445
2446 # Build the random tensor operands and the test
2447 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002448
Matthew Haddon1c00b712021-10-01 15:51:03 +01002449 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002450 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002451 else:
2452 qinfo = None
2453
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002454 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002455
Matthew Haddon1c00b712021-10-01 15:51:03 +01002456 try:
2457 if error_if_validators is None:
2458 if qinfo is not None:
2459 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2460 else:
2461 resultName = build_fcn(self, op, *tens, *testArgs)
2462 else:
2463 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002464 resultName = build_fcn(
2465 self,
2466 op,
2467 *tens,
2468 *testArgs,
2469 validator_fcns=error_if_validators,
2470 error_name=error_name,
2471 qinfo=qinfo,
2472 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002473 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002474 resultName = build_fcn(
2475 self,
2476 op,
2477 *tens,
2478 *testArgs,
2479 validator_fcns=error_if_validators,
2480 error_name=error_name,
2481 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002482 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002483 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002484 raise e
2485
Les Bell729b0352021-11-24 10:28:21 +00002486 if resultName:
2487 # The test is valid, serialize it
2488 self.serialize("test")
2489 else:
2490 # The test is not valid
2491 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002492
Eric Kunzee5e26762020-10-13 16:11:07 -07002493 def createDynamicOpLists(self):
2494
Jeremy Johnson00423432022-09-12 17:27:37 +01002495 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2496 # Already created these lists (can occur when class is initialized more than once)
2497 return
2498
Eric Kunzee5e26762020-10-13 16:11:07 -07002499 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002500 if not self.args.level8k:
2501 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2502 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2503 else:
2504 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2505 KERNELS_2D = [[1, bigK], [bigK, 2]]
2506 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002507
Kevin Cheng1533b852021-09-01 12:51:58 -07002508 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002509 testName = "conv2d_{}x{}".format(k[0], k[1])
2510 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2511 self.TOSA_OP_LIST[testName]["filter"] = k
2512 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002513
Kevin Cheng550ccc52021-03-03 11:21:43 -08002514 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2515 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2516 "depthwise_conv2d_TEMPLATE"
2517 ].copy()
2518 self.TOSA_OP_LIST[testName]["filter"] = k
2519 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002520
Kevin Cheng550ccc52021-03-03 11:21:43 -08002521 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2522 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2523 "transpose_conv2d_TEMPLATE"
2524 ].copy()
2525 self.TOSA_OP_LIST[testName]["filter"] = k
2526 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002527
Kevin Cheng1533b852021-09-01 12:51:58 -07002528 for k in KERNELS_3D:
2529 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2530 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2531 self.TOSA_OP_LIST[testName]["filter"] = k
2532 self.TOSA_OP_LIST[testName]["template"] = False
2533
Eric Kunzee5e26762020-10-13 16:11:07 -07002534 # Delete any templates after having created any dynamic ops
2535 # This is a two-pass operation because it's bad practice to delete
2536 # keys from dictionaries while iterating
2537 keyList = []
2538 for k in self.TOSA_OP_LIST:
2539 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002540 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002541 keyList.append(k)
2542 continue
2543 except KeyError:
2544 pass
2545
2546 for k in keyList:
2547 del self.TOSA_OP_LIST[k]
2548
2549 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002550 """Fill in default fields for ops if they aren't already specified.
2551 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002552 for op in self.TOSA_OP_LIST:
2553
2554 # Required fields
2555 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002556 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002557 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002558 raise Exception(
2559 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2560 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002561
2562 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002563 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002564 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002565 raise Exception(
2566 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2567 op
2568 )
2569 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002570
2571 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002572 _ = self.TOSA_OP_LIST[op]["types"]
2573 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002574 raise Exception(
2575 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2576 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002577
2578 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002579 _ = self.TOSA_OP_LIST[op]["op"]
2580 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002581 raise Exception(
2582 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2583 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002584
2585 # Put in default rank range, if missing
2586 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002587 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002588 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002589 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002590
2591 # Tensor operator list
2592 # 'op': op name
2593 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002594 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2595 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002596 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2597 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002598 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002599
Kevin Cheng550ccc52021-03-03 11:21:43 -08002600 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002601 TYPE_INT_FP = [
2602 DType.INT8,
2603 DType.INT16,
2604 DType.INT32,
2605 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002606 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002607 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002608 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002609
Kevin Cheng550ccc52021-03-03 11:21:43 -08002610 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002611 TYPE_FI32 = [
2612 DType.FP32,
2613 DType.FP16,
2614 DType.BF16,
2615 DType.INT32,
2616 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002617 TYPE_FIB = [
2618 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002619 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002620 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002621 DType.INT8,
2622 DType.INT16,
2623 DType.INT32,
2624 DType.BOOL,
2625 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002626 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002627
James Ward24dbc422022-10-19 12:20:31 +01002628 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002629
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002630 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002631 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002632 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002633 [DType.INT8, DType.INT8, DType.INT32],
2634 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002635 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002636 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002637 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002638 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002639 ]
2640
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002641 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002642
2643 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002644 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002645 "argmax": {
2646 "op": Op.ARGMAX,
2647 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00002648 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002649 "build_fcn": (
2650 build_argmax,
2651 TosaTensorGen.tgBasic,
2652 TosaTensorValuesGen.tvgDefault,
2653 TosaArgGen.agAxis,
2654 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002655 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002656 "error_if_validators": (
2657 TosaErrorValidator.evAxisSmallerZero,
2658 TosaErrorValidator.evAxisLargerRank,
2659 TosaErrorValidator.evArgmaxOutputRankMismatch,
2660 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2661 TosaErrorValidator.evWrongRank,
2662 TosaErrorValidator.evWrongInputType,
2663 TosaErrorValidator.evWrongOutputType,
2664 TosaErrorValidator.evWrongInputList,
2665 TosaErrorValidator.evWrongOutputList,
2666 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002667 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002668 "avg_pool2d": {
2669 "op": Op.AVG_POOL2D,
2670 "operands": (1, 0),
2671 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002672 "build_fcn": (
2673 build_pool2d,
2674 TosaTensorGen.tgNHWC,
2675 TosaTensorValuesGen.tvgDefault,
2676 TosaArgGen.agPooling,
2677 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002678 "qgen": TosaQuantGen.qgUnary,
2679 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002680 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002681 "error_if_validators": (
2682 TosaErrorValidator.evKernelSmallerOne,
2683 TosaErrorValidator.evStrideSmallerOne,
2684 TosaErrorValidator.evPadSmallerZero,
2685 TosaErrorValidator.evWrongRank,
2686 TosaErrorValidator.evWrongInputType,
2687 TosaErrorValidator.evWrongOutputType,
2688 TosaErrorValidator.evWrongInputList,
2689 TosaErrorValidator.evWrongOutputList,
2690 TosaErrorValidator.evInputZeroPointNotZero,
2691 TosaErrorValidator.evOutputZeroPointNotZero,
2692 TosaErrorValidator.evPadLargerEqualKernel,
2693 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002694 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002695 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002696 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002697 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002698 "conv2d_TEMPLATE": {
2699 "op": Op.CONV2D,
2700 "operands": (1, 2),
2701 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002702 "build_fcn": (
2703 build_conv2d,
2704 TosaTensorGen.tgConv2D,
2705 TosaTensorValuesGen.tvgDefault,
2706 TosaArgGen.agConv,
2707 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002708 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002709 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002710 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2711 "error_if_validators": (
2712 TosaErrorValidator.evWrongInputType,
2713 TosaErrorValidator.evWrongOutputType,
2714 TosaErrorValidator.evWrongInputList,
2715 TosaErrorValidator.evWrongOutputList,
2716 TosaErrorValidator.evInputZeroPointNotZero,
2717 TosaErrorValidator.evWeightZeroPointNotZero,
2718 TosaErrorValidator.evPadSmallerZero,
2719 TosaErrorValidator.evStrideSmallerOne,
2720 TosaErrorValidator.evDilationSmallerOne,
2721 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002722 TosaErrorValidator.evConvOutputShapeMismatch,
2723 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002724 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002725 "template": True,
2726 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002727 # Templated operator. Filled in by createDynamicOpLists
2728 "conv3d_TEMPLATE": {
2729 "op": Op.CONV3D,
2730 "operands": (1, 2),
2731 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002732 "build_fcn": (
2733 build_conv3d,
2734 TosaTensorGen.tgConv3D,
2735 TosaTensorValuesGen.tvgDefault,
2736 TosaArgGen.agConv,
2737 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002738 "qgen": TosaQuantGen.qgConv,
2739 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002740 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2741 "error_if_validators": (
2742 TosaErrorValidator.evWrongInputType,
2743 TosaErrorValidator.evWrongOutputType,
2744 TosaErrorValidator.evWrongInputList,
2745 TosaErrorValidator.evWrongOutputList,
2746 TosaErrorValidator.evInputZeroPointNotZero,
2747 TosaErrorValidator.evWeightZeroPointNotZero,
2748 TosaErrorValidator.evPadSmallerZero,
2749 TosaErrorValidator.evStrideSmallerOne,
2750 TosaErrorValidator.evDilationSmallerOne,
2751 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002752 TosaErrorValidator.evConvOutputShapeMismatch,
2753 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002754 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002755 "template": True,
2756 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002757 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002758 "depthwise_conv2d_TEMPLATE": {
2759 "op": Op.DEPTHWISE_CONV2D,
2760 "operands": (1, 2),
2761 "filter": [1, 1],
2762 "rank": (4, 4),
2763 "build_fcn": (
2764 build_depthwise_conv2d,
2765 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002766 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002767 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002768 ),
2769 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002770 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002771 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2772 "error_if_validators": (
2773 TosaErrorValidator.evWrongInputType,
2774 TosaErrorValidator.evWrongOutputType,
2775 TosaErrorValidator.evWrongInputList,
2776 TosaErrorValidator.evWrongOutputList,
2777 TosaErrorValidator.evInputZeroPointNotZero,
2778 TosaErrorValidator.evWeightZeroPointNotZero,
2779 TosaErrorValidator.evPadSmallerZero,
2780 TosaErrorValidator.evStrideSmallerOne,
2781 TosaErrorValidator.evDilationSmallerOne,
2782 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002783 TosaErrorValidator.evConvOutputShapeMismatch,
2784 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002785 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002786 "template": True,
2787 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002788 "fully_connected": {
2789 "op": Op.FULLY_CONNECTED,
2790 "operands": (1, 2),
2791 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002792 "build_fcn": (
2793 build_fully_connected,
2794 TosaTensorGen.tgFullyConnected,
2795 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002796 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002797 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002798 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002799 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002800 "error_if_validators": (
2801 TosaErrorValidator.evInputZeroPointNotZero,
2802 TosaErrorValidator.evWeightZeroPointNotZero,
2803 TosaErrorValidator.evWrongRank,
2804 TosaErrorValidator.evWrongInputType,
2805 TosaErrorValidator.evWrongOutputType,
2806 TosaErrorValidator.evWrongInputList,
2807 TosaErrorValidator.evWrongOutputList,
2808 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002809 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002810 "matmul": {
2811 "op": Op.MATMUL,
2812 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002813 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002814 "build_fcn": (
2815 build_matmul,
2816 TosaTensorGen.tgMatmul,
2817 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002818 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002819 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002820 "qgen": TosaQuantGen.qgMatmul,
2821 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002822 "error_if_validators": (
2823 TosaErrorValidator.evInputZeroPointNotZero,
2824 TosaErrorValidator.evWrongRank,
2825 TosaErrorValidator.evWrongInputType,
2826 TosaErrorValidator.evWrongOutputType,
2827 TosaErrorValidator.evWrongInputList,
2828 TosaErrorValidator.evWrongOutputList,
2829 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002830 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002831 "max_pool2d": {
2832 "op": Op.MAX_POOL2D,
2833 "operands": (1, 0),
2834 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002835 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002836 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002837 TosaTensorGen.tgNHWC,
2838 TosaTensorValuesGen.tvgDefault,
2839 TosaArgGen.agPooling,
2840 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002841 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002842 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002843 "error_if_validators": (
2844 TosaErrorValidator.evKernelSmallerOne,
2845 TosaErrorValidator.evStrideSmallerOne,
2846 TosaErrorValidator.evPadSmallerZero,
2847 TosaErrorValidator.evWrongRank,
2848 TosaErrorValidator.evWrongInputType,
2849 TosaErrorValidator.evWrongOutputType,
2850 TosaErrorValidator.evWrongInputList,
2851 TosaErrorValidator.evWrongOutputList,
2852 TosaErrorValidator.evPadLargerEqualKernel,
2853 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002854 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002855 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002856 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002857 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002858 "transpose_conv2d_TEMPLATE": {
2859 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002860 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002861 "rank": (4, 4),
2862 "build_fcn": (
2863 build_transpose_conv2d,
2864 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002865 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002866 TosaArgGen.agTransposeConv2D,
2867 ),
2868 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002869 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002870 "invalid_test_validators": (
2871 TosaInvalidValidator.ivHeightWidthInvalid,
2872 TosaInvalidValidator.ivNonPositiveOutputShape,
2873 ),
2874 "error_if_validators": (
2875 TosaErrorValidator.evWrongInputType,
2876 TosaErrorValidator.evWrongOutputType,
2877 TosaErrorValidator.evWrongInputList,
2878 TosaErrorValidator.evWrongOutputList,
2879 TosaErrorValidator.evInputZeroPointNotZero,
2880 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002881 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002882 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002883 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002884 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002885 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002886 "template": True,
2887 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002888 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002889 "clamp": {
2890 "op": Op.CLAMP,
2891 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002892 "build_fcn": (
2893 build_clamp,
2894 TosaTensorGen.tgBasic,
2895 TosaTensorValuesGen.tvgDefault,
2896 None,
2897 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002898 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002899 "error_if_validators": (
2900 TosaErrorValidator.evMaxSmallerMin,
2901 TosaErrorValidator.evWrongInputType,
2902 TosaErrorValidator.evWrongOutputType,
2903 TosaErrorValidator.evWrongInputList,
2904 TosaErrorValidator.evWrongOutputList,
2905 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002906 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002907 "sigmoid": {
2908 "op": Op.SIGMOID,
2909 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002910 "build_fcn": (
2911 build_sigmoid,
2912 TosaTensorGen.tgBasic,
2913 TosaTensorValuesGen.tvgDefault,
2914 None,
2915 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002916 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002917 "error_if_validators": (
2918 TosaErrorValidator.evWrongInputType,
2919 TosaErrorValidator.evWrongOutputType,
2920 TosaErrorValidator.evWrongInputList,
2921 TosaErrorValidator.evWrongOutputList,
2922 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002923 },
2924 "tanh": {
2925 "op": Op.TANH,
2926 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002927 "build_fcn": (
2928 build_tanh,
2929 TosaTensorGen.tgBasic,
2930 TosaTensorValuesGen.tvgDefault,
2931 None,
2932 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002933 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002934 "error_if_validators": (
2935 TosaErrorValidator.evWrongInputType,
2936 TosaErrorValidator.evWrongOutputType,
2937 TosaErrorValidator.evWrongInputList,
2938 TosaErrorValidator.evWrongOutputList,
2939 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002940 },
Won Jeon78155c62023-06-10 00:20:04 +00002941 "erf": {
2942 "op": Op.ERF,
2943 "operands": (1, 0),
2944 "build_fcn": (
2945 build_erf,
2946 TosaTensorGen.tgBasic,
2947 TosaTensorValuesGen.tvgDefault,
2948 None,
2949 ),
2950 "types": TYPE_FP,
2951 "error_if_validators": (
2952 TosaErrorValidator.evWrongInputType,
2953 TosaErrorValidator.evWrongOutputType,
2954 TosaErrorValidator.evWrongInputList,
2955 TosaErrorValidator.evWrongOutputList,
2956 ),
2957 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002958 # Elementwise Binary Operators
2959 "add": {
2960 "op": Op.ADD,
2961 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002962 "build_fcn": (
2963 build_binary_broadcast,
2964 TosaTensorGen.tgBroadcastFuzz,
2965 TosaTensorValuesGen.tvgAddSub,
2966 None,
2967 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002968 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002969 "error_if_validators": (
2970 TosaErrorValidator.evRankMismatch,
2971 TosaErrorValidator.evWrongInputType,
2972 TosaErrorValidator.evWrongOutputType,
2973 TosaErrorValidator.evWrongInputList,
2974 TosaErrorValidator.evWrongOutputList,
2975 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00002976 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002977 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002978 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002979 "arithmetic_right_shift": {
2980 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2981 "operands": (2, 0),
2982 "build_fcn": (
2983 build_arithmetic_right_shift,
2984 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002985 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002986 TosaArgGen.agArithmeticRightShift,
2987 ),
2988 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002989 "error_if_validators": (
2990 TosaErrorValidator.evRankMismatch,
2991 TosaErrorValidator.evWrongInputType,
2992 TosaErrorValidator.evWrongOutputType,
2993 TosaErrorValidator.evWrongInputList,
2994 TosaErrorValidator.evWrongOutputList,
2995 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00002996 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002997 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002998 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002999 "bitwise_and": {
3000 "op": Op.BITWISE_AND,
3001 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003002 "build_fcn": (
3003 build_binary_broadcast,
3004 TosaTensorGen.tgBroadcastFuzz,
3005 TosaTensorValuesGen.tvgDefault,
3006 None,
3007 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003008 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003009 "error_if_validators": (
3010 TosaErrorValidator.evRankMismatch,
3011 TosaErrorValidator.evWrongInputType,
3012 TosaErrorValidator.evWrongOutputType,
3013 TosaErrorValidator.evWrongInputList,
3014 TosaErrorValidator.evWrongOutputList,
3015 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003016 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003017 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003018 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003019 "bitwise_or": {
3020 "op": Op.BITWISE_OR,
3021 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003022 "build_fcn": (
3023 build_binary_broadcast,
3024 TosaTensorGen.tgBroadcastFuzz,
3025 TosaTensorValuesGen.tvgDefault,
3026 None,
3027 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003028 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003029 "error_if_validators": (
3030 TosaErrorValidator.evRankMismatch,
3031 TosaErrorValidator.evWrongInputType,
3032 TosaErrorValidator.evWrongOutputType,
3033 TosaErrorValidator.evWrongInputList,
3034 TosaErrorValidator.evWrongOutputList,
3035 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003036 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003037 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003038 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003039 "bitwise_xor": {
3040 "op": Op.BITWISE_XOR,
3041 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003042 "build_fcn": (
3043 build_binary_broadcast,
3044 TosaTensorGen.tgBroadcastFuzz,
3045 TosaTensorValuesGen.tvgDefault,
3046 None,
3047 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003048 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003049 "error_if_validators": (
3050 TosaErrorValidator.evRankMismatch,
3051 TosaErrorValidator.evWrongInputType,
3052 TosaErrorValidator.evWrongOutputType,
3053 TosaErrorValidator.evWrongInputList,
3054 TosaErrorValidator.evWrongOutputList,
3055 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003056 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003057 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003058 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003059 "intdiv": {
3060 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003061 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003062 "build_fcn": (
3063 build_binary_broadcast,
3064 TosaTensorGen.tgBroadcastFuzz,
3065 TosaTensorValuesGen.tvgIntDiv,
3066 None,
3067 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003068 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003069 "error_if_validators": (
3070 TosaErrorValidator.evRankMismatch,
3071 TosaErrorValidator.evWrongInputType,
3072 TosaErrorValidator.evWrongOutputType,
3073 TosaErrorValidator.evWrongInputList,
3074 TosaErrorValidator.evWrongOutputList,
3075 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003076 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003077 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003078 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003079 "logical_and": {
3080 "op": Op.LOGICAL_AND,
3081 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003082 "build_fcn": (
3083 build_binary_broadcast,
3084 TosaTensorGen.tgBroadcastFuzz,
3085 TosaTensorValuesGen.tvgDefault,
3086 None,
3087 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003088 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003089 "error_if_validators": (
3090 TosaErrorValidator.evRankMismatch,
3091 TosaErrorValidator.evWrongInputType,
3092 TosaErrorValidator.evWrongOutputType,
3093 TosaErrorValidator.evWrongInputList,
3094 TosaErrorValidator.evWrongOutputList,
3095 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003096 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003097 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003098 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003099 "logical_left_shift": {
3100 "op": Op.LOGICAL_LEFT_SHIFT,
3101 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003102 "build_fcn": (
3103 build_binary_broadcast,
3104 TosaTensorGen.tgBroadcastFuzz,
3105 TosaTensorValuesGen.tvgLogicalShift,
3106 None,
3107 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003108 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003109 "error_if_validators": (
3110 TosaErrorValidator.evRankMismatch,
3111 TosaErrorValidator.evWrongInputType,
3112 TosaErrorValidator.evWrongOutputType,
3113 TosaErrorValidator.evWrongInputList,
3114 TosaErrorValidator.evWrongOutputList,
3115 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003116 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003117 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003118 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003119 "logical_right_shift": {
3120 "op": Op.LOGICAL_RIGHT_SHIFT,
3121 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003122 "build_fcn": (
3123 build_binary_broadcast,
3124 TosaTensorGen.tgBroadcastFuzz,
3125 TosaTensorValuesGen.tvgLogicalShift,
3126 None,
3127 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003128 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003129 "error_if_validators": (
3130 TosaErrorValidator.evRankMismatch,
3131 TosaErrorValidator.evWrongInputType,
3132 TosaErrorValidator.evWrongOutputType,
3133 TosaErrorValidator.evWrongInputList,
3134 TosaErrorValidator.evWrongOutputList,
3135 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003136 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003137 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003138 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003139 "logical_or": {
3140 "op": Op.LOGICAL_OR,
3141 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003142 "build_fcn": (
3143 build_binary_broadcast,
3144 TosaTensorGen.tgBroadcastFuzz,
3145 TosaTensorValuesGen.tvgDefault,
3146 None,
3147 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003148 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003149 "error_if_validators": (
3150 TosaErrorValidator.evRankMismatch,
3151 TosaErrorValidator.evWrongInputType,
3152 TosaErrorValidator.evWrongOutputType,
3153 TosaErrorValidator.evWrongInputList,
3154 TosaErrorValidator.evWrongOutputList,
3155 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003156 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003157 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003158 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003159 "logical_xor": {
3160 "op": Op.LOGICAL_XOR,
3161 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003162 "build_fcn": (
3163 build_binary_broadcast,
3164 TosaTensorGen.tgBroadcastFuzz,
3165 TosaTensorValuesGen.tvgDefault,
3166 None,
3167 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003168 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003169 "error_if_validators": (
3170 TosaErrorValidator.evRankMismatch,
3171 TosaErrorValidator.evWrongInputType,
3172 TosaErrorValidator.evWrongOutputType,
3173 TosaErrorValidator.evWrongInputList,
3174 TosaErrorValidator.evWrongOutputList,
3175 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003176 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003177 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003178 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003179 "maximum": {
3180 "op": Op.MAXIMUM,
3181 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003182 "build_fcn": (
3183 build_binary_broadcast,
3184 TosaTensorGen.tgBroadcastFuzz,
3185 TosaTensorValuesGen.tvgDefault,
3186 None,
3187 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003188 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003189 "error_if_validators": (
3190 TosaErrorValidator.evRankMismatch,
3191 TosaErrorValidator.evWrongInputType,
3192 TosaErrorValidator.evWrongOutputType,
3193 TosaErrorValidator.evWrongInputList,
3194 TosaErrorValidator.evWrongOutputList,
3195 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003196 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003197 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003198 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003199 "minimum": {
3200 "op": Op.MINIMUM,
3201 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003202 "build_fcn": (
3203 build_binary_broadcast,
3204 TosaTensorGen.tgBroadcastFuzz,
3205 TosaTensorValuesGen.tvgDefault,
3206 None,
3207 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003208 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003209 "error_if_validators": (
3210 TosaErrorValidator.evRankMismatch,
3211 TosaErrorValidator.evWrongInputType,
3212 TosaErrorValidator.evWrongOutputType,
3213 TosaErrorValidator.evWrongInputList,
3214 TosaErrorValidator.evWrongOutputList,
3215 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003216 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003217 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003218 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003219 "mul": {
3220 "op": Op.MUL,
3221 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003222 "build_fcn": (
3223 build_mul,
3224 TosaTensorGen.tgBroadcastFuzz,
3225 TosaTensorValuesGen.tvgMul,
3226 TosaArgGen.agMul,
3227 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003228 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003229 "error_if_validators": (
3230 TosaErrorValidator.evWrongInputType,
3231 TosaErrorValidator.evWrongOutputType,
3232 TosaErrorValidator.evWrongInputList,
3233 TosaErrorValidator.evWrongOutputList,
3234 TosaErrorValidator.evRankMismatch,
3235 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003236 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003237 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003238 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003239 "pow": {
3240 "op": Op.POW,
3241 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003242 "build_fcn": (
3243 build_binary_broadcast,
3244 TosaTensorGen.tgBroadcastFuzz,
3245 TosaTensorValuesGen.tvgDefault,
3246 None,
3247 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003248 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003249 "error_if_validators": (
3250 TosaErrorValidator.evRankMismatch,
3251 TosaErrorValidator.evWrongInputType,
3252 TosaErrorValidator.evWrongOutputType,
3253 TosaErrorValidator.evWrongInputList,
3254 TosaErrorValidator.evWrongOutputList,
3255 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003256 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003257 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003258 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003259 "sub": {
3260 "op": Op.SUB,
3261 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003262 "build_fcn": (
3263 build_binary_broadcast,
3264 TosaTensorGen.tgBroadcastFuzz,
3265 TosaTensorValuesGen.tvgAddSub,
3266 None,
3267 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003268 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003269 "error_if_validators": (
3270 TosaErrorValidator.evRankMismatch,
3271 TosaErrorValidator.evWrongInputType,
3272 TosaErrorValidator.evWrongOutputType,
3273 TosaErrorValidator.evWrongInputList,
3274 TosaErrorValidator.evWrongOutputList,
3275 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003276 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003277 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003278 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003279 "table": {
3280 "op": Op.TABLE,
3281 # Use the automatic generation functions to create the input array
3282 # but create the table tensor in the build function, as it may be
3283 # a different type from the input
3284 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003285 "build_fcn": (
3286 build_table,
3287 TosaTensorGen.tgBasic,
3288 TosaTensorValuesGen.tvgDefault,
3289 TosaArgGen.agTable,
3290 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003291 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003292 "error_if_validators": (
3293 TosaErrorValidator.evWrongInputType,
3294 TosaErrorValidator.evWrongOutputType,
3295 TosaErrorValidator.evWrongInputList,
3296 TosaErrorValidator.evWrongOutputList,
3297 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003298 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003299 # Elementwise Unary operators
3300 "abs": {
3301 "op": Op.ABS,
3302 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003303 "build_fcn": (
3304 build_unary,
3305 TosaTensorGen.tgBasic,
3306 TosaTensorValuesGen.tvgDefault,
3307 None,
3308 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003309 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003310 "error_if_validators": (
3311 TosaErrorValidator.evWrongInputType,
3312 TosaErrorValidator.evWrongOutputType,
3313 TosaErrorValidator.evWrongInputList,
3314 TosaErrorValidator.evWrongOutputList,
3315 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003316 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003317 "bitwise_not": {
3318 "op": Op.BITWISE_NOT,
3319 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003320 "build_fcn": (
3321 build_unary,
3322 TosaTensorGen.tgBasic,
3323 TosaTensorValuesGen.tvgDefault,
3324 None,
3325 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003326 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003327 "error_if_validators": (
3328 TosaErrorValidator.evWrongInputType,
3329 TosaErrorValidator.evWrongOutputType,
3330 TosaErrorValidator.evWrongInputList,
3331 TosaErrorValidator.evWrongOutputList,
3332 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003333 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003334 "ceil": {
3335 "op": Op.CEIL,
3336 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003337 "build_fcn": (
3338 build_unary,
3339 TosaTensorGen.tgBasic,
3340 TosaTensorValuesGen.tvgDefault,
3341 None,
3342 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003343 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003344 "error_if_validators": (
3345 TosaErrorValidator.evWrongInputType,
3346 TosaErrorValidator.evWrongOutputType,
3347 TosaErrorValidator.evWrongInputList,
3348 TosaErrorValidator.evWrongOutputList,
3349 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003350 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003351 "clz": {
3352 "op": Op.CLZ,
3353 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003354 "build_fcn": (
3355 build_unary,
3356 TosaTensorGen.tgBasic,
3357 TosaTensorValuesGen.tvgDefault,
3358 None,
3359 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003361 "error_if_validators": (
3362 TosaErrorValidator.evWrongInputType,
3363 TosaErrorValidator.evWrongOutputType,
3364 TosaErrorValidator.evWrongInputList,
3365 TosaErrorValidator.evWrongOutputList,
3366 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003367 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003368 "exp": {
3369 "op": Op.EXP,
3370 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003371 "build_fcn": (
3372 build_unary,
3373 TosaTensorGen.tgBasic,
3374 TosaTensorValuesGen.tvgDefault,
3375 None,
3376 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003377 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003378 "error_if_validators": (
3379 TosaErrorValidator.evWrongInputType,
3380 TosaErrorValidator.evWrongOutputType,
3381 TosaErrorValidator.evWrongInputList,
3382 TosaErrorValidator.evWrongOutputList,
3383 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003384 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003385 "floor": {
3386 "op": Op.FLOOR,
3387 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003388 "build_fcn": (
3389 build_unary,
3390 TosaTensorGen.tgBasic,
3391 TosaTensorValuesGen.tvgDefault,
3392 None,
3393 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003394 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003395 "error_if_validators": (
3396 TosaErrorValidator.evWrongInputType,
3397 TosaErrorValidator.evWrongOutputType,
3398 TosaErrorValidator.evWrongInputList,
3399 TosaErrorValidator.evWrongOutputList,
3400 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003401 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003402 "log": {
3403 "op": Op.LOG,
3404 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003405 "build_fcn": (
3406 build_unary,
3407 TosaTensorGen.tgBasic,
3408 TosaTensorValuesGen.tvgDefault,
3409 None,
3410 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003411 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003412 "error_if_validators": (
3413 TosaErrorValidator.evWrongInputType,
3414 TosaErrorValidator.evWrongOutputType,
3415 TosaErrorValidator.evWrongInputList,
3416 TosaErrorValidator.evWrongOutputList,
3417 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003419 "logical_not": {
3420 "op": Op.LOGICAL_NOT,
3421 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003422 "build_fcn": (
3423 build_unary,
3424 TosaTensorGen.tgBasic,
3425 TosaTensorValuesGen.tvgDefault,
3426 None,
3427 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003428 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003429 "error_if_validators": (
3430 TosaErrorValidator.evWrongInputType,
3431 TosaErrorValidator.evWrongOutputType,
3432 TosaErrorValidator.evWrongInputList,
3433 TosaErrorValidator.evWrongOutputList,
3434 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003435 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003436 "negate": {
3437 "op": Op.NEGATE,
3438 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 "build_fcn": (
3440 build_unary,
3441 TosaTensorGen.tgBasic,
3442 TosaTensorValuesGen.tvgNegate,
3443 None,
3444 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 "qgen": TosaQuantGen.qgUnary,
3446 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003447 "error_if_validators": (
3448 TosaErrorValidator.evInputZeroPointNotZero,
3449 TosaErrorValidator.evOutputZeroPointNotZero,
3450 TosaErrorValidator.evWrongInputType,
3451 TosaErrorValidator.evWrongOutputType,
3452 TosaErrorValidator.evWrongInputList,
3453 TosaErrorValidator.evWrongOutputList,
3454 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003455 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003456 "reciprocal": {
3457 "op": Op.RECIPROCAL,
3458 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003459 "build_fcn": (
3460 build_unary,
3461 TosaTensorGen.tgBasic,
3462 TosaTensorValuesGen.tvgDefault,
3463 None,
3464 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003465 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003466 "error_if_validators": (
3467 TosaErrorValidator.evWrongInputType,
3468 TosaErrorValidator.evWrongOutputType,
3469 TosaErrorValidator.evWrongInputList,
3470 TosaErrorValidator.evWrongOutputList,
3471 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003472 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003473 "rsqrt": {
3474 "op": Op.RSQRT,
3475 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003476 "build_fcn": (
3477 build_unary,
3478 TosaTensorGen.tgBasic,
3479 TosaTensorValuesGen.tvgDefault,
3480 None,
3481 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003482 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003483 "error_if_validators": (
3484 TosaErrorValidator.evWrongInputType,
3485 TosaErrorValidator.evWrongOutputType,
3486 TosaErrorValidator.evWrongInputList,
3487 TosaErrorValidator.evWrongOutputList,
3488 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003489 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003490 # Elementwise Ternary operators
3491 "select": {
3492 "op": Op.SELECT,
3493 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003494 "build_fcn": (
3495 build_select,
3496 TosaTensorGen.tgBroadcastFuzz,
3497 TosaTensorValuesGen.tvgSelect,
3498 None,
3499 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003500 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003501 "error_if_validators": (
3502 TosaErrorValidator.evRankMismatch,
3503 TosaErrorValidator.evWrongInputType,
3504 TosaErrorValidator.evWrongOutputType,
3505 TosaErrorValidator.evWrongInputList,
3506 TosaErrorValidator.evWrongOutputList,
3507 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003508 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003509 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003510 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003511 # Comparison operators
3512 "equal": {
3513 "op": Op.EQUAL,
3514 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003515 "build_fcn": (
3516 build_comparison,
3517 TosaTensorGen.tgBroadcastFuzz,
3518 TosaTensorValuesGen.tvgEqual,
3519 None,
3520 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003521 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003522 "error_if_validators": (
3523 TosaErrorValidator.evRankMismatch,
3524 TosaErrorValidator.evWrongInputType,
3525 TosaErrorValidator.evWrongOutputType,
3526 TosaErrorValidator.evWrongInputList,
3527 TosaErrorValidator.evWrongOutputList,
3528 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003529 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003530 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003531 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003532 "greater_equal": {
3533 "op": Op.GREATER_EQUAL,
3534 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003535 "build_fcn": (
3536 build_comparison,
3537 TosaTensorGen.tgBroadcastFuzz,
3538 TosaTensorValuesGen.tvgDefault,
3539 None,
3540 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003541 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003542 "error_if_validators": (
3543 TosaErrorValidator.evRankMismatch,
3544 TosaErrorValidator.evWrongInputType,
3545 TosaErrorValidator.evWrongOutputType,
3546 TosaErrorValidator.evWrongInputList,
3547 TosaErrorValidator.evWrongOutputList,
3548 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003549 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003550 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003551 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003552 "greater": {
3553 "op": Op.GREATER,
3554 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003555 "build_fcn": (
3556 build_comparison,
3557 TosaTensorGen.tgBroadcastFuzz,
3558 TosaTensorValuesGen.tvgDefault,
3559 None,
3560 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003561 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003562 "error_if_validators": (
3563 TosaErrorValidator.evRankMismatch,
3564 TosaErrorValidator.evWrongInputType,
3565 TosaErrorValidator.evWrongOutputType,
3566 TosaErrorValidator.evWrongInputList,
3567 TosaErrorValidator.evWrongOutputList,
3568 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003569 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003570 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003571 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003572 # Reduction operators
3573 "reduce_all": {
3574 "op": Op.REDUCE_ALL,
3575 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003576 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003577 "build_fcn": (
3578 build_reduce,
3579 TosaTensorGen.tgBasic,
3580 TosaTensorValuesGen.tvgDefault,
3581 TosaArgGen.agAxis,
3582 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003583 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003584 "error_if_validators": (
3585 TosaErrorValidator.evAxisLargerRank,
3586 TosaErrorValidator.evAxisSmallerZero,
3587 TosaErrorValidator.evShapeOfAxisNotOne,
3588 TosaErrorValidator.evWrongInputType,
3589 TosaErrorValidator.evWrongOutputType,
3590 TosaErrorValidator.evWrongRank,
3591 TosaErrorValidator.evWrongInputList,
3592 TosaErrorValidator.evWrongOutputList,
3593 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003594 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003595 "reduce_any": {
3596 "op": Op.REDUCE_ANY,
3597 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003598 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003599 "build_fcn": (
3600 build_reduce,
3601 TosaTensorGen.tgBasic,
3602 TosaTensorValuesGen.tvgDefault,
3603 TosaArgGen.agAxis,
3604 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003605 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003606 "error_if_validators": (
3607 TosaErrorValidator.evAxisLargerRank,
3608 TosaErrorValidator.evAxisSmallerZero,
3609 TosaErrorValidator.evShapeOfAxisNotOne,
3610 TosaErrorValidator.evWrongInputType,
3611 TosaErrorValidator.evWrongOutputType,
3612 TosaErrorValidator.evWrongRank,
3613 TosaErrorValidator.evWrongInputList,
3614 TosaErrorValidator.evWrongOutputList,
3615 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003616 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003617 "reduce_max": {
3618 "op": Op.REDUCE_MAX,
3619 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003620 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003621 "build_fcn": (
3622 build_reduce,
3623 TosaTensorGen.tgBasic,
3624 TosaTensorValuesGen.tvgDefault,
3625 TosaArgGen.agAxis,
3626 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003627 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003628 "error_if_validators": (
3629 TosaErrorValidator.evAxisLargerRank,
3630 TosaErrorValidator.evAxisSmallerZero,
3631 TosaErrorValidator.evShapeOfAxisNotOne,
3632 TosaErrorValidator.evWrongInputType,
3633 TosaErrorValidator.evWrongOutputType,
3634 TosaErrorValidator.evWrongRank,
3635 TosaErrorValidator.evWrongInputList,
3636 TosaErrorValidator.evWrongOutputList,
3637 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003638 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003639 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003640 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003641 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003642 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003643 "build_fcn": (
3644 build_reduce,
3645 TosaTensorGen.tgBasic,
3646 TosaTensorValuesGen.tvgDefault,
3647 TosaArgGen.agAxis,
3648 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003649 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003650 "error_if_validators": (
3651 TosaErrorValidator.evAxisLargerRank,
3652 TosaErrorValidator.evAxisSmallerZero,
3653 TosaErrorValidator.evShapeOfAxisNotOne,
3654 TosaErrorValidator.evWrongInputType,
3655 TosaErrorValidator.evWrongOutputType,
3656 TosaErrorValidator.evWrongRank,
3657 TosaErrorValidator.evWrongInputList,
3658 TosaErrorValidator.evWrongOutputList,
3659 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003660 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003661 "reduce_product": {
3662 "op": Op.REDUCE_PRODUCT,
3663 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003664 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003665 "build_fcn": (
3666 build_reduce,
3667 TosaTensorGen.tgBasic,
3668 TosaTensorValuesGen.tvgDefault,
3669 TosaArgGen.agAxis,
3670 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003671 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003672 "error_if_validators": (
3673 TosaErrorValidator.evAxisLargerRank,
3674 TosaErrorValidator.evAxisSmallerZero,
3675 TosaErrorValidator.evShapeOfAxisNotOne,
3676 TosaErrorValidator.evWrongInputType,
3677 TosaErrorValidator.evWrongOutputType,
3678 TosaErrorValidator.evWrongRank,
3679 TosaErrorValidator.evWrongInputList,
3680 TosaErrorValidator.evWrongOutputList,
3681 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003682 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003683 "reduce_sum": {
3684 "op": Op.REDUCE_SUM,
3685 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003686 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003687 "build_fcn": (
3688 build_reduce,
3689 TosaTensorGen.tgBasic,
3690 TosaTensorValuesGen.tvgReduceSum,
3691 TosaArgGen.agAxis,
3692 ),
James Ward24dbc422022-10-19 12:20:31 +01003693 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003694 "error_if_validators": (
3695 TosaErrorValidator.evAxisLargerRank,
3696 TosaErrorValidator.evAxisSmallerZero,
3697 TosaErrorValidator.evShapeOfAxisNotOne,
3698 TosaErrorValidator.evWrongInputType,
3699 TosaErrorValidator.evWrongOutputType,
3700 TosaErrorValidator.evWrongRank,
3701 TosaErrorValidator.evWrongInputList,
3702 TosaErrorValidator.evWrongOutputList,
3703 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003704 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003705 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003706 "concat": {
3707 "op": Op.CONCAT,
3708 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003709 "build_fcn": (
3710 build_concat,
3711 TosaTensorGen.tgConcat,
3712 TosaTensorValuesGen.tvgConcat,
3713 TosaArgGen.agAxis,
3714 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003715 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003716 "error_if_validators": (
3717 TosaErrorValidator.evAxisLargerRank,
3718 TosaErrorValidator.evAxisSmallerZero,
3719 TosaErrorValidator.evConcatInputRankMismatch,
3720 TosaErrorValidator.evConcatShapeSumMismatch,
3721 TosaErrorValidator.evConcatInputDimMismatch,
3722 TosaErrorValidator.evWrongInputType,
3723 TosaErrorValidator.evWrongOutputType,
3724 TosaErrorValidator.evWrongOutputList,
3725 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003726 },
3727 "pad": {
3728 "op": Op.PAD,
3729 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003730 "build_fcn": (
3731 build_pad,
3732 TosaTensorGen.tgBasic,
3733 TosaTensorValuesGen.tvgDefault,
3734 TosaArgGen.agPad,
3735 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003736 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003737 "error_if_validators": (
3738 TosaErrorValidator.evWrongInputType,
3739 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003740 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003741 TosaErrorValidator.evWrongOutputType,
3742 TosaErrorValidator.evWrongInputList,
3743 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003744 TosaErrorValidator.evRankMismatch,
3745 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003746 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003747 },
3748 "reshape": {
3749 "op": Op.RESHAPE,
3750 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003751 "build_fcn": (
3752 build_reshape,
3753 TosaTensorGen.tgBasic,
3754 TosaTensorValuesGen.tvgDefault,
3755 TosaArgGen.agReshape,
3756 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003757 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003758 "error_if_validators": (
3759 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3760 TosaErrorValidator.evWrongInputType,
3761 TosaErrorValidator.evWrongOutputType,
3762 TosaErrorValidator.evWrongInputList,
3763 TosaErrorValidator.evWrongOutputList,
Jerry Ge264f7fa2023-04-21 22:49:57 +00003764 TosaErrorValidator.evReshapeOutputSizeMultiInference,
3765 TosaErrorValidator.evReshapeOutputSizeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003766 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003767 },
3768 "reverse": {
3769 "op": Op.REVERSE,
3770 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003771 "build_fcn": (
3772 build_reverse,
3773 TosaTensorGen.tgBasic,
3774 TosaTensorValuesGen.tvgDefault,
3775 TosaArgGen.agAxis,
3776 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003777 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003778 "error_if_validators": (
3779 TosaErrorValidator.evAxisSmallerZero,
3780 TosaErrorValidator.evAxisLargerRank,
3781 TosaErrorValidator.evWrongInputType,
3782 TosaErrorValidator.evWrongOutputType,
3783 TosaErrorValidator.evWrongInputList,
3784 TosaErrorValidator.evWrongOutputList,
3785 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003786 },
3787 "slice": {
3788 "op": Op.SLICE,
3789 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003790 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003791 "build_fcn": (
3792 build_slice,
3793 TosaTensorGen.tgBasic,
3794 TosaTensorValuesGen.tvgDefault,
3795 TosaArgGen.agSlice,
3796 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003797 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003798 "error_if_validators": (
3799 TosaErrorValidator.evStartSmallerZero,
3800 TosaErrorValidator.evSizeSmallerEqualZero,
3801 TosaErrorValidator.evStartSizeOutsideBounds,
3802 TosaErrorValidator.evSizeOutputShapeMismatch,
3803 TosaErrorValidator.evInputSizeStartLengthMismatch,
3804 TosaErrorValidator.evWrongRank,
3805 TosaErrorValidator.evWrongInputType,
3806 TosaErrorValidator.evWrongOutputType,
3807 TosaErrorValidator.evWrongInputList,
3808 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003809 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003810 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003811 },
3812 "tile": {
3813 "op": Op.TILE,
3814 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003815 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003816 "build_fcn": (
3817 build_tile,
3818 TosaTensorGen.tgBasic,
3819 TosaTensorValuesGen.tvgDefault,
3820 TosaArgGen.agTile,
3821 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003822 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003823 "error_if_validators": (
3824 TosaErrorValidator.evWrongInputType,
3825 TosaErrorValidator.evWrongOutputType,
3826 TosaErrorValidator.evWrongInputList,
3827 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003828 TosaErrorValidator.evRankMismatch,
3829 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003830 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003831 },
3832 "transpose": {
3833 "op": Op.TRANSPOSE,
3834 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003835 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003836 "build_fcn": (
3837 build_transpose,
3838 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003839 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003840 TosaArgGen.agTranspose,
3841 ),
3842 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003843 "error_if_validators": (
3844 TosaErrorValidator.evIndexOutsideBounds,
3845 TosaErrorValidator.evIndexUsedTwice,
3846 TosaErrorValidator.evWrongInputType,
3847 TosaErrorValidator.evWrongOutputType,
3848 TosaErrorValidator.evWrongInputList,
3849 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003850 TosaErrorValidator.evWrongRank,
3851 TosaErrorValidator.evRankMismatch,
3852 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003853 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003854 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003855 # Data nodes
3856 "const": {
3857 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003858 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003859 "build_fcn": (
3860 build_const,
3861 TosaTensorGen.tgBasic,
3862 TosaTensorValuesGen.tvgDefault,
3863 None,
3864 ),
Luke Hutton65872422023-02-20 10:33:04 +00003865 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08003866 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003867 "identity": {
3868 "op": Op.IDENTITY,
3869 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003870 "build_fcn": (
3871 build_unary,
3872 TosaTensorGen.tgBasic,
3873 TosaTensorValuesGen.tvgDefault,
3874 None,
3875 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003876 "types": TYPE_FIB,
3877 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003878 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003879 "gather": {
3880 "op": Op.GATHER,
3881 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3882 "operands": (1, 0),
3883 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003884 "build_fcn": (
3885 build_gather,
3886 TosaTensorGen.tgBasic,
3887 TosaTensorValuesGen.tvgDefault,
3888 None,
3889 ),
James Ward24dbc422022-10-19 12:20:31 +01003890 "types": (
3891 DType.INT8,
3892 DType.INT16,
3893 DType.INT32,
3894 DType.FP16,
3895 DType.BF16,
3896 DType.FP32,
3897 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003898 "error_if_validators": (
3899 TosaErrorValidator.evWrongInputType,
3900 TosaErrorValidator.evWrongOutputType,
3901 TosaErrorValidator.evWrongInputList,
3902 TosaErrorValidator.evWrongOutputList,
3903 TosaErrorValidator.evWrongRank,
3904 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003905 },
3906 "scatter": {
3907 "op": Op.SCATTER,
3908 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003909 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003910 "operands": (2, 0),
3911 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003912 "build_fcn": (
3913 build_scatter,
3914 TosaTensorGen.tgScatter,
3915 TosaTensorValuesGen.tvgDefault,
3916 None,
3917 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003918 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003919 "error_if_validators": (
3920 TosaErrorValidator.evWrongInputType,
3921 TosaErrorValidator.evWrongOutputType,
3922 TosaErrorValidator.evWrongInputList,
3923 TosaErrorValidator.evWrongOutputList,
3924 TosaErrorValidator.evWrongRank,
3925 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003926 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003927 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003928 "resize": {
3929 "op": Op.RESIZE,
3930 "operands": (1, 0),
3931 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003932 "build_fcn": (
3933 build_resize,
3934 TosaTensorGen.tgNHWC,
3935 TosaTensorValuesGen.tvgDefault,
3936 TosaArgGen.agResize,
3937 ),
James Ward24dbc422022-10-19 12:20:31 +01003938 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 "invalid_test_validators": (
3940 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003941 ),
3942 "error_if_validators": (
3943 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003944 TosaErrorValidator.evScaleSmallerEqualZero,
3945 TosaErrorValidator.evScaleNLargerMax,
3946 TosaErrorValidator.evScaleDLargerMax,
3947 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003948 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003949 TosaErrorValidator.evBorderSmallerMin,
3950 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003951 TosaErrorValidator.evWrongInputType,
3952 TosaErrorValidator.evWrongOutputType,
3953 TosaErrorValidator.evWrongRank,
3954 TosaErrorValidator.evWrongInputList,
3955 TosaErrorValidator.evWrongOutputList,
3956 TosaErrorValidator.evBatchMismatch,
3957 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003958 TosaErrorValidator.evResizeOutputShapeMismatch,
3959 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003960 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003961 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003962 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003963 "cast": {
3964 "op": Op.CAST,
3965 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003966 "build_fcn": (
3967 build_cast,
3968 TosaTensorGen.tgBasic,
3969 TosaTensorValuesGen.tvgDefault,
3970 TosaArgGen.agCast,
3971 ),
James Ward8b390432022-08-12 20:48:56 +01003972 "types": (
3973 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003974 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003975 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003976 DType.INT8,
3977 DType.INT16,
3978 DType.INT32,
3979 DType.BOOL,
3980 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003981 "error_if_validators": (
3982 TosaErrorValidator.evWrongInputType,
3983 TosaErrorValidator.evWrongOutputType,
3984 TosaErrorValidator.evWrongInputList,
3985 TosaErrorValidator.evWrongOutputList,
3986 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003987 },
3988 "rescale": {
3989 "op": Op.RESCALE,
3990 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003991 "build_fcn": (
3992 build_rescale,
3993 TosaTensorGen.tgBasic,
3994 TosaTensorValuesGen.tvgDefault,
3995 TosaArgGen.agRescale,
3996 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003997 "types": [
3998 DType.UINT8,
3999 DType.INT8,
4000 DType.INT16,
4001 DType.INT32,
4002 DType.INT48,
4003 DType.UINT16,
4004 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004005 "error_if_validators": (
4006 TosaErrorValidator.evInputZeroPointNotZero,
4007 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004008 TosaErrorValidator.evU16InputZeroPointNotValid,
4009 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004010 TosaErrorValidator.evScaleTrue,
4011 TosaErrorValidator.evScaleNotTrue,
4012 TosaErrorValidator.evWrongInputType,
4013 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004014 TosaErrorValidator.evWrongInputList,
4015 TosaErrorValidator.evWrongOutputList,
4016 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004017 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004018 # Custom
4019 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004020 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004021 # Two varients of cond_if, one that generates one of two constant tensors (no
4022 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4023 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004024 "cond_if_const": {
4025 "op": Op.COND_IF,
4026 "operands": (0, 2),
4027 "build_fcn": (
4028 build_cond_if_const,
4029 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004030 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004031 TosaArgGen.agCondIf,
4032 ),
4033 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004034 "error_if_validators": (
4035 TosaErrorValidator.evOutputListThenGraphMismatch,
4036 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004037 TosaErrorValidator.evCondIfCondNotMatchingBool,
4038 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004039 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004040 },
4041 "cond_if_binary": {
4042 "op": Op.COND_IF,
4043 "operands": (2, 0),
4044 "build_fcn": (
4045 build_cond_if_binary,
4046 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004047 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004048 TosaArgGen.agCondIf,
4049 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004050 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004051 "error_if_validators": (
4052 TosaErrorValidator.evInputListThenGraphMismatch,
4053 TosaErrorValidator.evInputListElseGraphMismatch,
4054 TosaErrorValidator.evOutputListThenGraphMismatch,
4055 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004056 TosaErrorValidator.evCondIfCondNotMatchingBool,
4057 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004058 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004059 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004060 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004061 "while_loop": {
4062 "op": Op.WHILE_LOOP,
4063 "operands": (0, 1),
4064 "build_fcn": (
4065 build_while_loop,
4066 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004067 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004068 TosaArgGen.agWhileLoop,
4069 ),
4070 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004071 "error_if_validators": (
4072 TosaErrorValidator.evInputListOutputListMismatch,
4073 TosaErrorValidator.evInputListCondGraphMismatch,
4074 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4075 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4076 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004077 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004078 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004079 },
Luke Hutton57287132023-02-06 14:54:18 +00004080 "fft2d": {
4081 "op": Op.FFT2D,
4082 "operands": (2, 0),
4083 "rank": (3, 3),
4084 "build_fcn": (
4085 build_fft2d,
4086 TosaTensorGen.tgFFT2d,
4087 TosaTensorValuesGen.tvgDefault,
4088 TosaArgGen.agFFT2d,
4089 ),
4090 "types": [DType.FP32],
4091 "error_if_validators": (
4092 TosaErrorValidator.evWrongInputType,
4093 TosaErrorValidator.evWrongOutputType,
4094 TosaErrorValidator.evWrongInputList,
4095 TosaErrorValidator.evWrongOutputList,
4096 TosaErrorValidator.evWrongRank,
4097 TosaErrorValidator.evBatchMismatch,
4098 TosaErrorValidator.evKernelNotPowerOfTwo,
4099 TosaErrorValidator.evFFTInputShapeMismatch,
4100 TosaErrorValidator.evFFTOutputShapeMismatch,
4101 ),
4102 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004103 "rfft2d": {
4104 "op": Op.RFFT2D,
4105 "operands": (1, 0),
4106 "rank": (3, 3),
4107 "build_fcn": (
4108 build_rfft2d,
4109 TosaTensorGen.tgRFFT2d,
4110 TosaTensorValuesGen.tvgDefault,
4111 TosaArgGen.agNone,
4112 ),
4113 "types": [DType.FP32],
4114 "error_if_validators": (
4115 TosaErrorValidator.evWrongInputType,
4116 TosaErrorValidator.evWrongOutputType,
4117 TosaErrorValidator.evWrongInputList,
4118 TosaErrorValidator.evWrongOutputList,
4119 TosaErrorValidator.evWrongRank,
4120 TosaErrorValidator.evBatchMismatch,
4121 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004122 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004123 ),
4124 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004125 }
4126
Kevin Cheng550ccc52021-03-03 11:21:43 -08004127
Eric Kunzee5e26762020-10-13 16:11:07 -07004128class OutputShaper:
4129 # Methods in this class compute the expected output shape and datatype
4130 # for common classes of operations
4131 def __init__(self):
4132 pass
4133
4134 # These methods return arguments that can be used for
4135 # creating a new output tensor
4136 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004137 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4138 if error_name != ErrorIf.RankMismatch:
4139 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004140 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004141
4142 shape = []
4143 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004144 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004145 shape.append(b.shape[i])
4146 else:
4147 shape.append(a.shape[i])
4148
Jerry Ge135c9552023-05-23 20:59:32 +00004149 fuzz_idx = rng.integers(0, len(a.shape))
4150 if error_name == ErrorIf.DimensionMismatch:
4151 shape[fuzz_idx] += 1
4152
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004153 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004154 all_dtypes = [
4155 DType.INT8,
4156 DType.INT16,
4157 DType.INT32,
4158 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004159 DType.FP16,
4160 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004161 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004162 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004163 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4164 outputDType = rng.choice(wrong_dtypes)
4165 else:
4166 outputDType = a.dtype
4167
4168 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004169
4170 @staticmethod
4171 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004172 assert len(a.shape) == len(b.shape)
4173 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004174
4175 shape = []
4176 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004177 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004178 shape.append(a.shape[i])
4179
Kevin Cheng550ccc52021-03-03 11:21:43 -08004180 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004181
4182 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004183 def unaryOp(ser, rng, a, error_name=None):
4184 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004185 all_dtypes = [
4186 DType.INT8,
4187 DType.INT16,
4188 DType.INT32,
4189 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004190 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004191 DType.FP16,
4192 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004193 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004194 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4195 outputDType = rng.choice(wrong_dtypes)
4196 else:
4197 outputDType = a.dtype
4198
4199 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004200
4201 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004202 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004203 if error_name != ErrorIf.RankMismatch:
4204 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004205 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004206
4207 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004208 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004209 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004210 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4211 else:
4212 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004213
Jerry Ge135c9552023-05-23 20:59:32 +00004214 fuzz_idx = rng.integers(0, len(a.shape))
4215 if error_name == ErrorIf.DimensionMismatch:
4216 shape[fuzz_idx] += 1
4217
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004218 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004219 all_dtypes = [
4220 DType.INT8,
4221 DType.INT16,
4222 DType.INT32,
4223 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004224 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004225 DType.FP16,
4226 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004227 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004228 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4229 outputDType = rng.choice(wrong_dtypes)
4230 else:
4231 outputDType = a.dtype
4232
4233 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004234
4235 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004236 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004237 if error_name != ErrorIf.RankMismatch:
4238 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004239 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004240
4241 # Do broadcast
4242 shape = []
4243 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004244 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004245 shape.append(b.shape[i])
4246 else:
4247 shape.append(a.shape[i])
4248
Jerry Ge135c9552023-05-23 20:59:32 +00004249 fuzz_idx = rng.integers(0, len(a.shape))
4250 if error_name == ErrorIf.DimensionMismatch:
4251 shape[fuzz_idx] += 1
4252
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004253 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004254 wrong_dtypes = [
4255 DType.INT8,
4256 DType.INT16,
4257 DType.INT32,
4258 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004259 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004260 DType.FP16,
4261 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004262 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004263 outputDType = rng.choice(wrong_dtypes)
4264 else:
4265 outputDType = DType.BOOL
4266
4267 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004268
4269 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004270 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004271 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004272 if error_name not in [
4273 ErrorIf.AxisSmallerZero,
4274 ErrorIf.AxisLargerRank,
4275 ErrorIf.ShapeOfAxisNotOne,
4276 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004277 shape[axis] = 1
4278 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4279 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004280
Matthew Haddond6ce7252021-09-29 15:35:44 +01004281 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004282 all_dtypes = [
4283 DType.INT8,
4284 DType.INT16,
4285 DType.INT32,
4286 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004287 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004288 DType.FP16,
4289 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004290 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004291 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4292 outputDType = rng.choice(wrong_dtypes)
4293 else:
4294 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004295
Matthew Haddond6ce7252021-09-29 15:35:44 +01004296 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004297
4298 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004299 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004300 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004301
4302 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4303 del shape[axis]
4304
4305 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4306 remove = rng.choice([True, False])
4307 if remove and len(shape) > 1:
4308 del shape[0]
4309 else:
4310 shape.append(1)
4311 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4312 for i in range(len(shape)):
4313 shape[i] = shape[i] + rng.integers(1, 10)
4314
4315 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004316 all_dtypes = [
4317 DType.INT8,
4318 DType.INT16,
4319 DType.INT32,
4320 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004321 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004322 DType.FP16,
4323 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004324 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004325 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4326 outputDType = rng.choice(wrong_dtypes)
4327 else:
4328 outputDType = DType.INT32
4329
4330 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004331
4332 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004333 def conv2dOp(
4334 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4335 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004336
4337 # IFM: NHWC
4338 # Filter: OHWI
4339 # OFM: NHWC
4340
Kevin Cheng550ccc52021-03-03 11:21:43 -08004341 h = (
4342 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004343 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004344 + padding[0]
4345 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004346 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004347 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004348
Kevin Cheng550ccc52021-03-03 11:21:43 -08004349 w = (
4350 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004351 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004352 + padding[2]
4353 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004354 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004355 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004356
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004357 if error_name == ErrorIf.ConvOutputShapeMismatch:
4358 choices = [1, 2, 3]
4359 change = rng.choice(choices)
4360 # increment in multiples of stride to not hit non-integer error case
4361 if change in [1, 3]:
4362 h = h + (rng.choice(choices) * strides[0])
4363 if change in [2, 3]:
4364 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004365
Eric Kunzee5e26762020-10-13 16:11:07 -07004366 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4367
James Ward8b390432022-08-12 20:48:56 +01004368 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004369 # Pick some potentially correct output dtype if input type is incorrect
4370 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004371 else:
James Ward8b390432022-08-12 20:48:56 +01004372 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004373
4374 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004375 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004376 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004377 else:
4378 excludes = [out_dtype]
4379 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004380 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004381
Kevin Cheng550ccc52021-03-03 11:21:43 -08004382 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004383
4384 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004385 def conv3dOp(
4386 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4387 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004388
4389 # IFM: NDHWC
4390 # Filter: ODHWI
4391 # OFM: NDHWC
4392
4393 d = (
4394 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004395 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004396 + padding[0]
4397 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004398 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004399 ) // strides[0] + 1
4400
4401 h = (
4402 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004403 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004404 + padding[2]
4405 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004406 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004407 ) // strides[1] + 1
4408
4409 w = (
4410 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004411 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004412 + padding[4]
4413 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004414 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004415 ) // strides[2] + 1
4416
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004417 if error_name == ErrorIf.ConvOutputShapeMismatch:
4418 choices = [1, 2, 3, 4]
4419 change = rng.choice(choices)
4420 # increment in multiples of stride to not hit non-integer error case
4421 if change in [1, 4]:
4422 d = d + (rng.choice(choices) * strides[0])
4423 if change in [2, 4]:
4424 h = h + (rng.choice(choices) * strides[1])
4425 if change in [3, 4]:
4426 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004427
Kevin Cheng1533b852021-09-01 12:51:58 -07004428 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4429
James Ward8b390432022-08-12 20:48:56 +01004430 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004431 # Pick some potentially correct output dtype if input type is incorrect
4432 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004433 else:
James Ward8b390432022-08-12 20:48:56 +01004434 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004435
4436 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004437 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004438 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004439 else:
4440 excludes = [out_dtype]
4441 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004442 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004443
4444 return ser.addOutput(ofm_shape, out_dtype)
4445
4446 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004447 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004448 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004449 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004450 # IFM: NHWC
4451 # Filter: HWCM
4452 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004453
Kevin Cheng550ccc52021-03-03 11:21:43 -08004454 h = (
4455 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004456 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004457 + padding[0]
4458 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004459 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004460 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004461
Kevin Cheng550ccc52021-03-03 11:21:43 -08004462 w = (
4463 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004464 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004465 + padding[2]
4466 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004467 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004468 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004469
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004470 if error_name == ErrorIf.ConvOutputShapeMismatch:
4471 choices = [1, 2, 3]
4472 change = rng.choice(choices)
4473 # increment in multiples of stride to not hit non-integer error case
4474 if change in [1, 3]:
4475 h = h + (rng.choice(choices) * strides[0])
4476 if change in [2, 3]:
4477 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004478
Eric Kunzee5e26762020-10-13 16:11:07 -07004479 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4480
James Ward8b390432022-08-12 20:48:56 +01004481 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004482 # Pick some potentially correct output dtype if input type is incorrect
4483 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004484 else:
James Ward8b390432022-08-12 20:48:56 +01004485 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004486
4487 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004488 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004489 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004490 else:
4491 excludes = [out_dtype]
4492 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004493 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004494
Kevin Cheng550ccc52021-03-03 11:21:43 -08004495 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004496
4497 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004498 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004499 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004500 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004501 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004502 h = 1
4503 w = 1
4504 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004505 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4506 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004507
4508 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004509 choices = [1, 2, 3]
4510 change = rng.choice(choices)
4511 # increment in multiples of stride to not hit non-integer error case
4512 if change in [1, 3]:
4513 h = h + (rng.choice(choices) * stride[0])
4514 if change in [2, 3]:
4515 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004516 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004517
4518 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004519 all_dtypes = [
4520 DType.INT8,
4521 DType.INT16,
4522 DType.INT32,
4523 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004524 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004525 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004526 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004527 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004528 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4529 outputDType = rng.choice(wrong_dtypes)
4530 else:
4531 outputDType = ifm.dtype
4532
4533 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004534
4535 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004536 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004537 # input: N, IC
4538 # filter: OC, IC
4539 # output: N, OC
4540
4541 output_shape = [input.shape[0], filter.shape[0]]
4542
James Ward8b390432022-08-12 20:48:56 +01004543 # Validated in arg_gen (also invalidated for ErrorIf)
4544 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004545
Kevin Cheng550ccc52021-03-03 11:21:43 -08004546 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004547
4548 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004549 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004550 # a: N, H, C
4551 # b: N, C, W
4552 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004553
Kevin Cheng2d60f002021-06-09 14:18:32 -07004554 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004555
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004556 if error_name == ErrorIf.WrongOutputType:
4557 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004558 incorrect_types = (
4559 DType.INT4,
4560 DType.INT8,
4561 DType.INT16,
4562 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004563 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004564 DType.FP16,
4565 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004566 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004567 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004568 incorrect_types = (
4569 DType.INT4,
4570 DType.INT8,
4571 DType.INT16,
4572 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004573 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004574 DType.FP16,
4575 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004576 )
James Ward24dbc422022-10-19 12:20:31 +01004577 elif (
4578 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4579 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004580 incorrect_types = (
4581 DType.INT4,
4582 DType.INT8,
4583 DType.INT16,
4584 DType.INT32,
4585 DType.INT48,
4586 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004587 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004588 elif error_name == ErrorIf.WrongInputType:
4589 # Pick some potentially correct output dtype if input type is incorrect
4590 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004591 else:
James Ward8b390432022-08-12 20:48:56 +01004592 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004593
Kevin Cheng550ccc52021-03-03 11:21:43 -08004594 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004595
4596 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004597 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004598 input1 = a[0]
4599 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004600
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004601 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004602 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004603 if not (
4604 # unable to concat tensors of different ranks
4605 error_name == ErrorIf.ConcatInputRankMismatch
4606 # unable to concat tensors along an invalid axis
4607 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004608 ):
4609 for tensor in remaining_inputs:
4610 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004611
Matthew Haddon01c359d2021-10-15 16:30:48 +01004612 if error_name == ErrorIf.ConcatShapeSumMismatch:
4613 output_shape[axis] += rng.integers(5, 10)
4614
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004615 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004616 all_dtypes = {
4617 DType.INT8,
4618 DType.INT16,
4619 DType.INT32,
4620 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004621 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004622 DType.FP16,
4623 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004624 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004625 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4626 outputDType = rng.choice(wrong_dtypes)
4627 else:
4628 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004629
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004630 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004631
4632 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004633 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004634
4635 output_shape = a.shape.copy()
4636
4637 for i in range(len(output_shape)):
4638 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4639
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004640 if error_name == ErrorIf.PadOutputShapeMismatch:
4641 bad_dim = rng.choice(range(len(output_shape)))
4642 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004643 elif error_name == ErrorIf.RankMismatch:
4644 output_shape = get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004645
Matthew Haddone807aae2021-10-11 18:12:58 +01004646 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004647 all_dtypes = [
4648 DType.INT8,
4649 DType.INT16,
4650 DType.INT32,
4651 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004652 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004653 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004654 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004655 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004656 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4657 outputDType = rng.choice(wrong_dtypes)
4658 else:
4659 outputDType = a.dtype
4660
4661 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004662
4663 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004664 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004665 output_shape = shape.copy()
4666
Matthew Haddone807aae2021-10-11 18:12:58 +01004667 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4668 for i in range(len(output_shape)):
4669 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4670
4671 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004672 all_dtypes = [
4673 DType.INT8,
4674 DType.INT16,
4675 DType.INT32,
4676 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004677 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004678 DType.FP16,
4679 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004680 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004681 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4682 outputDType = rng.choice(wrong_dtypes)
4683 else:
4684 outputDType = a.dtype
4685
4686 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004687
4688 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004689 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004690
Matthew Haddone807aae2021-10-11 18:12:58 +01004691 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004692 all_dtypes = [
4693 DType.INT8,
4694 DType.INT16,
4695 DType.INT32,
4696 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004697 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004698 DType.FP16,
4699 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004700 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004701 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004702 outputDType = rng.choice(wrong_dtypes)
4703 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004704 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004705
Luke Huttona4e48ca2023-02-22 11:53:48 +00004706 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004707 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01004708 for index in range(len(output_shape)):
4709 if output_shape[index] <= 2:
4710 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4711 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004712 output_shape[index] = output_shape[index] + rng.choice(
4713 [-2, -1, 1, 2]
4714 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00004715 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
4716 output_shape = input.shape.copy()
4717 elif error_name == ErrorIf.RankMismatch:
4718 output_shape = get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01004719
4720 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004721
4722 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004723 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004724
4725 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004726 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004727
4728 for i in range(len(output_shape)):
4729 output_shape[i] = a.shape[i] * multiples[i]
4730
Luke Huttona4e48ca2023-02-22 11:53:48 +00004731 if error_name == ErrorIf.RankMismatch:
4732 output_shape = get_rank_mismatch_shape(rng, output_shape)
4733
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004734 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004735 all_dtypes = [
4736 DType.INT8,
4737 DType.INT16,
4738 DType.INT32,
4739 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004740 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004741 DType.FP16,
4742 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004743 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004744 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4745 outputDType = rng.choice(wrong_dtypes)
4746 else:
4747 outputDType = a.dtype
4748
4749 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004750
4751 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004752 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004753 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004754
Kevin Cheng550ccc52021-03-03 11:21:43 -08004755 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004756
Luke Huttona4e48ca2023-02-22 11:53:48 +00004757 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01004758 for i in range(len(output_shape)):
4759 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004760
Luke Huttona4e48ca2023-02-22 11:53:48 +00004761 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4762 for i in range(len(output_shape)):
4763 output_shape[i] += rng.integers(1, 10)
4764 elif error_name == ErrorIf.RankMismatch:
4765 output_shape = get_rank_mismatch_shape(rng, output_shape)
4766
Matthew Haddone807aae2021-10-11 18:12:58 +01004767 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004768 all_dtypes = [
4769 DType.INT8,
4770 DType.INT16,
4771 DType.INT32,
4772 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004773 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004774 DType.FP16,
4775 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004776 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004777 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4778 outputDType = rng.choice(wrong_dtypes)
4779 else:
4780 outputDType = a.dtype
4781
4782 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004783
4784 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004785 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004786 if error_name != ErrorIf.WrongRank:
4787 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004788 assert len(indices.shape) == 2
4789 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004790
Kevin Cheng77d0f762020-11-24 10:26:32 -08004791 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4792
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004793 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004794 all_dtypes = [
4795 DType.INT8,
4796 DType.INT16,
4797 DType.INT32,
4798 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004799 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004800 DType.FP16,
4801 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004802 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004803 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4804 outputDType = rng.choice(wrong_dtypes)
4805 else:
4806 outputDType = values.dtype
4807
4808 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004809
4810 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004811 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004812 if error_name != ErrorIf.WrongRank:
4813 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004814 assert len(indices.shape) == 2
4815 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004816 assert values_in.shape[0] == indices.shape[0] # N
4817 assert input.shape[1] == indices.shape[1] # W
4818 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004819
4820 output_shape = values_in.shape
4821
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004822 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004823 all_dtypes = [
4824 DType.INT8,
4825 DType.INT16,
4826 DType.INT32,
4827 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004828 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004829 DType.FP16,
4830 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004831 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004832 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4833 outputDType = rng.choice(wrong_dtypes)
4834 else:
4835 outputDType = values_in.dtype
4836
4837 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004838
4839 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004840 def tableOp(ser, rng, input, error_name=None):
4841 # Same shape as the input, dtype dependent on input dtype
4842 if error_name != ErrorIf.WrongInputType:
4843 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004844 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004845 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004846 wrong_dtypes = [
4847 DType.INT8,
4848 DType.INT16,
4849 DType.INT32,
4850 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004851 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004852 DType.FP16,
4853 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004854 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004855 wrong_dtypes.remove(output_dtype)
4856 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004857 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004858
4859 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004860 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004861 serializer,
4862 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004863 input,
4864 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004865 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004866 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004867 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004868 input_dtype,
4869 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004870 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004871 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004872 # Calculate OH, OW
4873 scale_y_n = scale[0]
4874 scale_y_d = scale[1]
4875 scale_x_n = scale[2]
4876 scale_x_d = scale[3]
4877 if error_name == ErrorIf.ScaleSmallerEqualZero:
4878 scale_y_n = max(scale_y_n, 1)
4879 scale_y_d = max(scale_y_d, 1)
4880 scale_x_n = max(scale_x_n, 1)
4881 scale_x_d = max(scale_x_d, 1)
4882
4883 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4884 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4885
4886 if error_name is not None:
4887 # Make sure the output tensor is valid, which can occur when
4888 # scale, offset or border have been changed for ERROR_IFs
4889 oh = max(oh, 1)
4890 ow = max(ow, 1)
4891 if error_name != ErrorIf.MaxDimExceeded:
4892 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4893 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4894
4895 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4896 choices = [1, 2, 3]
4897 change = rng.choice(choices)
4898 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4899 if change in [1, 3]:
4900 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4901 oh -= scale_y_d
4902 assert oh > 0 # Should have been caught in agResize
4903 else:
4904 oh += scale_y_d
4905 if change in [2, 3]:
4906 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4907 ow -= scale_x_d
4908 assert ow > 0 # Should have been caught in agResize
4909 else:
4910 ow += scale_x_d
4911
Matthew Haddon848efb42021-09-09 12:30:53 +01004912 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004913 output_dims = [
4914 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004915 oh,
4916 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004917 input.shape[0],
4918 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004919 elif error_name == ErrorIf.BatchMismatch:
4920 output_dims = [
4921 input.shape[0] + rng.integers(1, 10),
4922 oh,
4923 ow,
4924 input.shape[3],
4925 ]
4926 elif error_name == ErrorIf.ChannelMismatch:
4927 output_dims = [
4928 input.shape[0],
4929 oh,
4930 ow,
4931 input.shape[3] + rng.integers(1, 10),
4932 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004933 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004934 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004935
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004936 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004937
4938 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004939 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004940 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004941
4942 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004943 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004944 if error_name == ErrorIf.ConvOutputShapeMismatch:
4945 choices = [1, 2, 3]
4946 change = rng.choice(choices)
4947 if change in [1, 3]:
4948 output_shape[1] = output_shape[1] + rng.choice(choices)
4949 if change in [2, 3]:
4950 output_shape[2] = output_shape[2] + rng.choice(choices)
4951
James Ward8b390432022-08-12 20:48:56 +01004952 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004953 # Pick some potentially correct output dtype if input type is incorrect
4954 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004955 else:
James Ward8b390432022-08-12 20:48:56 +01004956 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004957
4958 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004959 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004960 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004961 else:
4962 excludes = [out_dtype]
4963 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004964 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004965
Kevin Cheng550ccc52021-03-03 11:21:43 -08004966 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00004967
4968 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00004969 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
4970 outputs = []
4971
4972 assert ifm1.dtype == ifm2.dtype
4973 input_dtype = ifm1.dtype
4974
4975 if error_name != ErrorIf.FFTInputShapeMismatch:
4976 assert ifm1.shape == ifm2.shape
4977
4978 input_shape = ifm1.shape
4979 if error_name != ErrorIf.WrongRank:
4980 assert len(input_shape) == 3
4981
4982 output_shape = input_shape.copy()
4983 output_dtype = input_dtype
4984
4985 if error_name == ErrorIf.WrongOutputType:
4986 excludes = [DType.FP32]
4987 wrong_dtypes = list(usableDTypes(excludes=excludes))
4988 output_dtype = rng.choice(wrong_dtypes)
4989 elif error_name == ErrorIf.BatchMismatch:
4990 output_shape[0] += rng.integers(1, 10)
4991 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4992 modify_dim = rng.choice([1, 2])
4993 output_shape[modify_dim] += rng.integers(1, 10)
4994
4995 outputs.append(serializer.addOutput(output_shape, output_dtype))
4996 outputs.append(serializer.addOutput(output_shape, output_dtype))
4997 return outputs
4998
4999 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005000 def rfft2dOp(serializer, rng, value, error_name=None):
5001 outputs = []
5002
5003 input_shape = value.shape
5004 if error_name != ErrorIf.WrongRank:
5005 assert len(input_shape) == 3
5006
5007 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5008
5009 output_dtype = value.dtype
5010 if error_name == ErrorIf.WrongOutputType:
5011 excludes = [DType.FP32]
5012 wrong_dtypes = list(usableDTypes(excludes=excludes))
5013 output_dtype = rng.choice(wrong_dtypes)
5014 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005015 output_shape[0] += rng.integers(1, 10)
5016 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5017 modify_dim = rng.choice([1, 2])
5018 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005019
5020 outputs.append(serializer.addOutput(output_shape, output_dtype))
5021 outputs.append(serializer.addOutput(output_shape, output_dtype))
5022 return outputs