blob: 65bdeb737f581c63ab6750ccf028927d8b1dc0c8 [file] [log] [blame]
Jerry Ge9e94af82022-10-27 09:57:00 -07001# Copyright (c) 2020-2023, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01004from copy import deepcopy
Eric Kunzee5e26762020-10-13 16:11:07 -07005
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +00007import serializer.tosa_serializer as ts
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_arg_gen import TosaArgGen
9from generator.tosa_arg_gen import TosaQuantGen
10from generator.tosa_arg_gen import TosaTensorGen
11from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from generator.tosa_error_if import TosaErrorIfArgGen
14from generator.tosa_error_if import TosaErrorValidator
15from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010016from generator.tosa_utils import DTYPE_ATTRIBUTES
Luke Huttona4e48ca2023-02-22 11:53:48 +000017from generator.tosa_utils import get_rank_mismatch_shape
Jeremy Johnson05c711e2022-12-12 18:00:41 +000018from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010019from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_utils import usableDTypes
James Ward24dbc422022-10-19 12:20:31 +010021from generator.tosa_utils import vect_f32_to_bf16
Les Bell0e027d42021-11-09 14:42:14 +000022from tosa.DType import DType
23from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010024
25
Eric Kunzee5e26762020-10-13 16:11:07 -070026class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010027 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000028 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010029 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010030 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010031 TOSA_8K_LEVEL_MAX_KERNEL = 8192
32 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010033
Eric Kunzee5e26762020-10-13 16:11:07 -070034 def __init__(self, args):
35 self.args = args
36 self.basePath = args.output_dir
37 self.random_seed = args.random_seed
38 self.ser = None
39 self.rng = np.random.default_rng(self.random_seed)
40 self.createDynamicOpLists()
41 self.initOpListDefaults()
42 self.quantGen = TosaQuantGen()
43 # Force makeShape to do a specific starting shape
44 self.targetted_shape = None
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010045 # Work out floating point range
46 self.random_fp_low = min(args.tensor_fp_value_range)
47 self.random_fp_high = max(args.tensor_fp_value_range)
Eric Kunzee5e26762020-10-13 16:11:07 -070048
49 def createSerializer(self, opName, testPath):
50 self.testPath = os.path.join(opName, testPath)
51
52 fullPath = os.path.join(self.basePath, self.testPath)
53 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnsona0848c62022-09-15 15:01:30 +010054 self.ser = ts.TosaSerializer(fullPath, saveConstsToFile=self.args.dump_consts)
Eric Kunzee5e26762020-10-13 16:11:07 -070055
56 def getSerializer(self):
57 return self.ser
58
59 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -080060 with open(
61 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
62 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -070063 fd.write(self.ser.serialize())
64
Kevin Cheng550ccc52021-03-03 11:21:43 -080065 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
66 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -070067
Matthew Haddon74567092021-07-16 15:38:20 +010068 def resetRNG(self, seed=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000069 if seed is None:
Matthew Haddon74567092021-07-16 15:38:20 +010070 seed = self.random_seed + 1
71 self.rng = np.random.default_rng(seed)
72
Eric Kunzee5e26762020-10-13 16:11:07 -070073 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 if dtype == DType.BOOL:
Eric Kunzee5e26762020-10-13 16:11:07 -070075 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -070076 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -070077 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -070078 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070079 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +010080 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
81 elif dtype == DType.UINT8:
82 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070083 elif dtype == DType.INT16:
84 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +010085 elif dtype == DType.UINT16:
86 return np.int32(self.rng.integers(low=0, high=65536, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -070087 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -080088 return np.int32(
89 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
90 )
Eric Kunzee5e26762020-10-13 16:11:07 -070091 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -080092 return np.int64(
93 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
94 )
James Ward8b390432022-08-12 20:48:56 +010095 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010096 return np.float16(
97 self.rng.uniform(
98 low=self.random_fp_low, high=self.random_fp_high, size=shape
99 )
100 )
James Ward24dbc422022-10-19 12:20:31 +0100101 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100102 f32_tensor = np.float32(
103 self.rng.uniform(
104 low=self.random_fp_low, high=self.random_fp_high, size=shape
105 )
106 )
James Ward24dbc422022-10-19 12:20:31 +0100107 # Floor the last 16 bits of each f32 value
108 return np.float32(vect_f32_to_bf16(f32_tensor))
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100109 elif dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100110 return np.float32(
111 self.rng.uniform(
112 low=self.random_fp_low, high=self.random_fp_high, size=shape
113 )
114 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700115 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
Kevin Cheng989cb052021-04-28 16:29:44 -0700118 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700119 placeholders = []
120
Kevin Cheng989cb052021-04-28 16:29:44 -0700121 assert len(shape_list) == len(dtype_list)
122
123 for idx, shape in enumerate(shape_list):
124 arr = self.getRandTensor(shape, dtype_list[idx])
125 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700126
127 return placeholders
128
Kevin Cheng989cb052021-04-28 16:29:44 -0700129 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700130 consts = []
131
Kevin Cheng989cb052021-04-28 16:29:44 -0700132 assert len(shape_list) == len(dtype_list)
133
134 for idx, shape in enumerate(shape_list):
135 arr = self.getRandTensor(shape, dtype_list[idx])
136 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700137
138 return consts
139
140 def makeShape(self, rank):
141 if self.targetted_shape:
142 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800143 return np.int32(
144 self.rng.integers(
145 low=self.args.tensor_shape_range[0],
146 high=self.args.tensor_shape_range[1],
147 size=rank,
148 )
149 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700150
151 def setTargetShape(self, shape):
152 self.targetted_shape = shape
153
154 def randInt(self, low=0, high=256):
155 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
156
157 def getRandNumberDType(self, dtype):
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100158 if dtype == DType.FP32:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100159 return np.float32(
160 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
161 )
James Ward8b390432022-08-12 20:48:56 +0100162 elif dtype == DType.FP16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100163 return np.float16(
164 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
165 )
James Ward24dbc422022-10-19 12:20:31 +0100166 elif dtype == DType.BF16:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100167 rand_f32 = np.float32(
168 self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
169 )
James Ward24dbc422022-10-19 12:20:31 +0100170 return vect_f32_to_bf16(rand_f32)
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 elif dtype == DType.BOOL:
172 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700173 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700175 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700176 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100177 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 elif dtype == DType.INT16:
179 low, high = (-32768, 32768)
180 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700182 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800183 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 # Special size
185 return np.int64(self.rng.integers(low, high, size=1))[0]
186 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800187 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700188
189 return np.int32(self.rng.integers(low, high, size=1))[0]
190
191 def shapeStr(self, shape):
192
193 sStr = []
194 # Convert to strings
195 for i in shape:
196 sStr.append(str(i))
197
Kevin Cheng550ccc52021-03-03 11:21:43 -0800198 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700199
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100200 def typeStr(self, dtype):
201 if isinstance(dtype, list) or isinstance(dtype, tuple):
202 assert len(dtype) >= 2
203 strs = [self.typeStr(t) for t in dtype]
204 # Limit types to the first 2 as the 3rd is the accumulator
205 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100207 if dtype in DTYPE_ATTRIBUTES:
208 return DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700209 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100210 raise Exception(
211 "Unknown dtype, cannot convert to string: {}".format(dtype)
212 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700213
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100214 def typeWidth(self, dtype):
James Ward8b390432022-08-12 20:48:56 +0100215 """Get the datatype width for data types"""
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100216 if dtype in DTYPE_ATTRIBUTES:
217 return DTYPE_ATTRIBUTES[dtype]["width"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700218 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100219 raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
Eric Kunzee5e26762020-10-13 16:11:07 -0700220
Luke Hutton57287132023-02-06 14:54:18 +0000221 def constrictBatchSize(self, shape):
222 # Limit the batch size unless an explicit target shape set
223 if self.args.max_batch_size and not self.args.target_shapes:
224 shape[0] = min(shape[0], self.args.max_batch_size)
225 return shape
226
James Ward30124a82023-02-02 14:56:33 +0000227 def makeDimension(self):
228 return self.randInt(
229 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
230 )
231
Eric Kunzee5e26762020-10-13 16:11:07 -0700232 # Argument generators
233 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
234 # Where the string descriptor is used to generate the test name and
235 # The build_fcn_arg_list is expanded and passed to the operator test
236 # build function
237
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100238 def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
239 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
240
Matthew Haddon848efb42021-09-09 12:30:53 +0100241 # build_placeholder returns an int, ABS/other ops does not
242 if isinstance(op, int):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000243 self.ser.addOperator(op, a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100244 return result_tens
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000245 elif op["op"] == Op.IDENTITY:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000246 self.ser.addOperator(op["op"], a.name, result_tens.name, None)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100247 return result_tens
248
249 # Ensure new output type has correct qinfo
250 if error_name == ErrorIf.WrongOutputType:
251 if result_tens.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000252 qinfo = [
253 TosaQuantGen.getZeroPoint(self, a.dtype),
254 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
255 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100256
257 # Invalidate Input/Output list for error if checks.
258 input_list = [a.name]
259 output_list = [result_tens.name]
260 pCount, cCount = op["operands"]
261 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000262 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
263 self, error_name, input_list, output_list
264 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100265
Les Bell729b0352021-11-24 10:28:21 +0000266 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100267 self.ser,
268 validator_fcns,
269 error_name,
270 op=op,
271 input_dtype=a.dtype,
272 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000273 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000274 result_tensors=[result_tens],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100275 input_list=input_list,
276 output_list=output_list,
277 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000278 ):
279 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100280
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000281 attr = None
282 if op["op"] == Op.NEGATE:
283 attr = ts.TosaSerializerAttribute()
284 attr.NegateAttribute(qinfo[0], qinfo[1])
285
286 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 return result_tens
288
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100289 def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000290 result_tens = OutputShaper.binaryBroadcastOp(
291 self.ser, self.rng, a, b, error_name
292 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100293
294 # Invalidate Input/Output list for error if checks.
295 input_list = [a.name, b.name]
296 output_list = [result_tens.name]
297 pCount, cCount = op["operands"]
298 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000299 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
300 self, error_name, input_list, output_list
301 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100302
Les Bell729b0352021-11-24 10:28:21 +0000303 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100304 self.ser,
305 validator_fcns,
306 error_name,
307 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000308 input1=a,
309 input2=b,
310 input_dtype=a.dtype,
311 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000312 result_tensors=[result_tens],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100313 input_list=input_list,
314 output_list=output_list,
315 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000316 ):
317 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100318
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000319 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -0700320 return result_tens
321
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100322 def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700323 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000324 self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -0700325 return result_tens
326
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000327 def build_arithmetic_right_shift(
328 self, op, a, b, round, validator_fcns=None, error_name=None
329 ):
330 result_tens = OutputShaper.binaryBroadcastOp(
331 self.ser, self.rng, a, b, error_name
332 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100333
334 # Invalidate Input/Output list for error if checks.
335 input_list = [a.name, b.name]
336 output_list = [result_tens.name]
337 pCount, cCount = op["operands"]
338 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000339 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
340 self, error_name, input_list, output_list
341 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100342
Les Bell729b0352021-11-24 10:28:21 +0000343 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100344 self.ser,
345 validator_fcns,
346 error_name,
347 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000348 input1=a,
349 input2=b,
350 input_dtype=a.dtype,
351 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000352 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100353 input_list=input_list,
354 output_list=output_list,
355 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000356 ):
357 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800358
359 attr = ts.TosaSerializerAttribute()
360 attr.ArithmeticRightShiftAttribute(round)
361
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000362 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800363 return result_tens
364
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100365 def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000366 result_tens = OutputShaper.binaryBroadcastOp(
367 self.ser, self.rng, a, b, error_name
368 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
370 # Special for multiply:
371 # Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100372 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Eric Kunzee5e26762020-10-13 16:11:07 -0700373 result_tens.setDtype(DType.INT32)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100374 if error_name == ErrorIf.WrongOutputType:
375 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
376 outputDType = self.rng.choice(all_dtypes)
377 result_tens.setDtype(outputDType)
378
379 # Invalidate Input/Output list for error if checks.
380 input_list = [a.name, b.name]
381 output_list = [result_tens.name]
382 pCount, cCount = op["operands"]
383 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000384 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
385 self, error_name, input_list, output_list
386 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100387
Les Bell729b0352021-11-24 10:28:21 +0000388 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100389 self.ser,
390 validator_fcns,
391 error_name,
392 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 input1=a,
394 input2=b,
395 input_dtype=a.dtype,
396 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000397 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100398 input_list=input_list,
399 output_list=output_list,
400 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000401 ):
402 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700403
Kevin Chengaee1fac2020-11-11 13:54:06 -0800404 attr = ts.TosaSerializerAttribute()
405 attr.MulAttribute(shift)
406
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000407 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700408 return result_tens
409
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100410 def build_table(self, op, a, table, validator_fcns=None, error_name=None):
411 result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700412
Kevin Chengfe392ce2021-10-18 21:51:55 +0000413 attr = ts.TosaSerializerAttribute()
414 attr.TableAttribute(table)
415
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100416 # Invalidate Input/Output list for error if checks.
417 input_list = [a.name]
418 output_list = [result_tens.name]
419 pCount, cCount = op["operands"]
420 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000421 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
422 self, error_name, input_list, output_list
423 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100424
Les Bell729b0352021-11-24 10:28:21 +0000425 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100426 self.ser,
427 validator_fcns,
428 error_name,
429 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000430 input_shape=a.shape,
431 input_dtype=a.dtype,
432 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000433 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100434 input_list=input_list,
435 output_list=output_list,
436 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000437 ):
438 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100439
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000440 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700441
442 return result_tens
443
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100444 def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
445 result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
446
447 # Invalidate Input/Output list for error if checks.
448 input_list = [cond.name, a.name, b.name]
449 output_list = [result_tens.name]
450 pCount, cCount = op["operands"]
451 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000452 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
453 self, error_name, input_list, output_list
454 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100455
Les Bell729b0352021-11-24 10:28:21 +0000456 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100457 self.ser,
458 validator_fcns,
459 error_name,
460 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000461 input1=cond,
462 input2=a,
463 input3=b,
464 input_shape=a.shape,
465 input_dtype=a.dtype,
466 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000467 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100468 input_list=input_list,
469 output_list=output_list,
470 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000471 ):
472 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100473
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 self.ser.addOperator(
475 op["op"],
476 input_list,
477 output_list,
478 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 return result_tens
480
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100481 def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000482 result_tens = OutputShaper.binaryComparisonOp(
483 self.ser, self.rng, a, b, error_name
484 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100485
486 # Invalidate Input/Output list for error if checks.
487 input_list = [a.name, b.name]
488 output_list = [result_tens.name]
489 pCount, cCount = op["operands"]
490 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000491 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
492 self, error_name, input_list, output_list
493 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100494
Les Bell729b0352021-11-24 10:28:21 +0000495 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100496 self.ser,
497 validator_fcns,
498 error_name,
499 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000500 input1=a,
501 input2=b,
502 input_shape=a.shape,
503 input_dtype=a.dtype,
504 output_shape=result_tens.shape,
505 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000506 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100507 input_list=input_list,
508 output_list=output_list,
509 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000510 ):
511 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100512
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000513 self.ser.addOperator(
514 op["op"],
515 input_list,
516 output_list,
517 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700518 return result_tens
519
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100520 def build_argmax(self, op, a, axis, validator_fcns, error_name):
521 result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
522
523 # Invalidate Input/Output list for error if checks.
524 input_list = [a.name]
525 output_list = [result_tens.name]
526 pCount, cCount = op["operands"]
527 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
529 self, error_name, input_list, output_list
530 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100531
Les Bell729b0352021-11-24 10:28:21 +0000532 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100533 self.ser,
534 validator_fcns,
535 error_name,
536 op=op,
537 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000538 input_shape=a.shape,
539 input_dtype=a.dtype,
540 output_shape=result_tens.shape,
541 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +0000542 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100543 input_list=input_list,
544 output_list=output_list,
545 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000546 ):
547 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
549 attr = ts.TosaSerializerAttribute()
550 attr.AxisAttribute(axis)
551
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000552 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700553 return result_tens
554
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000555 def build_pool2d(
556 self,
557 op,
558 input,
James Ward8b390432022-08-12 20:48:56 +0100559 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000560 stride,
561 pad,
562 kernel,
563 validator_fcns=None,
564 error_name=None,
565 qinfo=None,
566 ):
567 result_tens = OutputShaper.pool2dOp(
568 self.ser, self.rng, input, kernel, stride, pad, error_name
569 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100570
571 # Ensure new output type has correct qinfo
572 if error_name == ErrorIf.WrongInputType:
573 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000574 qinfo = [
575 TosaQuantGen.getZeroPoint(self, input.dtype),
576 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
577 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100578
579 # Invalidate Input/Output list for error if checks.
580 input_list = [input.name]
581 output_list = [result_tens.name]
582 pCount, cCount = op["operands"]
583 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000584 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
585 self, error_name, input_list, output_list
586 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100587
Les Bell729b0352021-11-24 10:28:21 +0000588 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100589 self.ser,
590 validator_fcns,
591 error_name,
592 op=op,
593 input_shape=input.shape,
594 input_dtype=input.dtype,
595 output_shape=result_tens.shape,
596 output_dtype=result_tens.dtype,
597 kernel=kernel,
598 stride=stride,
599 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000600 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000601 result_tensors=[result_tens],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100602 input_list=input_list,
603 output_list=output_list,
604 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000605 ):
606 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700607
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000608 if qinfo is None:
609 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700610
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000611 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100612 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000613
614 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700615 return result_tens
616
James Ward8b390432022-08-12 20:48:56 +0100617 def build_maxpool2d(
618 self,
619 op,
620 input,
621 stride,
622 pad,
623 kernel,
624 validator_fcns=None,
625 error_name=None,
626 qinfo=None,
627 ):
628 # Same as build_pool2d but manually sets accum_dtype value
629 # (maxpool has no accum_dtype)
630 return self.build_pool2d(
631 op,
632 input,
633 DType.UNKNOWN,
634 stride,
635 pad,
636 kernel,
637 validator_fcns,
638 error_name,
639 qinfo,
640 )
641
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000642 def build_conv2d(
643 self,
644 op,
645 ifm,
646 filter,
647 bias,
James Ward8b390432022-08-12 20:48:56 +0100648 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000649 strides,
650 padding,
651 dilations,
652 validator_fcns=None,
653 error_name=None,
654 qinfo=None,
655 ):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800656 assert len(padding) == 4
657 result_tens = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100658 self.ser,
659 self.rng,
660 ifm,
661 filter,
662 accum_dtype,
663 strides,
664 padding,
665 dilations,
666 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000667 )
668
669 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000670 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
671 DType.INT8,
672 DType.UINT8,
673 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000674 qinfo = [
675 TosaQuantGen.getZeroPoint(self, ifm.dtype),
676 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
677 ]
Les Bell0e027d42021-11-09 14:42:14 +0000678
679 # Invalidate Input/Output list for error_if checks.
680 input_list = [ifm.name, filter.name, bias.name]
681 output_list = [result_tens.name]
682 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000683 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
684 self, error_name, input_list, output_list
685 )
Les Bell0e027d42021-11-09 14:42:14 +0000686
Les Bell729b0352021-11-24 10:28:21 +0000687 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000688 self.ser,
689 validator_fcns,
690 error_name,
691 op=op,
692 input_dtype=ifm.dtype,
693 weight_dtype=filter.dtype,
694 output_dtype=result_tens.dtype,
695 qinfo=qinfo,
696 input_list=input_list,
697 num_operands=num_operands,
698 output_list=output_list,
699 pad=padding,
700 stride=strides,
701 dilation=dilations,
702 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100703 weight_shape=filter.shape,
704 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000705 ):
706 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700707
708 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000709 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700710
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000711 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700712 return result_tens
713
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000714 def build_conv3d(
715 self,
716 op,
717 ifm,
718 filter,
719 bias,
James Ward8b390432022-08-12 20:48:56 +0100720 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000721 strides,
722 padding,
723 dilations,
724 validator_fcns=None,
725 error_name=None,
726 qinfo=None,
727 ):
Kevin Cheng1533b852021-09-01 12:51:58 -0700728 assert len(padding) == 6
729 result_tens = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100730 self.ser,
731 self.rng,
732 ifm,
733 filter,
734 accum_dtype,
735 strides,
736 padding,
737 dilations,
738 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000739 )
740
741 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000742 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
743 DType.INT8,
744 DType.UINT8,
745 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000746 qinfo = [
747 TosaQuantGen.getZeroPoint(self, ifm.dtype),
748 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
749 ]
Les Bell0e027d42021-11-09 14:42:14 +0000750
751 # Invalidate Input/Output list for error_if checks.
752 input_list = [ifm.name, filter.name, bias.name]
753 output_list = [result_tens.name]
754 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000755 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
756 self, error_name, input_list, output_list
757 )
Les Bell0e027d42021-11-09 14:42:14 +0000758
Les Bell729b0352021-11-24 10:28:21 +0000759 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000760 self.ser,
761 validator_fcns,
762 error_name,
763 op=op,
764 input_dtype=ifm.dtype,
765 weight_dtype=filter.dtype,
766 output_dtype=result_tens.dtype,
767 qinfo=qinfo,
768 input_list=input_list,
769 num_operands=num_operands,
770 output_list=output_list,
771 pad=padding,
772 stride=strides,
773 dilation=dilations,
774 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100775 weight_shape=filter.shape,
776 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000777 ):
778 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700779
780 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000781 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Kevin Cheng1533b852021-09-01 12:51:58 -0700782
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000783 self.ser.addOperator(op["op"], input_list, output_list, attr)
Kevin Cheng1533b852021-09-01 12:51:58 -0700784 return result_tens
785
Kevin Cheng550ccc52021-03-03 11:21:43 -0800786 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000787 self,
788 op,
789 ifm,
790 filter,
791 bias,
James Ward8b390432022-08-12 20:48:56 +0100792 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000793 stride,
TatWai Chong24594f52022-06-08 00:48:04 -0700794 out_pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000795 output_shape,
796 validator_fcns=None,
797 error_name=None,
798 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800799 ):
TatWai Chong24594f52022-06-08 00:48:04 -0700800 assert len(out_pad) == 4
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 result_tens = OutputShaper.transposeConv2DOp(
James Ward8b390432022-08-12 20:48:56 +0100802 self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 )
Les Bell0e027d42021-11-09 14:42:14 +0000804
805 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000806 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
807 DType.INT8,
808 DType.UINT8,
809 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000810 qinfo = [
811 TosaQuantGen.getZeroPoint(self, ifm.dtype),
812 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
813 ]
Les Bell0e027d42021-11-09 14:42:14 +0000814
815 # Invalidate Input/Output list for error_if checks.
816 input_list = [ifm.name, filter.name, bias.name]
817 output_list = [result_tens.name]
818 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000819 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
820 self, error_name, input_list, output_list
821 )
Les Bell0e027d42021-11-09 14:42:14 +0000822
Les Bell729b0352021-11-24 10:28:21 +0000823 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000824 self.ser,
825 validator_fcns,
826 error_name,
827 op=op,
828 input_dtype=ifm.dtype,
829 weight_dtype=filter.dtype,
830 output_dtype=result_tens.dtype,
831 qinfo=qinfo,
832 input_list=input_list,
833 num_operands=num_operands,
834 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -0700835 pad=out_pad,
Les Bell0e027d42021-11-09 14:42:14 +0000836 stride=stride,
Les Bell0e027d42021-11-09 14:42:14 +0000837 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100838 weight_shape=filter.shape,
839 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000840 ):
841 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700842
843 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000844 attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700845
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000846 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700847 return result_tens
848
Kevin Cheng550ccc52021-03-03 11:21:43 -0800849 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000850 self,
851 op,
852 ifm,
853 filter,
854 bias,
James Ward8b390432022-08-12 20:48:56 +0100855 accum_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000856 strides,
857 padding,
858 dilations,
859 validator_fcns=None,
860 error_name=None,
861 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800862 ):
863 result_tens = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100864 self.ser,
865 self.rng,
866 ifm,
867 filter,
868 accum_dtype,
869 strides,
870 padding,
871 dilations,
872 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000873 )
874
875 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000876 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
877 DType.INT8,
878 DType.UINT8,
879 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000880 qinfo = [
881 TosaQuantGen.getZeroPoint(self, ifm.dtype),
882 TosaQuantGen.getZeroPoint(self, result_tens.dtype),
883 ]
Les Bell0e027d42021-11-09 14:42:14 +0000884
885 # Invalidate Input/Output list for error_if checks.
886 input_list = [ifm.name, filter.name, bias.name]
887 output_list = [result_tens.name]
888 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000889 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
890 self, error_name, input_list, output_list
891 )
Les Bell0e027d42021-11-09 14:42:14 +0000892
Les Bell729b0352021-11-24 10:28:21 +0000893 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000894 self.ser,
895 validator_fcns,
896 error_name,
897 op=op,
898 input_dtype=ifm.dtype,
899 weight_dtype=filter.dtype,
900 output_dtype=result_tens.dtype,
901 qinfo=qinfo,
902 input_list=input_list,
903 num_operands=num_operands,
904 output_list=output_list,
905 pad=padding,
906 stride=strides,
907 dilation=dilations,
908 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100909 weight_shape=filter.shape,
910 output_shape=result_tens.shape,
Les Bell729b0352021-11-24 10:28:21 +0000911 ):
912 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700913
914 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000915 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
Eric Kunzee5e26762020-10-13 16:11:07 -0700916
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000917 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700918 return result_tens
919
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000920 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +0100921 self,
922 op,
923 ifm,
924 filter,
925 bias,
926 accum_dtype,
927 validator_fcns=None,
928 error_name=None,
929 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000930 ):
931 result_tens = OutputShaper.fullyConnectedOp(
James Ward8b390432022-08-12 20:48:56 +0100932 self.ser, self.rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000933 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100934
935 # Invalidate Input/Output list for error if checks.
936 input_list = [ifm.name, filter.name, bias.name]
937 output_list = [result_tens.name]
938 pCount, cCount = op["operands"]
939 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000940 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
941 self, error_name, input_list, output_list
942 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100943
Les Bell729b0352021-11-24 10:28:21 +0000944 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100945 self.ser,
946 validator_fcns,
947 error_name,
948 op=op,
949 input_shape=ifm.shape,
950 input_dtype=ifm.dtype,
951 weight_dtype=filter.dtype,
952 output_shape=result_tens.shape,
953 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000954 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000955 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100956 input_list=input_list,
957 output_list=output_list,
958 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +0100959 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000960 ):
961 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700962
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000963 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +0000964 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000965
966 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700967 return result_tens
968
James Ward8b390432022-08-12 20:48:56 +0100969 def build_matmul(
970 self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
971 ):
972 result_tens = OutputShaper.matmulOp(
973 self.ser, self.rng, a, b, accum_dtype, error_name
974 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100975
976 # Invalidate Input/Output list for error if checks.
977 input_list = [a.name, b.name]
978 output_list = [result_tens.name]
979 pCount, cCount = op["operands"]
980 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000981 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
982 self, error_name, input_list, output_list
983 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100984
Les Bell729b0352021-11-24 10:28:21 +0000985 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100986 self.ser,
987 validator_fcns,
988 error_name,
989 op=op,
990 input_shape=a.shape,
991 input_dtype=a.dtype,
992 input2_shape=b.shape,
993 input2_dtype=b.dtype,
994 output_shape=result_tens.shape,
995 output_dtype=result_tens.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000996 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +0000997 result_tensors=[result_tens],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100998 input_list=input_list,
999 output_list=output_list,
1000 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001001 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001002 ):
1003 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001004
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001005 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001006 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001007
1008 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001009 return result_tens
1010
Matthew Haddond6ce7252021-09-29 15:35:44 +01001011 def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
1012 result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
1013
1014 # Invalidate Input/Output list for error if checks.
1015 input_list = [a.name]
1016 output_list = [result_tens.name]
1017 pCount, cCount = op["operands"]
1018 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001019 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1020 self, error_name, input_list, output_list
1021 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001022
Les Bell729b0352021-11-24 10:28:21 +00001023 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001024 self.ser,
1025 validator_fcns,
1026 error_name,
1027 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 axis=axis,
1029 input_shape=a.shape,
1030 output_shape=result_tens.shape,
1031 input_dtype=a.dtype,
1032 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001033 result_tensors=[result_tens],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001034 input_list=input_list,
1035 output_list=output_list,
1036 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001037 ):
1038 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001039
1040 attr = ts.TosaSerializerAttribute()
1041 attr.AxisAttribute(axis)
1042
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001043 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001044 return result_tens
1045
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001046 def build_clamp(self, op, a, validator_fcns=None, error_name=None):
1047 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001048
Jeremy Johnson18e26662021-07-22 16:15:29 +01001049 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001050
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001051 if error_name == ErrorIf.MaxSmallerMin:
1052 # Make sure the numbers are different to invoke this error
1053 while v[0] == v[1]:
1054 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
1055 max_val = min(v)
1056 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001057 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001058 max_val = max(v)
1059 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001060
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001061 # Invalidate Input/Output list for error if checks.
1062 input_list = [a.name]
1063 output_list = [result_tens.name]
1064 pCount, cCount = op["operands"]
1065 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001066 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1067 self, error_name, input_list, output_list
1068 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001069
Les Bell729b0352021-11-24 10:28:21 +00001070 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001071 self.ser,
1072 validator_fcns,
1073 error_name,
1074 op=op,
1075 max_val=max_val,
1076 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001077 input_shape=a.shape,
1078 output_shape=result_tens.shape,
1079 input_dtype=a.dtype,
1080 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001081 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001082 input_list=input_list,
1083 output_list=output_list,
1084 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001085 ):
1086 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001087
1088 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001089 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1090 if a.dtype == DType.FP16:
1091 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1092 min_val = min_val.astype(np.float32)
1093 max_val = max_val.astype(np.float32)
1094
1095 attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001096 else:
James Ward34071252022-12-07 15:48:47 +00001097 attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001098
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001099 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001100 return result_tens
1101
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001102 def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
1103 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001104 attr = ts.TosaSerializerAttribute()
1105
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001106 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
Eric Kunzee5e26762020-10-13 16:11:07 -07001107
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001108 self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001109 return result_tens
1110
1111 # Needs an additional type/input
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001112 def build_prelu(self, op, a, validator_fcns=None, error_name=None):
1113 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001114
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001115 self.ser.addOperator(op["op"], [a.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001116 return result_tens
1117
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001118 def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
1119 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1120
1121 # Invalidate Input/Output list for error if checks.
1122 input_list = [a.name]
1123 output_list = [result_tens.name]
1124 pCount, cCount = op["operands"]
1125 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001126 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1127 self, error_name, input_list, output_list
1128 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001129
Les Bell729b0352021-11-24 10:28:21 +00001130 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001131 self.ser,
1132 validator_fcns,
1133 error_name,
1134 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001135 input_shape=a.shape,
1136 output_shape=result_tens.shape,
1137 input_dtype=a.dtype,
1138 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001139 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001140 input_list=input_list,
1141 output_list=output_list,
1142 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001143 ):
1144 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001145
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001146 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001147 return result_tens
1148
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001149 def build_tanh(self, op, a, validator_fcns=None, error_name=None):
1150 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1151
1152 # Invalidate Input/Output list for error if checks.
1153 input_list = [a.name]
1154 output_list = [result_tens.name]
1155 pCount, cCount = op["operands"]
1156 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001157 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1158 self, error_name, input_list, output_list
1159 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001160
Les Bell729b0352021-11-24 10:28:21 +00001161 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001162 self.ser,
1163 validator_fcns,
1164 error_name,
1165 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001166 input_shape=a.shape,
1167 output_shape=result_tens.shape,
1168 input_dtype=a.dtype,
1169 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001170 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001171 input_list=input_list,
1172 output_list=output_list,
1173 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001174 ):
1175 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001176
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001177 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001178 return result_tens
1179
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001180 def build_concat(self, op, *a, validator_fcns=None, error_name=None):
1181 if error_name != ErrorIf.WrongInputType:
1182 assert type(a[-1]) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001183
1184 # To store variable length list of input tensors we need to store axis along with it
1185 axis = a[-1]
1186 a = a[:-1]
1187
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001188 result_tens = OutputShaper.concatOp(
1189 self.ser, self.rng, axis, *a, error_name=error_name
1190 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001191
Matthew Haddon818ab902021-07-27 09:12:49 +01001192 input_tensor_names = []
1193 for tensor in a:
1194 input_tensor_names.append(tensor.name)
1195
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001196 # Invalidate Input/Output list for error if checks.
1197 input_list = input_tensor_names
1198 output_list = [result_tens.name]
1199 pCount, cCount = op["operands"]
1200 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001201 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1202 self, error_name, input_list, output_list
1203 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001204
Les Bell729b0352021-11-24 10:28:21 +00001205 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001206 self.ser,
1207 validator_fcns,
1208 error_name,
1209 op=op,
1210 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001211 input_shape=a[0].shape,
1212 output_shape=result_tens.shape,
1213 input_dtype=a[0].dtype,
1214 output_dtype=result_tens.dtype,
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001215 inputs=a,
Luke Hutton261b7b62023-01-10 14:50:31 +00001216 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001217 input_list=input_list,
1218 output_list=output_list,
1219 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001220 ):
1221 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001222
1223 attr = ts.TosaSerializerAttribute()
1224 attr.AxisAttribute(axis)
1225
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001226 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddon848efb42021-09-09 12:30:53 +01001227 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001228
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001229 def build_pad(
1230 self,
1231 op,
1232 a,
1233 padding,
1234 pad_const_int,
1235 pad_const_float,
1236 validator_fcns=None,
1237 error_name=None,
1238 qinfo=None,
1239 ):
Matthew Haddone807aae2021-10-11 18:12:58 +01001240 result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001241
Kevin Chengfe392ce2021-10-18 21:51:55 +00001242 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001243 attr.PadAttribute(
1244 self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
1245 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001246
Matthew Haddone807aae2021-10-11 18:12:58 +01001247 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001248 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001249 output_list = [result_tens.name]
1250 pCount, cCount = op["operands"]
1251 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001252 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1253 self, error_name, input_list, output_list
1254 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001255
Les Bell729b0352021-11-24 10:28:21 +00001256 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001257 self.ser,
1258 validator_fcns,
1259 error_name,
1260 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001261 input_shape=a.shape,
1262 output_shape=result_tens.shape,
1263 input_dtype=a.dtype,
1264 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001265 pad=padding,
1266 qinfo=qinfo,
Luke Hutton261b7b62023-01-10 14:50:31 +00001267 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001268 input_list=input_list,
1269 output_list=output_list,
1270 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001271 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001272 ):
1273 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001274
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001275 self.ser.addOperator(op["op"], input_list, output_list, attr)
Matthew Haddone86fd342021-09-07 16:12:21 +01001276 return result_tens
Eric Kunzee5e26762020-10-13 16:11:07 -07001277
Matthew Haddone807aae2021-10-11 18:12:58 +01001278 def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001279 result_tens = OutputShaper.reshapeOp(
1280 self.ser, self.rng, a, newShape, error_name
1281 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001282
1283 # Invalidate Input/Output list for error if checks.
1284 input_list = [a.name]
1285 output_list = [result_tens.name]
1286 pCount, cCount = op["operands"]
1287 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001288 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1289 self, error_name, input_list, output_list
1290 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001291
Les Bell729b0352021-11-24 10:28:21 +00001292 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001293 self.ser,
1294 validator_fcns,
1295 error_name,
1296 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001297 input_shape=a.shape,
1298 output_shape=result_tens.shape,
1299 input_dtype=a.dtype,
1300 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001301 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001302 input_list=input_list,
1303 output_list=output_list,
1304 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001305 ):
1306 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001307
1308 attr = ts.TosaSerializerAttribute()
1309 attr.ReshapeAttribute(newShape)
1310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001311 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001312 return result_tens
1313
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001314 def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
1315 result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
1316
1317 # Invalidate Input/Output list for error if checks.
1318 input_list = [a.name]
1319 output_list = [result_tens.name]
1320 pCount, cCount = op["operands"]
1321 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001322 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1323 self, error_name, input_list, output_list
1324 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001325
Les Bell729b0352021-11-24 10:28:21 +00001326 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001327 self.ser,
1328 validator_fcns,
1329 error_name,
1330 op=op,
1331 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001332 input_shape=a.shape,
1333 output_shape=result_tens.shape,
1334 input_dtype=a.dtype,
1335 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001336 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001337 input_list=input_list,
1338 output_list=output_list,
1339 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001340 ):
1341 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001342
1343 attr = ts.TosaSerializerAttribute()
1344 attr.AxisAttribute(axis)
1345
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001346 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001347 return result_tens
1348
Matthew Haddone807aae2021-10-11 18:12:58 +01001349 def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
1350 result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001351
Kevin Chengfe392ce2021-10-18 21:51:55 +00001352 attr = ts.TosaSerializerAttribute()
1353 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001354
Matthew Haddone807aae2021-10-11 18:12:58 +01001355 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001356 input_list = [a.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001357 output_list = [result_tens.name]
1358 pCount, cCount = op["operands"]
1359 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001360 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1361 self, error_name, input_list, output_list
1362 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001363
Les Bell729b0352021-11-24 10:28:21 +00001364 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001365 self.ser,
1366 validator_fcns,
1367 error_name,
1368 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001369 input_shape=a.shape,
1370 output_shape=result_tens.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001371 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001372 input_dtype=a.dtype,
1373 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001374 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001375 input_list=input_list,
1376 output_list=output_list,
1377 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001378 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001379 ):
1380 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001381
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001382 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001383 return result_tens
1384
Matthew Haddone807aae2021-10-11 18:12:58 +01001385 def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001386 result_tens = OutputShaper.sliceOp(
1387 self.ser, self.rng, a, start, size, error_name
1388 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001389
1390 # Invalidate Input/Output list for error if checks.
1391 input_list = [a.name]
1392 output_list = [result_tens.name]
1393 pCount, cCount = op["operands"]
1394 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001395 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1396 self, error_name, input_list, output_list
1397 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001398
Les Bell729b0352021-11-24 10:28:21 +00001399 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001400 self.ser,
1401 validator_fcns,
1402 error_name,
1403 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001404 input_shape=a.shape,
1405 output_shape=result_tens.shape,
1406 input_dtype=a.dtype,
1407 output_dtype=result_tens.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001408 start=start,
1409 size=size,
Luke Hutton261b7b62023-01-10 14:50:31 +00001410 result_tensors=[result_tens],
Matthew Haddone807aae2021-10-11 18:12:58 +01001411 input_list=input_list,
1412 output_list=output_list,
1413 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001414 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001415 ):
1416 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001417
1418 attr = ts.TosaSerializerAttribute()
Matthew Haddone807aae2021-10-11 18:12:58 +01001419 attr.SliceAttribute(start, size)
Eric Kunzee5e26762020-10-13 16:11:07 -07001420
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001421 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001422 return result_tens
1423
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
1425 result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
1426
1427 # Invalidate Input/Output list for error if checks.
1428 input_list = [a.name]
1429 output_list = [result_tens.name]
1430 pCount, cCount = op["operands"]
1431 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001432 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1433 self, error_name, input_list, output_list
1434 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001435
Les Bell729b0352021-11-24 10:28:21 +00001436 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001437 self.ser,
1438 validator_fcns,
1439 error_name,
1440 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001441 input_shape=a.shape,
1442 output_shape=result_tens.shape,
1443 input_dtype=a.dtype,
1444 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001445 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001446 input_list=input_list,
1447 output_list=output_list,
1448 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001449 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001450 ):
1451 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001452
1453 attr = ts.TosaSerializerAttribute()
1454 attr.TileAttribute(multiples)
1455
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001456 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001457 return result_tens
1458
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001459 def build_gather(self, op, values, validator_fcns=None, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001460
1461 # Create a new indicies tensor
1462 # here with data that doesn't exceed the dimensions of the values tensor
1463
Kevin Cheng550ccc52021-03-03 11:21:43 -08001464 K = values.shape[1] # K
1465 W = self.randInt(
1466 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1467 ) # W
1468 indicies_arr = np.int32(
1469 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1470 ) # (N, W)
1471 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001472
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001473 result_tens = OutputShaper.gatherOp(
1474 self.ser, self.rng, values, indicies, error_name
1475 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001476
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001477 # Invalidate Input/Output list for error if checks.
1478 input_list = [values.name, indicies.name]
1479 output_list = [result_tens.name]
1480 pCount, cCount = op["operands"]
1481 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001482 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1483 self, error_name, input_list, output_list
1484 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001485
Les Bell729b0352021-11-24 10:28:21 +00001486 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001487 self.ser,
1488 validator_fcns,
1489 error_name,
1490 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001491 input_shape=values.shape,
1492 output_shape=result_tens.shape,
1493 input_dtype=values.dtype,
1494 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001495 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001496 input_list=input_list,
1497 output_list=output_list,
1498 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001499 ):
1500 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001501
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001502 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001503
1504 return result_tens
1505
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001506 def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
Kevin Cheng77d0f762020-11-24 10:26:32 -08001507
1508 # Create a new indicies tensor
1509 # here with data that doesn't exceed the dimensions of the values_in tensor
1510
Kevin Cheng550ccc52021-03-03 11:21:43 -08001511 K = values_in.shape[1] # K
1512 W = input.shape[1] # W
1513 indicies_arr = np.int32(
1514 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1515 ) # (N, W)
1516 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001517
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001518 result_tens = OutputShaper.scatterOp(
1519 self.ser, self.rng, values_in, indicies, input, error_name
1520 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001521
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001522 # Invalidate Input/Output list for error if checks.
1523 input_list = [values_in.name, indicies.name, input.name]
1524 output_list = [result_tens.name]
1525 pCount, cCount = op["operands"]
1526 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001527 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1528 self, error_name, input_list, output_list
1529 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001530
Les Bell729b0352021-11-24 10:28:21 +00001531 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001532 self.ser,
1533 validator_fcns,
1534 error_name,
1535 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001536 input_shape=values_in.shape,
1537 output_shape=result_tens.shape,
1538 input_dtype=values_in.dtype,
1539 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001540 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001541 input_list=input_list,
1542 output_list=output_list,
1543 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001544 ):
1545 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001546
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001547 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001548
Kevin Cheng77d0f762020-11-24 10:26:32 -08001549 return result_tens
1550
Kevin Cheng550ccc52021-03-03 11:21:43 -08001551 def build_resize(
1552 self,
1553 op,
1554 input,
1555 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001556 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001557 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001558 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001559 input_dtype,
1560 output_dtype,
Matthew Haddone86fd342021-09-07 16:12:21 +01001561 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001562 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001563 ):
1564 result_tens = OutputShaper.resizeOp(
1565 self.ser,
Matthew Haddon693ba9e2021-09-22 11:24:37 +01001566 self.rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001567 input,
1568 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001569 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001570 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001571 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001572 input_dtype,
1573 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001574 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001575 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001576
Matthew Haddon848efb42021-09-09 12:30:53 +01001577 # Invalidate Input/Output list for error if checks.
1578 input_list = [input.name]
1579 output_list = [result_tens.name]
1580 pCount, cCount = op["operands"]
1581 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001582 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1583 self, error_name, input_list, output_list
1584 )
Matthew Haddone86fd342021-09-07 16:12:21 +01001585
Les Bell729b0352021-11-24 10:28:21 +00001586 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01001587 self.ser,
1588 validator_fcns,
1589 error_name,
1590 op=op,
1591 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001592 scale=scale,
Matthew Haddon848efb42021-09-09 12:30:53 +01001593 input_dtype=input_dtype,
1594 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01001595 input_shape=input.shape,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001596 output_shape=result_tens.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01001597 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001598 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01001599 input_list=input_list,
1600 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001601 result_tensors=[result_tens],
Matthew Haddon848efb42021-09-09 12:30:53 +01001602 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001603 ):
1604 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01001605
Eric Kunzee5e26762020-10-13 16:11:07 -07001606 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001607
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001608 attr.ResizeAttribute(scale, offset, border, mode)
Eric Kunzee5e26762020-10-13 16:11:07 -07001609
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001610 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001611 return result_tens
1612
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001613 def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
1614 result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
1615 result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001616 self.ser.addOperator(
1617 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1618 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001619 return result_tens
1620
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001621 def build_const(self, op, val, validator_fcns=None, error_name=None):
Kevin Cheng17e92022021-10-01 14:33:33 -07001622 self.ser.addOutputTensor(val)
1623 return val
Eric Kunzee5e26762020-10-13 16:11:07 -07001624
1625 # Type Conversion
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001626 def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001627 result_tens = OutputShaper.typeConversionOp(
1628 self.ser, self.rng, val, out_dtype, error_name
1629 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001630
1631 # Invalidate Input/Output list for error if checks.
1632 input_list = [val.name]
1633 output_list = [result_tens.name]
1634 pCount, cCount = op["operands"]
1635 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001636 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1637 self, error_name, input_list, output_list
1638 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001639
Les Bell729b0352021-11-24 10:28:21 +00001640 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001641 self.ser,
1642 validator_fcns,
1643 error_name,
1644 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001645 input_shape=val.shape,
1646 output_shape=result_tens.shape,
1647 input_dtype=val.dtype,
1648 output_dtype=result_tens.dtype,
Luke Hutton261b7b62023-01-10 14:50:31 +00001649 result_tensors=[result_tens],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001650 input_list=input_list,
1651 output_list=output_list,
1652 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001653 ):
1654 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001655
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001656 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001657 return result_tens
1658
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001659 def build_rescale(
1660 self,
1661 op,
1662 val,
1663 out_dtype,
1664 scale32,
1665 double_round,
1666 per_channel,
1667 validator_fcns,
1668 error_name,
1669 ):
1670 result_tens = OutputShaper.typeConversionOp(
1671 self.ser, self.rng, val, out_dtype, error_name
1672 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001673
1674 if per_channel:
1675 nc = val.shape[-1]
1676 else:
1677 nc = 1
1678
1679 in_type_width = self.typeWidth(val.dtype)
1680 out_type_width = self.typeWidth(out_dtype)
1681
Kevin Cheng3a478572021-01-22 17:21:02 -08001682 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001683 input_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001684 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07001685 elif val.dtype == DType.UINT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001686 input_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001687 in_type_width += 1
1688 elif error_name in [
1689 ErrorIf.InputZeroPointNotZero,
1690 ErrorIf.U16InputZeroPointNotValid,
1691 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001692 input_zp = self.randInt(-128, 128)
1693 if input_zp == 0:
1694 input_zp = input_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001695 in_type_width += 1
1696 elif val.dtype == DType.UINT16:
1697 # Must come after ErrorIf.U16InputZeroPointNotValid check
1698 input_zp = self.rng.choice([0, 32768])
1699 in_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001700 else:
1701 input_zp = 0
1702
Kevin Cheng3a478572021-01-22 17:21:02 -08001703 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001704 output_zp = self.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001705 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001706 elif out_dtype == DType.UINT8:
1707 output_zp = self.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001708 out_type_width += 1
1709 elif error_name in [
1710 ErrorIf.OutputZeroPointNotZero,
1711 ErrorIf.U16OutputZeroPointNotValid,
1712 ]:
Matthew Haddonc2025212021-10-08 21:21:05 +01001713 output_zp = self.randInt(-128, 128)
1714 if output_zp == 0:
1715 output_zp = output_zp + self.rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001716 out_type_width += 1
1717 elif out_dtype == DType.UINT16:
1718 # Must come after ErrorIf.U16OutputZeroPointNotValid check
1719 output_zp = self.rng.choice([0, 32768])
1720 out_type_width += 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001721 else:
1722 output_zp = 0
1723
1724 # Calculate scale based on:
1725 # scale = a *(2^output_width)/(2^input_width))
1726
1727 a = np.float32(self.rng.random(size=[nc]))
1728 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1729
1730 if scale32:
1731 pass
Matthew Haddonb724efc2021-08-25 16:40:29 +01001732 # Cap the scaling at 2^31 - 1 for scale32
Eric Kunzee5e26762020-10-13 16:11:07 -07001733 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1734 else:
1735 # Cap the scaling at 2^15 - 1 for scale16
1736 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1737
Kevin Cheng550ccc52021-03-03 11:21:43 -08001738 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001739
1740 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1741 shift_arr = np.int32(np.zeros(shape=[nc]))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001742 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
1743 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001744
1745 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001746 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1747 scale_arr[i], scale32
1748 )
Eric Kunze750d27d2022-06-30 21:37:09 +00001749 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
1750 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07001751
Kevin Cheng550ccc52021-03-03 11:21:43 -08001752 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001753 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001754 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00001755 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001756 assert val.placeholderFilename
1757 values = np.load(
1758 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
1759 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00001760 val_adj = np.subtract(values, input_zp, dtype=np.int64)
1761 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
1762 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
1763 val_adj = np.add(val_adj, input_zp, dtype=values.dtype)
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00001764 if not np.all(np.array_equal(values, val_adj)):
1765 # Values changed so overwrite file with new values
1766 np.save(
1767 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
1768 val_adj,
1769 False,
1770 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001771
Matthew Haddonc2025212021-10-08 21:21:05 +01001772 # Invalidate Input/Output list for error if checks.
1773 input_list = [val.name]
1774 output_list = [result_tens.name]
1775 pCount, cCount = op["operands"]
1776 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001777 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
1778 self, error_name, input_list, output_list
1779 )
Matthew Haddonc2025212021-10-08 21:21:05 +01001780
1781 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00001782 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01001783 self.ser,
1784 validator_fcns,
1785 error_name,
1786 op=op,
1787 input_dtype=val.dtype,
1788 output_dtype=out_dtype,
1789 input_shape=val.shape,
1790 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001791 scale32=scale32,
1792 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01001793 input_list=input_list,
1794 output_list=output_list,
Luke Hutton261b7b62023-01-10 14:50:31 +00001795 result_tensors=[result_tens],
Matthew Haddonc2025212021-10-08 21:21:05 +01001796 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001797 ):
1798 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01001799
Eric Kunzee5e26762020-10-13 16:11:07 -07001800 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001801 attr.RescaleAttribute(
1802 input_zp,
1803 output_zp,
1804 multiplier_arr,
1805 shift_arr,
1806 scale32,
1807 double_round,
1808 per_channel,
1809 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001810
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001811 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001812 return result_tens
1813
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001814 def _get_condition_tensor(self, op, cond, error_name):
1815 if error_name == ErrorIf.CondIfCondNotMatchingBool:
1816 cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
1817 else:
1818 cond_type = DType.BOOL
1819 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
1820 choice = self.rng.choice([1, 2])
1821 if choice == 1:
1822 cond_shape = [2]
1823 else:
1824 cond_shape = [1, 2]
1825 else:
1826 # Must be of size 1 (rank 0)
1827 cond_shape = []
1828 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
1829 return cond_tens
1830
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001831 def build_cond_if_const(
1832 self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
1833 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001834 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001835 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07001836 # and fill them with const nodes for the body.
1837
1838 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001839 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001840
1841 # Make then/else tensors
1842 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01001843
1844 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001845 if error_name in [
1846 ErrorIf.CondIfOutputListThenGraphMismatch,
1847 ErrorIf.CondIfOutputListElseGraphMismatch,
1848 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001849 incorrect_shape = deepcopy(then_tens.shape)
1850 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001851 incorrect_shape[i] += (
1852 self.rng.choice([-3, -2, 2, 3])
1853 if incorrect_shape[i] > 3
1854 else self.rng.choice([1, 2, 4])
1855 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01001856 incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
1857
Jeremy Johnson18e26662021-07-22 16:15:29 +01001858 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1859 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001860
1861 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001862 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001863
1864 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001865 then_block = "THEN_BLOCK"
1866 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001867 attr = ts.TosaSerializerAttribute()
1868 attr.CondIfAttribute(then_block, else_block)
1869
1870 # Finally, build the op and the two blocks
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001871 self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001872
Jerry Ge9e94af82022-10-27 09:57:00 -07001873 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07001874 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01001875 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
1876 then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1877 else:
1878 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001879 self.ser.addOutputTensor(then_tens)
1880
Jerry Ge9e94af82022-10-27 09:57:00 -07001881 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01001882 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
1883 else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
1884 else:
1885 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001886 self.ser.addOutputTensor(else_tens)
1887
Les Bell729b0352021-11-24 10:28:21 +00001888 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001889 self.ser,
1890 validator_fcns,
1891 error_name,
1892 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07001893 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001894 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001895 ):
1896 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001897
Eric Kunzee5e26762020-10-13 16:11:07 -07001898 return result_tens
1899
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001900 def build_cond_if_binary(
1901 self, op, a, b, cond, validator_fcns=None, error_name=None
1902 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001903 # For cond_if with a binary op in the then/else blocks, take a and b and
1904 # alternately add or subtract them based on the condition
1905
1906 # Condition tensor
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001907 cond_tens = self._get_condition_tensor(op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
Kevin Cheng550ccc52021-03-03 11:21:43 -08001909 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001910
1911 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001912 then_block = "THEN_BLOCK"
1913 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001914 attr = ts.TosaSerializerAttribute()
1915 attr.CondIfAttribute(then_block, else_block)
1916
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001917 if error_name in [
1918 ErrorIf.CondIfInputListThenGraphMismatch,
1919 ErrorIf.CondIfInputListElseGraphMismatch,
1920 ErrorIf.CondIfOutputListElseGraphMismatch,
1921 ErrorIf.CondIfOutputListThenGraphMismatch,
1922 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01001923 incorrect_shape = a.shape.copy()
1924 for i in range(len(incorrect_shape)):
1925 incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
1926 incorrect_block_input = deepcopy(a)
1927 incorrect_block_input.shape = incorrect_shape
1928
Eric Kunzee5e26762020-10-13 16:11:07 -07001929 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001930 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001931 op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08001932 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001933
James Ward24dbc422022-10-19 12:20:31 +01001934 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Les Bell6040b4d2021-10-11 12:50:31 +01001935 then_op, else_op = Op.ADD, Op.SUB
1936 elif a.dtype in (DType.INT8, DType.INT16):
1937 then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
1938 else:
1939 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07001940
Les Bell6040b4d2021-10-11 12:50:31 +01001941 for block, op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07001942 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001943 if (
1944 error_name == ErrorIf.CondIfInputListThenGraphMismatch
1945 and block == then_block
1946 ) or (
1947 error_name == ErrorIf.CondIfInputListElseGraphMismatch
1948 and block == else_block
1949 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001950 self.ser.addInputTensor(incorrect_block_input)
1951 self.ser.addInputTensor(b)
1952 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001953 elif (
1954 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
1955 and block == then_block
1956 ) or (
1957 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
1958 and block == else_block
1959 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01001960 self.ser.addInputTensor(a)
1961 self.ser.addInputTensor(b)
1962 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
1963 else:
1964 self.ser.addInputTensor(a)
1965 self.ser.addInputTensor(b)
1966 tens = self.ser.addOutput(a.shape, a.dtype)
Les Bell6040b4d2021-10-11 12:50:31 +01001967 self.ser.addOperator(op, [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001968
Les Bell729b0352021-11-24 10:28:21 +00001969 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01001970 self.ser,
1971 validator_fcns,
1972 error_name,
1973 op=op,
1974 a=a,
1975 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07001976 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00001977 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00001978 ):
1979 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01001980
Eric Kunzee5e26762020-10-13 16:11:07 -07001981 return result_tens
1982
Matthew Haddon630c17c2021-10-14 15:05:41 +01001983 def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001984 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001985
Kevin Cheng550ccc52021-03-03 11:21:43 -08001986 cond_block = "COND_BLOCK"
1987 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001988
1989 attr = ts.TosaSerializerAttribute()
1990 attr.WhileLoopAttribute(cond_block, body_block)
1991
1992 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001993 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001994 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001995 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001996
1997 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001998 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1999 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002000 if error_name == ErrorIf.InputListOutputListMismatch:
2001 incorrect_acc = deepcopy(acc)
2002 for i in range(len(incorrect_acc.shape)):
2003 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2004 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2005 else:
2006 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002007
2008 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002009 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002010 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002011 [iter.name, a.name, acc.name],
2012 [iter_out.name, a_out.name, acc_out.name],
2013 attr,
2014 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002015 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002016
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002017 if error_name in [
2018 ErrorIf.InputListCondGraphMismatch,
2019 ErrorIf.InputListBodyGraphInputMismatch,
2020 ErrorIf.InputListBodyGraphOutputMismatch,
2021 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002022 incorrect_iter = deepcopy(iter)
2023 for i in range(len(incorrect_iter.shape)):
2024 incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
2025 if len(incorrect_iter.shape) == 0:
2026 incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
2027
2028 incorrect_acc = deepcopy(acc)
2029 for i in range(len(incorrect_acc.shape)):
2030 incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
2031
Eric Kunzee5e26762020-10-13 16:11:07 -07002032 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002033 self.ser.addBasicBlock(cond_block)
2034
Matthew Haddon630c17c2021-10-14 15:05:41 +01002035 if error_name == ErrorIf.InputListCondGraphMismatch:
2036 self.ser.addInputTensor(incorrect_iter)
2037 self.ser.addInputTensor(a)
2038 self.ser.addInputTensor(incorrect_acc)
2039 else:
2040 self.ser.addInputTensor(iter)
2041 self.ser.addInputTensor(a)
2042 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002043 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002044
2045 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002046 cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002047 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002048 cond_type = DType.BOOL
2049 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
2050 choice = self.rng.choice([1, 2])
2051 if choice == 1:
2052 cond_shape = [3]
2053 else:
2054 cond_shape = [1, 2]
2055 else:
2056 cond_shape = []
2057 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002058
Kevin Cheng550ccc52021-03-03 11:21:43 -08002059 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002060
2061 # BODY block (input: a, acc, iter, output: a, acc, iter)
2062 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002063 self.ser.addBasicBlock(body_block)
2064
Matthew Haddon630c17c2021-10-14 15:05:41 +01002065 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2066 self.ser.addInputTensor(incorrect_iter)
2067 self.ser.addInputTensor(a)
2068 self.ser.addInputTensor(incorrect_acc)
2069 else:
2070 self.ser.addInputTensor(iter)
2071 self.ser.addInputTensor(a)
2072 self.ser.addInputTensor(acc)
2073
Kevin Cheng550ccc52021-03-03 11:21:43 -08002074 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002075
2076 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002077 iter_body_out = self.ser.addIntermediate(
2078 incorrect_iter.shape, incorrect_iter.dtype
2079 )
2080 acc_body_out = self.ser.addIntermediate(
2081 incorrect_acc.shape, incorrect_acc.dtype
2082 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002083 else:
2084 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2085 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2086
Eric Kunzee5e26762020-10-13 16:11:07 -07002087 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2088 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2089 self.ser.addOutputTensor(iter_body_out)
2090 self.ser.addOutputTensor(a)
2091 self.ser.addOutputTensor(acc_body_out)
2092
Les Bell729b0352021-11-24 10:28:21 +00002093 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002094 self.ser,
2095 validator_fcns,
2096 error_name,
2097 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002098 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002099 ):
2100 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002101
Eric Kunzee5e26762020-10-13 16:11:07 -07002102 return acc_out
2103
Luke Hutton57287132023-02-06 14:54:18 +00002104 def build_fft2d(
2105 self, op, val1, val2, inverse, validator_fcns=None, error_name=None
2106 ):
2107 results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
2108
2109 input_names = [val1.name, val2.name]
2110 pCount, cCount = op["operands"]
2111 num_operands = pCount + cCount
2112
2113 output_names = [res.name for res in results]
2114 output_shapes = [res.shape for res in results]
2115 output_dtypes = [res.dtype for res in results]
2116
2117 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2118 self, error_name, input_names, output_names
2119 )
2120
2121 if not TosaErrorValidator.evValidateErrorIfs(
2122 self.ser,
2123 validator_fcns,
2124 error_name,
2125 op=op,
2126 inverse=inverse,
2127 input1=val1,
2128 input2=val2,
2129 input_shape=val1.shape,
2130 input_dtype=val1.dtype,
2131 output_shape=output_shapes,
2132 output_dtype=output_dtypes,
2133 result_tensors=results,
2134 input_list=input_names,
2135 output_list=output_names,
2136 num_operands=num_operands,
2137 ):
2138 return None
2139
2140 attr = ts.TosaSerializerAttribute()
2141 attr.FFTAttribute(inverse)
2142
2143 self.ser.addOperator(op["op"], input_names, output_names, attr)
2144 return results
2145
Luke Hutton261b7b62023-01-10 14:50:31 +00002146 def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
2147 results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
2148
2149 input_names = [val.name]
2150 pCount, cCount = op["operands"]
2151 num_operands = pCount + cCount
2152
2153 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002154 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002155 output_dtypes = [res.dtype for res in results]
2156
2157 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2158 self, error_name, input_names, output_names
2159 )
2160
2161 if not TosaErrorValidator.evValidateErrorIfs(
2162 self.ser,
2163 validator_fcns,
2164 error_name,
2165 op=op,
2166 input_shape=val.shape,
2167 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002168 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002169 output_dtype=output_dtypes,
2170 result_tensors=results,
2171 input_list=input_names,
2172 output_list=output_names,
2173 num_operands=num_operands,
2174 ):
2175 return None
2176
2177 self.ser.addOperator(op["op"], input_names, output_names)
2178 return results
2179
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002180 def create_filter_lists(
2181 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2182 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002183 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2184 default_test_rank_range = range(1, 5)
2185 if not shapeFilter:
2186 shapeFilter = [None]
2187
2188 # Calculate the filters based on what is requested and what the operator allows
2189 rmin, rmax = op["rank"]
2190 if rankFilter is not None:
2191 cleanRankFilter = []
2192 # Ensure rankFilter values are allowed by operator
2193 for rank in rankFilter:
2194 if rank >= rmin and rank <= rmax:
2195 cleanRankFilter.append(rank)
2196 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002197 # Ensure default behaviour is bounded by default range or by operator,
2198 # whichever is the smaller range of ranks.
2199 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002200 cleanRankFilter = (
2201 opRankRange
2202 if len(opRankRange) <= len(default_test_rank_range)
2203 else default_test_rank_range
2204 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002205 else:
2206 cleanRankFilter = range(rmin, rmax + 1)
2207
2208 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002209
Matthew Haddon1c00b712021-10-01 15:51:03 +01002210 if dtypeFilter is not None:
2211 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002212 # Create list of operator dtypes filtered by requested dtypes
2213 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002214 if dtype in dtypeFilter or (
2215 isinstance(dtype, list) and dtype[0] in dtypeFilter
2216 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002217 cleanDtypeFilter.append(dtype)
2218 else:
2219 cleanDtypeFilter = dtypes
2220
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002221 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002222 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002223 "shapeFilter": shapeFilter,
2224 "rankFilter": cleanRankFilter,
2225 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002226 }
2227 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002228 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002229 if validator is not None:
2230 validator_info = validator(check=False, op=op)
2231 else:
2232 return None
2233
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002234 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002235
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002236 # Set parameters as required
2237 if error_arguments["rank"] is not None:
2238 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002239 else:
2240 rankFilter = cleanRankFilter
2241
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002242 if error_arguments["dtype"] is not None:
2243 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002244 else:
2245 dtypeFilter = cleanDtypeFilter
2246
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002247 if error_arguments["shape"] is not None:
2248 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002249 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002250 shapeFilter = shapeFilter[
2251 :2
2252 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002253
2254 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002255 "shapeFilter": shapeFilter,
2256 "rankFilter": rankFilter,
2257 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002258 }
2259 return filterDict
2260
Kevin Cheng550ccc52021-03-03 11:21:43 -08002261 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002262 self,
2263 opName,
2264 shapeFilter=[None],
2265 rankFilter=None,
2266 dtypeFilter=None,
2267 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002268 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002269
2270 try:
2271 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002272 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002273 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002274
2275 # Initialize a new random number generator
2276 self.rng = np.random.default_rng(self.random_seed)
2277
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002278 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002279
Eric Kunzee5e26762020-10-13 16:11:07 -07002280 # Test list consists of a tuple of:
2281 # (opName, testNameStr, dtype, shapeList, argumentsList)
2282 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002283 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002284 error_if_validators = op["error_if_validators"]
2285 else:
2286 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002287
Matthew Haddon1c00b712021-10-01 15:51:03 +01002288 for validator in error_if_validators:
2289 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002290 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002291 else:
2292 error_name = None
2293
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002294 filterDict = self.create_filter_lists(
2295 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2296 )
2297 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002298 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002299 cleanRankFilter = filterDict["rankFilter"]
2300 cleanDtypeFilter = filterDict["dtypeFilter"]
2301 cleanShapeFilter = filterDict["shapeFilter"]
2302 # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002303
2304 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002305 for t in cleanDtypeFilter:
2306 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002307 # Filter out by rank
2308 if shape is not None and len(shape) != r:
2309 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002310 self.setTargetShape(shape)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002311 shapeList = tgen_fcn(self, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002312
Matthew Haddon74567092021-07-16 15:38:20 +01002313 shapeStr = self.shapeStr(shapeList[0])
2314 typeStr = self.typeStr(t)
Eric Kunzee5e26762020-10-13 16:11:07 -07002315
Matthew Haddon74567092021-07-16 15:38:20 +01002316 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2317 argList = []
2318 if agen_fcn:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002319 argList = agen_fcn(self, opName, shapeList, t, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002320 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002321 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002322
Matthew Haddon74567092021-07-16 15:38:20 +01002323 for argStr, args in argList:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002324 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002325 if argStr:
2326 testStr = "{}_{}_{}_{}".format(
2327 opName, shapeStr, typeStr, argStr
2328 )
2329 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002330 testStr = "{}_{}_{}".format(
2331 opName, shapeStr, typeStr
2332 )
2333 elif testType == "negative":
Matthew Haddone86fd342021-09-07 16:12:21 +01002334 if argStr:
2335 testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
2336 opName, error_name, shapeStr, typeStr, argStr
2337 )
2338 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002339 testStr = "{}_ERRORIF_{}_{}_{}".format(
2340 opName, error_name, shapeStr, typeStr
2341 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002342
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002343 testList.append(
2344 (opName, testStr, t, error_name, shapeList, args)
2345 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002346
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002347 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002348 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
2349 if "invalid_test_validators" in op:
2350 invalid_test_validators = op["invalid_test_validators"]
2351 clean_testList = []
2352 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01002353 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01002354 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002355 if validator_fcn(
2356 opName=test[0],
2357 input_dtype=test[2],
2358 shapeList=test[4],
2359 args=test[5],
2360 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002361 remove_test = True
2362 if not remove_test:
2363 clean_testList.append(test)
2364 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07002365
2366 return testList
2367
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002368 def serializeTest(
2369 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
2370 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002371 try:
2372 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002373 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002374 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002375
Jeremy Johnson0c716862023-04-13 17:18:19 +01002376 if self.args.verbose:
2377 print(f"Creating {testStr}")
2378
Eric Kunzee5e26762020-10-13 16:11:07 -07002379 # Create a serializer
2380 self.createSerializer(opName, testStr)
2381
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002382 build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01002383 if "error_if_validators" in op:
2384 error_if_validators = op["error_if_validators"]
2385 else:
2386 error_if_validators = None
2387
Kevin Cheng550ccc52021-03-03 11:21:43 -08002388 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07002389 num_operands = pCount + cCount
2390
2391 if isinstance(dtype_or_dtypeList, list):
2392 dtypeList = dtype_or_dtypeList
Kevin Cheng93a16282021-08-31 16:14:03 -07002393 elif op["op"] == Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002394 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07002395 else:
2396 dtypeList = [dtype_or_dtypeList] * (num_operands)
2397
Kevin Cheng93a16282021-08-31 16:14:03 -07002398 if op["op"] != Op.CONCAT:
Matthew Haddon818ab902021-07-27 09:12:49 +01002399 assert (
2400 len(shapeList) == num_operands
2401 ), "shapeList length {} must match number of operands {}".format(
2402 len(shapeList), num_operands
2403 )
2404 assert (
2405 len(dtypeList) == num_operands
2406 ), "dtypeList length {} must match number of operands {}".format(
2407 len(dtypeList), num_operands
2408 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002409
2410 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002411 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002412 except KeyError:
2413 qgen = None
2414
2415 # Build the random tensor operands and the test
2416 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08002417
Matthew Haddon1c00b712021-10-01 15:51:03 +01002418 if qgen is not None:
Matthew Haddone4ecdb22021-09-28 11:38:21 +01002419 qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002420 else:
2421 qinfo = None
2422
Eric Kunzeb5fabec2022-06-07 05:20:44 +00002423 tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00002424
Matthew Haddon1c00b712021-10-01 15:51:03 +01002425 try:
2426 if error_if_validators is None:
2427 if qinfo is not None:
2428 resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
2429 else:
2430 resultName = build_fcn(self, op, *tens, *testArgs)
2431 else:
2432 if qinfo is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002433 resultName = build_fcn(
2434 self,
2435 op,
2436 *tens,
2437 *testArgs,
2438 validator_fcns=error_if_validators,
2439 error_name=error_name,
2440 qinfo=qinfo,
2441 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002442 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002443 resultName = build_fcn(
2444 self,
2445 op,
2446 *tens,
2447 *testArgs,
2448 validator_fcns=error_if_validators,
2449 error_name=error_name,
2450 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002451 except TypeError as e:
Les Bell0e027d42021-11-09 14:42:14 +00002452 print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002453 raise e
2454
Les Bell729b0352021-11-24 10:28:21 +00002455 if resultName:
2456 # The test is valid, serialize it
2457 self.serialize("test")
2458 else:
2459 # The test is not valid
2460 print(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01002461
Eric Kunzee5e26762020-10-13 16:11:07 -07002462 def createDynamicOpLists(self):
2463
Jeremy Johnson00423432022-09-12 17:27:37 +01002464 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
2465 # Already created these lists (can occur when class is initialized more than once)
2466 return
2467
Eric Kunzee5e26762020-10-13 16:11:07 -07002468 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01002469 if not self.args.level8k:
2470 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
2471 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
2472 else:
2473 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
2474 KERNELS_2D = [[1, bigK], [bigK, 2]]
2475 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002476
Kevin Cheng1533b852021-09-01 12:51:58 -07002477 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002478 testName = "conv2d_{}x{}".format(k[0], k[1])
2479 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
2480 self.TOSA_OP_LIST[testName]["filter"] = k
2481 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002482
Kevin Cheng550ccc52021-03-03 11:21:43 -08002483 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
2484 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2485 "depthwise_conv2d_TEMPLATE"
2486 ].copy()
2487 self.TOSA_OP_LIST[testName]["filter"] = k
2488 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002489
Kevin Cheng550ccc52021-03-03 11:21:43 -08002490 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
2491 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
2492 "transpose_conv2d_TEMPLATE"
2493 ].copy()
2494 self.TOSA_OP_LIST[testName]["filter"] = k
2495 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07002496
Kevin Cheng1533b852021-09-01 12:51:58 -07002497 for k in KERNELS_3D:
2498 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
2499 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
2500 self.TOSA_OP_LIST[testName]["filter"] = k
2501 self.TOSA_OP_LIST[testName]["template"] = False
2502
Eric Kunzee5e26762020-10-13 16:11:07 -07002503 # Delete any templates after having created any dynamic ops
2504 # This is a two-pass operation because it's bad practice to delete
2505 # keys from dictionaries while iterating
2506 keyList = []
2507 for k in self.TOSA_OP_LIST:
2508 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002509 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07002510 keyList.append(k)
2511 continue
2512 except KeyError:
2513 pass
2514
2515 for k in keyList:
2516 del self.TOSA_OP_LIST[k]
2517
2518 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002519 """Fill in default fields for ops if they aren't already specified.
2520 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07002521 for op in self.TOSA_OP_LIST:
2522
2523 # Required fields
2524 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002525 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002526 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002527 raise Exception(
2528 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
2529 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002530
2531 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002532 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002533 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002534 raise Exception(
2535 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
2536 op
2537 )
2538 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002539
2540 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002541 _ = self.TOSA_OP_LIST[op]["types"]
2542 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002543 raise Exception(
2544 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
2545 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002546
2547 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002548 _ = self.TOSA_OP_LIST[op]["op"]
2549 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002550 raise Exception(
2551 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
2552 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002553
2554 # Put in default rank range, if missing
2555 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002556 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002557 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002558 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07002559
2560 # Tensor operator list
2561 # 'op': op name
2562 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08002563 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
2564 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07002565 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
2566 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01002567 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002568
Kevin Cheng550ccc52021-03-03 11:21:43 -08002569 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01002570 TYPE_INT_FP = [
2571 DType.INT8,
2572 DType.INT16,
2573 DType.INT32,
2574 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002575 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002576 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002577 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07002578
Kevin Cheng550ccc52021-03-03 11:21:43 -08002579 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01002580 TYPE_FI32 = [
2581 DType.FP32,
2582 DType.FP16,
2583 DType.BF16,
2584 DType.INT32,
2585 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01002586 TYPE_FIB = [
2587 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01002588 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002589 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01002590 DType.INT8,
2591 DType.INT16,
2592 DType.INT32,
2593 DType.BOOL,
2594 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002595 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07002596
James Ward24dbc422022-10-19 12:20:31 +01002597 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
Eric Kunzee5e26762020-10-13 16:11:07 -07002598
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002599 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07002600 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07002601 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002602 [DType.INT8, DType.INT8, DType.INT32],
2603 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01002604 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002605 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01002606 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002607 [DType.FP32, DType.FP32, DType.FP32],
Kevin Cheng989cb052021-04-28 16:29:44 -07002608 ]
2609
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01002610 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07002611
2612 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08002613 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002614 "argmax": {
2615 "op": Op.ARGMAX,
2616 "operands": (1, 0),
Matthew Haddonc4cf0372021-10-11 09:38:10 +01002617 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002618 "build_fcn": (
2619 build_argmax,
2620 TosaTensorGen.tgBasic,
2621 TosaTensorValuesGen.tvgDefault,
2622 TosaArgGen.agAxis,
2623 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002624 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002625 "error_if_validators": (
2626 TosaErrorValidator.evAxisSmallerZero,
2627 TosaErrorValidator.evAxisLargerRank,
2628 TosaErrorValidator.evArgmaxOutputRankMismatch,
2629 TosaErrorValidator.evArgmaxOutputShapeMismatch,
2630 TosaErrorValidator.evWrongRank,
2631 TosaErrorValidator.evWrongInputType,
2632 TosaErrorValidator.evWrongOutputType,
2633 TosaErrorValidator.evWrongInputList,
2634 TosaErrorValidator.evWrongOutputList,
2635 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002636 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002637 "avg_pool2d": {
2638 "op": Op.AVG_POOL2D,
2639 "operands": (1, 0),
2640 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002641 "build_fcn": (
2642 build_pool2d,
2643 TosaTensorGen.tgNHWC,
2644 TosaTensorValuesGen.tvgDefault,
2645 TosaArgGen.agPooling,
2646 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002647 "qgen": TosaQuantGen.qgUnary,
2648 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002649 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002650 "error_if_validators": (
2651 TosaErrorValidator.evKernelSmallerOne,
2652 TosaErrorValidator.evStrideSmallerOne,
2653 TosaErrorValidator.evPadSmallerZero,
2654 TosaErrorValidator.evWrongRank,
2655 TosaErrorValidator.evWrongInputType,
2656 TosaErrorValidator.evWrongOutputType,
2657 TosaErrorValidator.evWrongInputList,
2658 TosaErrorValidator.evWrongOutputList,
2659 TosaErrorValidator.evInputZeroPointNotZero,
2660 TosaErrorValidator.evOutputZeroPointNotZero,
2661 TosaErrorValidator.evPadLargerEqualKernel,
2662 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002663 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002664 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002665 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002666 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002667 "conv2d_TEMPLATE": {
2668 "op": Op.CONV2D,
2669 "operands": (1, 2),
2670 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002671 "build_fcn": (
2672 build_conv2d,
2673 TosaTensorGen.tgConv2D,
2674 TosaTensorValuesGen.tvgDefault,
2675 TosaArgGen.agConv,
2676 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002677 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002678 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002679 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2680 "error_if_validators": (
2681 TosaErrorValidator.evWrongInputType,
2682 TosaErrorValidator.evWrongOutputType,
2683 TosaErrorValidator.evWrongInputList,
2684 TosaErrorValidator.evWrongOutputList,
2685 TosaErrorValidator.evInputZeroPointNotZero,
2686 TosaErrorValidator.evWeightZeroPointNotZero,
2687 TosaErrorValidator.evPadSmallerZero,
2688 TosaErrorValidator.evStrideSmallerOne,
2689 TosaErrorValidator.evDilationSmallerOne,
2690 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002691 TosaErrorValidator.evConvOutputShapeMismatch,
2692 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002693 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002694 "template": True,
2695 },
Kevin Cheng1533b852021-09-01 12:51:58 -07002696 # Templated operator. Filled in by createDynamicOpLists
2697 "conv3d_TEMPLATE": {
2698 "op": Op.CONV3D,
2699 "operands": (1, 2),
2700 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002701 "build_fcn": (
2702 build_conv3d,
2703 TosaTensorGen.tgConv3D,
2704 TosaTensorValuesGen.tvgDefault,
2705 TosaArgGen.agConv,
2706 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002707 "qgen": TosaQuantGen.qgConv,
2708 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002709 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2710 "error_if_validators": (
2711 TosaErrorValidator.evWrongInputType,
2712 TosaErrorValidator.evWrongOutputType,
2713 TosaErrorValidator.evWrongInputList,
2714 TosaErrorValidator.evWrongOutputList,
2715 TosaErrorValidator.evInputZeroPointNotZero,
2716 TosaErrorValidator.evWeightZeroPointNotZero,
2717 TosaErrorValidator.evPadSmallerZero,
2718 TosaErrorValidator.evStrideSmallerOne,
2719 TosaErrorValidator.evDilationSmallerOne,
2720 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002721 TosaErrorValidator.evConvOutputShapeMismatch,
2722 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002723 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07002724 "template": True,
2725 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002726 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002727 "depthwise_conv2d_TEMPLATE": {
2728 "op": Op.DEPTHWISE_CONV2D,
2729 "operands": (1, 2),
2730 "filter": [1, 1],
2731 "rank": (4, 4),
2732 "build_fcn": (
2733 build_depthwise_conv2d,
2734 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002735 TosaTensorValuesGen.tvgDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01002736 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002737 ),
2738 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002739 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002740 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
2741 "error_if_validators": (
2742 TosaErrorValidator.evWrongInputType,
2743 TosaErrorValidator.evWrongOutputType,
2744 TosaErrorValidator.evWrongInputList,
2745 TosaErrorValidator.evWrongOutputList,
2746 TosaErrorValidator.evInputZeroPointNotZero,
2747 TosaErrorValidator.evWeightZeroPointNotZero,
2748 TosaErrorValidator.evPadSmallerZero,
2749 TosaErrorValidator.evStrideSmallerOne,
2750 TosaErrorValidator.evDilationSmallerOne,
2751 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002752 TosaErrorValidator.evConvOutputShapeMismatch,
2753 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00002754 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002755 "template": True,
2756 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002757 "fully_connected": {
2758 "op": Op.FULLY_CONNECTED,
2759 "operands": (1, 2),
2760 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002761 "build_fcn": (
2762 build_fully_connected,
2763 TosaTensorGen.tgFullyConnected,
2764 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002765 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002766 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002767 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002768 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002769 "error_if_validators": (
2770 TosaErrorValidator.evInputZeroPointNotZero,
2771 TosaErrorValidator.evWeightZeroPointNotZero,
2772 TosaErrorValidator.evWrongRank,
2773 TosaErrorValidator.evWrongInputType,
2774 TosaErrorValidator.evWrongOutputType,
2775 TosaErrorValidator.evWrongInputList,
2776 TosaErrorValidator.evWrongOutputList,
2777 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002778 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002779 "matmul": {
2780 "op": Op.MATMUL,
2781 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002782 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002783 "build_fcn": (
2784 build_matmul,
2785 TosaTensorGen.tgMatmul,
2786 TosaTensorValuesGen.tvgDefault,
James Ward8b390432022-08-12 20:48:56 +01002787 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002788 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002789 "qgen": TosaQuantGen.qgMatmul,
2790 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002791 "error_if_validators": (
2792 TosaErrorValidator.evInputZeroPointNotZero,
2793 TosaErrorValidator.evWrongRank,
2794 TosaErrorValidator.evWrongInputType,
2795 TosaErrorValidator.evWrongOutputType,
2796 TosaErrorValidator.evWrongInputList,
2797 TosaErrorValidator.evWrongOutputList,
2798 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002799 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002800 "max_pool2d": {
2801 "op": Op.MAX_POOL2D,
2802 "operands": (1, 0),
2803 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002804 "build_fcn": (
James Ward8b390432022-08-12 20:48:56 +01002805 build_maxpool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002806 TosaTensorGen.tgNHWC,
2807 TosaTensorValuesGen.tvgDefault,
2808 TosaArgGen.agPooling,
2809 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002810 "types": TYPE_NARROW_INT_FP,
Les Bell0e027d42021-11-09 14:42:14 +00002811 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002812 "error_if_validators": (
2813 TosaErrorValidator.evKernelSmallerOne,
2814 TosaErrorValidator.evStrideSmallerOne,
2815 TosaErrorValidator.evPadSmallerZero,
2816 TosaErrorValidator.evWrongRank,
2817 TosaErrorValidator.evWrongInputType,
2818 TosaErrorValidator.evWrongOutputType,
2819 TosaErrorValidator.evWrongInputList,
2820 TosaErrorValidator.evWrongOutputList,
2821 TosaErrorValidator.evPadLargerEqualKernel,
2822 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002823 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002824 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002825 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002826 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002827 "transpose_conv2d_TEMPLATE": {
2828 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002829 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002830 "rank": (4, 4),
2831 "build_fcn": (
2832 build_transpose_conv2d,
2833 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002834 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002835 TosaArgGen.agTransposeConv2D,
2836 ),
2837 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07002838 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00002839 "invalid_test_validators": (
2840 TosaInvalidValidator.ivHeightWidthInvalid,
2841 TosaInvalidValidator.ivNonPositiveOutputShape,
2842 ),
2843 "error_if_validators": (
2844 TosaErrorValidator.evWrongInputType,
2845 TosaErrorValidator.evWrongOutputType,
2846 TosaErrorValidator.evWrongInputList,
2847 TosaErrorValidator.evWrongOutputList,
2848 TosaErrorValidator.evInputZeroPointNotZero,
2849 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07002850 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00002851 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00002852 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002853 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00002854 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002855 "template": True,
2856 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002857 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002858 "clamp": {
2859 "op": Op.CLAMP,
2860 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002861 "build_fcn": (
2862 build_clamp,
2863 TosaTensorGen.tgBasic,
2864 TosaTensorValuesGen.tvgDefault,
2865 None,
2866 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002867 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002868 "error_if_validators": (
2869 TosaErrorValidator.evMaxSmallerMin,
2870 TosaErrorValidator.evWrongInputType,
2871 TosaErrorValidator.evWrongOutputType,
2872 TosaErrorValidator.evWrongInputList,
2873 TosaErrorValidator.evWrongOutputList,
2874 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002875 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08002876 "sigmoid": {
2877 "op": Op.SIGMOID,
2878 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002879 "build_fcn": (
2880 build_sigmoid,
2881 TosaTensorGen.tgBasic,
2882 TosaTensorValuesGen.tvgDefault,
2883 None,
2884 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002885 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002886 "error_if_validators": (
2887 TosaErrorValidator.evWrongInputType,
2888 TosaErrorValidator.evWrongOutputType,
2889 TosaErrorValidator.evWrongInputList,
2890 TosaErrorValidator.evWrongOutputList,
2891 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002892 },
2893 "tanh": {
2894 "op": Op.TANH,
2895 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002896 "build_fcn": (
2897 build_tanh,
2898 TosaTensorGen.tgBasic,
2899 TosaTensorValuesGen.tvgDefault,
2900 None,
2901 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002902 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002903 "error_if_validators": (
2904 TosaErrorValidator.evWrongInputType,
2905 TosaErrorValidator.evWrongOutputType,
2906 TosaErrorValidator.evWrongInputList,
2907 TosaErrorValidator.evWrongOutputList,
2908 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002909 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002910 # Elementwise Binary Operators
2911 "add": {
2912 "op": Op.ADD,
2913 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002914 "build_fcn": (
2915 build_binary_broadcast,
2916 TosaTensorGen.tgBroadcastFuzz,
2917 TosaTensorValuesGen.tvgAddSub,
2918 None,
2919 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002920 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002921 "error_if_validators": (
2922 TosaErrorValidator.evRankMismatch,
2923 TosaErrorValidator.evWrongInputType,
2924 TosaErrorValidator.evWrongOutputType,
2925 TosaErrorValidator.evWrongInputList,
2926 TosaErrorValidator.evWrongOutputList,
2927 TosaErrorValidator.evDimensionMismatch,
2928 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002929 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002930 "arithmetic_right_shift": {
2931 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2932 "operands": (2, 0),
2933 "build_fcn": (
2934 build_arithmetic_right_shift,
2935 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002936 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08002937 TosaArgGen.agArithmeticRightShift,
2938 ),
2939 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002940 "error_if_validators": (
2941 TosaErrorValidator.evRankMismatch,
2942 TosaErrorValidator.evWrongInputType,
2943 TosaErrorValidator.evWrongOutputType,
2944 TosaErrorValidator.evWrongInputList,
2945 TosaErrorValidator.evWrongOutputList,
2946 TosaErrorValidator.evDimensionMismatch,
2947 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002948 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002949 "bitwise_and": {
2950 "op": Op.BITWISE_AND,
2951 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002952 "build_fcn": (
2953 build_binary_broadcast,
2954 TosaTensorGen.tgBroadcastFuzz,
2955 TosaTensorValuesGen.tvgDefault,
2956 None,
2957 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002958 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002959 "error_if_validators": (
2960 TosaErrorValidator.evRankMismatch,
2961 TosaErrorValidator.evWrongInputType,
2962 TosaErrorValidator.evWrongOutputType,
2963 TosaErrorValidator.evWrongInputList,
2964 TosaErrorValidator.evWrongOutputList,
2965 TosaErrorValidator.evDimensionMismatch,
2966 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002967 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002968 "bitwise_or": {
2969 "op": Op.BITWISE_OR,
2970 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002971 "build_fcn": (
2972 build_binary_broadcast,
2973 TosaTensorGen.tgBroadcastFuzz,
2974 TosaTensorValuesGen.tvgDefault,
2975 None,
2976 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002977 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002978 "error_if_validators": (
2979 TosaErrorValidator.evRankMismatch,
2980 TosaErrorValidator.evWrongInputType,
2981 TosaErrorValidator.evWrongOutputType,
2982 TosaErrorValidator.evWrongInputList,
2983 TosaErrorValidator.evWrongOutputList,
2984 TosaErrorValidator.evDimensionMismatch,
2985 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002986 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002987 "bitwise_xor": {
2988 "op": Op.BITWISE_XOR,
2989 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002990 "build_fcn": (
2991 build_binary_broadcast,
2992 TosaTensorGen.tgBroadcastFuzz,
2993 TosaTensorValuesGen.tvgDefault,
2994 None,
2995 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08002996 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002997 "error_if_validators": (
2998 TosaErrorValidator.evRankMismatch,
2999 TosaErrorValidator.evWrongInputType,
3000 TosaErrorValidator.evWrongOutputType,
3001 TosaErrorValidator.evWrongInputList,
3002 TosaErrorValidator.evWrongOutputList,
3003 TosaErrorValidator.evDimensionMismatch,
3004 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003005 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003006 "intdiv": {
3007 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003008 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003009 "build_fcn": (
3010 build_binary_broadcast,
3011 TosaTensorGen.tgBroadcastFuzz,
3012 TosaTensorValuesGen.tvgIntDiv,
3013 None,
3014 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003015 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003016 "error_if_validators": (
3017 TosaErrorValidator.evRankMismatch,
3018 TosaErrorValidator.evWrongInputType,
3019 TosaErrorValidator.evWrongOutputType,
3020 TosaErrorValidator.evWrongInputList,
3021 TosaErrorValidator.evWrongOutputList,
3022 TosaErrorValidator.evDimensionMismatch,
3023 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003024 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003025 "logical_and": {
3026 "op": Op.LOGICAL_AND,
3027 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003028 "build_fcn": (
3029 build_binary_broadcast,
3030 TosaTensorGen.tgBroadcastFuzz,
3031 TosaTensorValuesGen.tvgDefault,
3032 None,
3033 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003034 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003035 "error_if_validators": (
3036 TosaErrorValidator.evRankMismatch,
3037 TosaErrorValidator.evWrongInputType,
3038 TosaErrorValidator.evWrongOutputType,
3039 TosaErrorValidator.evWrongInputList,
3040 TosaErrorValidator.evWrongOutputList,
3041 TosaErrorValidator.evDimensionMismatch,
3042 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003043 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003044 "logical_left_shift": {
3045 "op": Op.LOGICAL_LEFT_SHIFT,
3046 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003047 "build_fcn": (
3048 build_binary_broadcast,
3049 TosaTensorGen.tgBroadcastFuzz,
3050 TosaTensorValuesGen.tvgLogicalShift,
3051 None,
3052 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003053 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003054 "error_if_validators": (
3055 TosaErrorValidator.evRankMismatch,
3056 TosaErrorValidator.evWrongInputType,
3057 TosaErrorValidator.evWrongOutputType,
3058 TosaErrorValidator.evWrongInputList,
3059 TosaErrorValidator.evWrongOutputList,
3060 TosaErrorValidator.evDimensionMismatch,
3061 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003062 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003063 "logical_right_shift": {
3064 "op": Op.LOGICAL_RIGHT_SHIFT,
3065 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003066 "build_fcn": (
3067 build_binary_broadcast,
3068 TosaTensorGen.tgBroadcastFuzz,
3069 TosaTensorValuesGen.tvgLogicalShift,
3070 None,
3071 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003072 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003073 "error_if_validators": (
3074 TosaErrorValidator.evRankMismatch,
3075 TosaErrorValidator.evWrongInputType,
3076 TosaErrorValidator.evWrongOutputType,
3077 TosaErrorValidator.evWrongInputList,
3078 TosaErrorValidator.evWrongOutputList,
3079 TosaErrorValidator.evDimensionMismatch,
3080 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003081 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003082 "logical_or": {
3083 "op": Op.LOGICAL_OR,
3084 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003085 "build_fcn": (
3086 build_binary_broadcast,
3087 TosaTensorGen.tgBroadcastFuzz,
3088 TosaTensorValuesGen.tvgDefault,
3089 None,
3090 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003091 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003092 "error_if_validators": (
3093 TosaErrorValidator.evRankMismatch,
3094 TosaErrorValidator.evWrongInputType,
3095 TosaErrorValidator.evWrongOutputType,
3096 TosaErrorValidator.evWrongInputList,
3097 TosaErrorValidator.evWrongOutputList,
3098 TosaErrorValidator.evDimensionMismatch,
3099 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003100 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003101 "logical_xor": {
3102 "op": Op.LOGICAL_XOR,
3103 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003104 "build_fcn": (
3105 build_binary_broadcast,
3106 TosaTensorGen.tgBroadcastFuzz,
3107 TosaTensorValuesGen.tvgDefault,
3108 None,
3109 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003110 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003111 "error_if_validators": (
3112 TosaErrorValidator.evRankMismatch,
3113 TosaErrorValidator.evWrongInputType,
3114 TosaErrorValidator.evWrongOutputType,
3115 TosaErrorValidator.evWrongInputList,
3116 TosaErrorValidator.evWrongOutputList,
3117 TosaErrorValidator.evDimensionMismatch,
3118 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003119 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003120 "maximum": {
3121 "op": Op.MAXIMUM,
3122 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003123 "build_fcn": (
3124 build_binary_broadcast,
3125 TosaTensorGen.tgBroadcastFuzz,
3126 TosaTensorValuesGen.tvgDefault,
3127 None,
3128 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003129 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003130 "error_if_validators": (
3131 TosaErrorValidator.evRankMismatch,
3132 TosaErrorValidator.evWrongInputType,
3133 TosaErrorValidator.evWrongOutputType,
3134 TosaErrorValidator.evWrongInputList,
3135 TosaErrorValidator.evWrongOutputList,
3136 TosaErrorValidator.evDimensionMismatch,
3137 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003138 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003139 "minimum": {
3140 "op": Op.MINIMUM,
3141 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003142 "build_fcn": (
3143 build_binary_broadcast,
3144 TosaTensorGen.tgBroadcastFuzz,
3145 TosaTensorValuesGen.tvgDefault,
3146 None,
3147 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003148 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003149 "error_if_validators": (
3150 TosaErrorValidator.evRankMismatch,
3151 TosaErrorValidator.evWrongInputType,
3152 TosaErrorValidator.evWrongOutputType,
3153 TosaErrorValidator.evWrongInputList,
3154 TosaErrorValidator.evWrongOutputList,
3155 TosaErrorValidator.evDimensionMismatch,
3156 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003157 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003158 "mul": {
3159 "op": Op.MUL,
3160 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003161 "build_fcn": (
3162 build_mul,
3163 TosaTensorGen.tgBroadcastFuzz,
3164 TosaTensorValuesGen.tvgMul,
3165 TosaArgGen.agMul,
3166 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003167 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003168 "error_if_validators": (
3169 TosaErrorValidator.evWrongInputType,
3170 TosaErrorValidator.evWrongOutputType,
3171 TosaErrorValidator.evWrongInputList,
3172 TosaErrorValidator.evWrongOutputList,
3173 TosaErrorValidator.evRankMismatch,
3174 TosaErrorValidator.evDimensionMismatch,
3175 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003177 "pow": {
3178 "op": Op.POW,
3179 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003180 "build_fcn": (
3181 build_binary_broadcast,
3182 TosaTensorGen.tgBroadcastFuzz,
3183 TosaTensorValuesGen.tvgDefault,
3184 None,
3185 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003186 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003187 "error_if_validators": (
3188 TosaErrorValidator.evRankMismatch,
3189 TosaErrorValidator.evWrongInputType,
3190 TosaErrorValidator.evWrongOutputType,
3191 TosaErrorValidator.evWrongInputList,
3192 TosaErrorValidator.evWrongOutputList,
3193 TosaErrorValidator.evDimensionMismatch,
3194 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003195 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003196 "sub": {
3197 "op": Op.SUB,
3198 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003199 "build_fcn": (
3200 build_binary_broadcast,
3201 TosaTensorGen.tgBroadcastFuzz,
3202 TosaTensorValuesGen.tvgAddSub,
3203 None,
3204 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003205 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003206 "error_if_validators": (
3207 TosaErrorValidator.evRankMismatch,
3208 TosaErrorValidator.evWrongInputType,
3209 TosaErrorValidator.evWrongOutputType,
3210 TosaErrorValidator.evWrongInputList,
3211 TosaErrorValidator.evWrongOutputList,
3212 TosaErrorValidator.evDimensionMismatch,
3213 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003214 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003215 "table": {
3216 "op": Op.TABLE,
3217 # Use the automatic generation functions to create the input array
3218 # but create the table tensor in the build function, as it may be
3219 # a different type from the input
3220 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003221 "build_fcn": (
3222 build_table,
3223 TosaTensorGen.tgBasic,
3224 TosaTensorValuesGen.tvgDefault,
3225 TosaArgGen.agTable,
3226 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01003227 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003228 "error_if_validators": (
3229 TosaErrorValidator.evWrongInputType,
3230 TosaErrorValidator.evWrongOutputType,
3231 TosaErrorValidator.evWrongInputList,
3232 TosaErrorValidator.evWrongOutputList,
3233 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003234 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003235 # Elementwise Unary operators
3236 "abs": {
3237 "op": Op.ABS,
3238 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003239 "build_fcn": (
3240 build_unary,
3241 TosaTensorGen.tgBasic,
3242 TosaTensorValuesGen.tvgDefault,
3243 None,
3244 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003245 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003246 "error_if_validators": (
3247 TosaErrorValidator.evWrongInputType,
3248 TosaErrorValidator.evWrongOutputType,
3249 TosaErrorValidator.evWrongInputList,
3250 TosaErrorValidator.evWrongOutputList,
3251 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003252 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003253 "bitwise_not": {
3254 "op": Op.BITWISE_NOT,
3255 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003256 "build_fcn": (
3257 build_unary,
3258 TosaTensorGen.tgBasic,
3259 TosaTensorValuesGen.tvgDefault,
3260 None,
3261 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003262 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003263 "error_if_validators": (
3264 TosaErrorValidator.evWrongInputType,
3265 TosaErrorValidator.evWrongOutputType,
3266 TosaErrorValidator.evWrongInputList,
3267 TosaErrorValidator.evWrongOutputList,
3268 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003269 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003270 "ceil": {
3271 "op": Op.CEIL,
3272 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003273 "build_fcn": (
3274 build_unary,
3275 TosaTensorGen.tgBasic,
3276 TosaTensorValuesGen.tvgDefault,
3277 None,
3278 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003279 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003280 "error_if_validators": (
3281 TosaErrorValidator.evWrongInputType,
3282 TosaErrorValidator.evWrongOutputType,
3283 TosaErrorValidator.evWrongInputList,
3284 TosaErrorValidator.evWrongOutputList,
3285 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003286 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003287 "clz": {
3288 "op": Op.CLZ,
3289 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003290 "build_fcn": (
3291 build_unary,
3292 TosaTensorGen.tgBasic,
3293 TosaTensorValuesGen.tvgDefault,
3294 None,
3295 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003296 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003297 "error_if_validators": (
3298 TosaErrorValidator.evWrongInputType,
3299 TosaErrorValidator.evWrongOutputType,
3300 TosaErrorValidator.evWrongInputList,
3301 TosaErrorValidator.evWrongOutputList,
3302 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003303 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003304 "exp": {
3305 "op": Op.EXP,
3306 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003307 "build_fcn": (
3308 build_unary,
3309 TosaTensorGen.tgBasic,
3310 TosaTensorValuesGen.tvgDefault,
3311 None,
3312 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003313 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003314 "error_if_validators": (
3315 TosaErrorValidator.evWrongInputType,
3316 TosaErrorValidator.evWrongOutputType,
3317 TosaErrorValidator.evWrongInputList,
3318 TosaErrorValidator.evWrongOutputList,
3319 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003320 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003321 "floor": {
3322 "op": Op.FLOOR,
3323 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003324 "build_fcn": (
3325 build_unary,
3326 TosaTensorGen.tgBasic,
3327 TosaTensorValuesGen.tvgDefault,
3328 None,
3329 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003330 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003331 "error_if_validators": (
3332 TosaErrorValidator.evWrongInputType,
3333 TosaErrorValidator.evWrongOutputType,
3334 TosaErrorValidator.evWrongInputList,
3335 TosaErrorValidator.evWrongOutputList,
3336 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "log": {
3339 "op": Op.LOG,
3340 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003341 "build_fcn": (
3342 build_unary,
3343 TosaTensorGen.tgBasic,
3344 TosaTensorValuesGen.tvgDefault,
3345 None,
3346 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003347 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003348 "error_if_validators": (
3349 TosaErrorValidator.evWrongInputType,
3350 TosaErrorValidator.evWrongOutputType,
3351 TosaErrorValidator.evWrongInputList,
3352 TosaErrorValidator.evWrongOutputList,
3353 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003354 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003355 "logical_not": {
3356 "op": Op.LOGICAL_NOT,
3357 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003358 "build_fcn": (
3359 build_unary,
3360 TosaTensorGen.tgBasic,
3361 TosaTensorValuesGen.tvgDefault,
3362 None,
3363 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003364 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003365 "error_if_validators": (
3366 TosaErrorValidator.evWrongInputType,
3367 TosaErrorValidator.evWrongOutputType,
3368 TosaErrorValidator.evWrongInputList,
3369 TosaErrorValidator.evWrongOutputList,
3370 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003371 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003372 "negate": {
3373 "op": Op.NEGATE,
3374 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003375 "build_fcn": (
3376 build_unary,
3377 TosaTensorGen.tgBasic,
3378 TosaTensorValuesGen.tvgNegate,
3379 None,
3380 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003381 "qgen": TosaQuantGen.qgUnary,
3382 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003383 "error_if_validators": (
3384 TosaErrorValidator.evInputZeroPointNotZero,
3385 TosaErrorValidator.evOutputZeroPointNotZero,
3386 TosaErrorValidator.evWrongInputType,
3387 TosaErrorValidator.evWrongOutputType,
3388 TosaErrorValidator.evWrongInputList,
3389 TosaErrorValidator.evWrongOutputList,
3390 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003391 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003392 "reciprocal": {
3393 "op": Op.RECIPROCAL,
3394 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003395 "build_fcn": (
3396 build_unary,
3397 TosaTensorGen.tgBasic,
3398 TosaTensorValuesGen.tvgDefault,
3399 None,
3400 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003401 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003402 "error_if_validators": (
3403 TosaErrorValidator.evWrongInputType,
3404 TosaErrorValidator.evWrongOutputType,
3405 TosaErrorValidator.evWrongInputList,
3406 TosaErrorValidator.evWrongOutputList,
3407 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003408 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003409 "rsqrt": {
3410 "op": Op.RSQRT,
3411 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003412 "build_fcn": (
3413 build_unary,
3414 TosaTensorGen.tgBasic,
3415 TosaTensorValuesGen.tvgDefault,
3416 None,
3417 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003419 "error_if_validators": (
3420 TosaErrorValidator.evWrongInputType,
3421 TosaErrorValidator.evWrongOutputType,
3422 TosaErrorValidator.evWrongInputList,
3423 TosaErrorValidator.evWrongOutputList,
3424 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003425 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003426 # Elementwise Ternary operators
3427 "select": {
3428 "op": Op.SELECT,
3429 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003430 "build_fcn": (
3431 build_select,
3432 TosaTensorGen.tgBroadcastFuzz,
3433 TosaTensorValuesGen.tvgSelect,
3434 None,
3435 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003436 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003437 "error_if_validators": (
3438 TosaErrorValidator.evRankMismatch,
3439 TosaErrorValidator.evWrongInputType,
3440 TosaErrorValidator.evWrongOutputType,
3441 TosaErrorValidator.evWrongInputList,
3442 TosaErrorValidator.evWrongOutputList,
3443 TosaErrorValidator.evDimensionMismatch,
3444 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003445 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003446 # Comparison operators
3447 "equal": {
3448 "op": Op.EQUAL,
3449 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003450 "build_fcn": (
3451 build_comparison,
3452 TosaTensorGen.tgBroadcastFuzz,
3453 TosaTensorValuesGen.tvgEqual,
3454 None,
3455 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003456 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003457 "error_if_validators": (
3458 TosaErrorValidator.evRankMismatch,
3459 TosaErrorValidator.evWrongInputType,
3460 TosaErrorValidator.evWrongOutputType,
3461 TosaErrorValidator.evWrongInputList,
3462 TosaErrorValidator.evWrongOutputList,
3463 TosaErrorValidator.evDimensionMismatch,
3464 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003465 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003466 "greater_equal": {
3467 "op": Op.GREATER_EQUAL,
3468 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003469 "build_fcn": (
3470 build_comparison,
3471 TosaTensorGen.tgBroadcastFuzz,
3472 TosaTensorValuesGen.tvgDefault,
3473 None,
3474 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003475 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003476 "error_if_validators": (
3477 TosaErrorValidator.evRankMismatch,
3478 TosaErrorValidator.evWrongInputType,
3479 TosaErrorValidator.evWrongOutputType,
3480 TosaErrorValidator.evWrongInputList,
3481 TosaErrorValidator.evWrongOutputList,
3482 TosaErrorValidator.evDimensionMismatch,
3483 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003484 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003485 "greater": {
3486 "op": Op.GREATER,
3487 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003488 "build_fcn": (
3489 build_comparison,
3490 TosaTensorGen.tgBroadcastFuzz,
3491 TosaTensorValuesGen.tvgDefault,
3492 None,
3493 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003494 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003495 "error_if_validators": (
3496 TosaErrorValidator.evRankMismatch,
3497 TosaErrorValidator.evWrongInputType,
3498 TosaErrorValidator.evWrongOutputType,
3499 TosaErrorValidator.evWrongInputList,
3500 TosaErrorValidator.evWrongOutputList,
3501 TosaErrorValidator.evDimensionMismatch,
3502 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003503 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003504 # Reduction operators
3505 "reduce_all": {
3506 "op": Op.REDUCE_ALL,
3507 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003508 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003509 "build_fcn": (
3510 build_reduce,
3511 TosaTensorGen.tgBasic,
3512 TosaTensorValuesGen.tvgDefault,
3513 TosaArgGen.agAxis,
3514 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003515 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003516 "error_if_validators": (
3517 TosaErrorValidator.evAxisLargerRank,
3518 TosaErrorValidator.evAxisSmallerZero,
3519 TosaErrorValidator.evShapeOfAxisNotOne,
3520 TosaErrorValidator.evWrongInputType,
3521 TosaErrorValidator.evWrongOutputType,
3522 TosaErrorValidator.evWrongRank,
3523 TosaErrorValidator.evWrongInputList,
3524 TosaErrorValidator.evWrongOutputList,
3525 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003526 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003527 "reduce_any": {
3528 "op": Op.REDUCE_ANY,
3529 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003530 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003531 "build_fcn": (
3532 build_reduce,
3533 TosaTensorGen.tgBasic,
3534 TosaTensorValuesGen.tvgDefault,
3535 TosaArgGen.agAxis,
3536 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003537 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003538 "error_if_validators": (
3539 TosaErrorValidator.evAxisLargerRank,
3540 TosaErrorValidator.evAxisSmallerZero,
3541 TosaErrorValidator.evShapeOfAxisNotOne,
3542 TosaErrorValidator.evWrongInputType,
3543 TosaErrorValidator.evWrongOutputType,
3544 TosaErrorValidator.evWrongRank,
3545 TosaErrorValidator.evWrongInputList,
3546 TosaErrorValidator.evWrongOutputList,
3547 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 "reduce_max": {
3550 "op": Op.REDUCE_MAX,
3551 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003552 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003553 "build_fcn": (
3554 build_reduce,
3555 TosaTensorGen.tgBasic,
3556 TosaTensorValuesGen.tvgDefault,
3557 TosaArgGen.agAxis,
3558 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003560 "error_if_validators": (
3561 TosaErrorValidator.evAxisLargerRank,
3562 TosaErrorValidator.evAxisSmallerZero,
3563 TosaErrorValidator.evShapeOfAxisNotOne,
3564 TosaErrorValidator.evWrongInputType,
3565 TosaErrorValidator.evWrongOutputType,
3566 TosaErrorValidator.evWrongRank,
3567 TosaErrorValidator.evWrongInputList,
3568 TosaErrorValidator.evWrongOutputList,
3569 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003570 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003571 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00003572 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003573 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003574 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003575 "build_fcn": (
3576 build_reduce,
3577 TosaTensorGen.tgBasic,
3578 TosaTensorValuesGen.tvgDefault,
3579 TosaArgGen.agAxis,
3580 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003581 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003582 "error_if_validators": (
3583 TosaErrorValidator.evAxisLargerRank,
3584 TosaErrorValidator.evAxisSmallerZero,
3585 TosaErrorValidator.evShapeOfAxisNotOne,
3586 TosaErrorValidator.evWrongInputType,
3587 TosaErrorValidator.evWrongOutputType,
3588 TosaErrorValidator.evWrongRank,
3589 TosaErrorValidator.evWrongInputList,
3590 TosaErrorValidator.evWrongOutputList,
3591 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003592 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003593 "reduce_product": {
3594 "op": Op.REDUCE_PRODUCT,
3595 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003596 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003597 "build_fcn": (
3598 build_reduce,
3599 TosaTensorGen.tgBasic,
3600 TosaTensorValuesGen.tvgDefault,
3601 TosaArgGen.agAxis,
3602 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003603 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003604 "error_if_validators": (
3605 TosaErrorValidator.evAxisLargerRank,
3606 TosaErrorValidator.evAxisSmallerZero,
3607 TosaErrorValidator.evShapeOfAxisNotOne,
3608 TosaErrorValidator.evWrongInputType,
3609 TosaErrorValidator.evWrongOutputType,
3610 TosaErrorValidator.evWrongRank,
3611 TosaErrorValidator.evWrongInputList,
3612 TosaErrorValidator.evWrongOutputList,
3613 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003614 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003615 "reduce_sum": {
3616 "op": Op.REDUCE_SUM,
3617 "operands": (1, 0),
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00003618 "rank": (1, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003619 "build_fcn": (
3620 build_reduce,
3621 TosaTensorGen.tgBasic,
3622 TosaTensorValuesGen.tvgReduceSum,
3623 TosaArgGen.agAxis,
3624 ),
James Ward24dbc422022-10-19 12:20:31 +01003625 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003626 "error_if_validators": (
3627 TosaErrorValidator.evAxisLargerRank,
3628 TosaErrorValidator.evAxisSmallerZero,
3629 TosaErrorValidator.evShapeOfAxisNotOne,
3630 TosaErrorValidator.evWrongInputType,
3631 TosaErrorValidator.evWrongOutputType,
3632 TosaErrorValidator.evWrongRank,
3633 TosaErrorValidator.evWrongInputList,
3634 TosaErrorValidator.evWrongOutputList,
3635 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003636 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003637 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003638 "concat": {
3639 "op": Op.CONCAT,
3640 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003641 "build_fcn": (
3642 build_concat,
3643 TosaTensorGen.tgConcat,
3644 TosaTensorValuesGen.tvgConcat,
3645 TosaArgGen.agAxis,
3646 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003647 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003648 "error_if_validators": (
3649 TosaErrorValidator.evAxisLargerRank,
3650 TosaErrorValidator.evAxisSmallerZero,
3651 TosaErrorValidator.evConcatInputRankMismatch,
3652 TosaErrorValidator.evConcatShapeSumMismatch,
3653 TosaErrorValidator.evConcatInputDimMismatch,
3654 TosaErrorValidator.evWrongInputType,
3655 TosaErrorValidator.evWrongOutputType,
3656 TosaErrorValidator.evWrongOutputList,
3657 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003658 },
3659 "pad": {
3660 "op": Op.PAD,
3661 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003662 "build_fcn": (
3663 build_pad,
3664 TosaTensorGen.tgBasic,
3665 TosaTensorValuesGen.tvgDefault,
3666 TosaArgGen.agPad,
3667 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003668 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003669 "error_if_validators": (
3670 TosaErrorValidator.evWrongInputType,
3671 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01003672 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003673 TosaErrorValidator.evWrongOutputType,
3674 TosaErrorValidator.evWrongInputList,
3675 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003676 TosaErrorValidator.evRankMismatch,
3677 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003678 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003679 },
3680 "reshape": {
3681 "op": Op.RESHAPE,
3682 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003683 "build_fcn": (
3684 build_reshape,
3685 TosaTensorGen.tgBasic,
3686 TosaTensorValuesGen.tvgDefault,
3687 TosaArgGen.agReshape,
3688 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003689 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003690 "error_if_validators": (
3691 TosaErrorValidator.evTensorSizeInputOutputMismatch,
3692 TosaErrorValidator.evWrongInputType,
3693 TosaErrorValidator.evWrongOutputType,
3694 TosaErrorValidator.evWrongInputList,
3695 TosaErrorValidator.evWrongOutputList,
3696 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003697 },
3698 "reverse": {
3699 "op": Op.REVERSE,
3700 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003701 "build_fcn": (
3702 build_reverse,
3703 TosaTensorGen.tgBasic,
3704 TosaTensorValuesGen.tvgDefault,
3705 TosaArgGen.agAxis,
3706 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003707 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003708 "error_if_validators": (
3709 TosaErrorValidator.evAxisSmallerZero,
3710 TosaErrorValidator.evAxisLargerRank,
3711 TosaErrorValidator.evWrongInputType,
3712 TosaErrorValidator.evWrongOutputType,
3713 TosaErrorValidator.evWrongInputList,
3714 TosaErrorValidator.evWrongOutputList,
3715 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003716 },
3717 "slice": {
3718 "op": Op.SLICE,
3719 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003720 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003721 "build_fcn": (
3722 build_slice,
3723 TosaTensorGen.tgBasic,
3724 TosaTensorValuesGen.tvgDefault,
3725 TosaArgGen.agSlice,
3726 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003727 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003728 "error_if_validators": (
3729 TosaErrorValidator.evStartSmallerZero,
3730 TosaErrorValidator.evSizeSmallerEqualZero,
3731 TosaErrorValidator.evStartSizeOutsideBounds,
3732 TosaErrorValidator.evSizeOutputShapeMismatch,
3733 TosaErrorValidator.evInputSizeStartLengthMismatch,
3734 TosaErrorValidator.evWrongRank,
3735 TosaErrorValidator.evWrongInputType,
3736 TosaErrorValidator.evWrongOutputType,
3737 TosaErrorValidator.evWrongInputList,
3738 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003739 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003740 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003741 },
3742 "tile": {
3743 "op": Op.TILE,
3744 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003745 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003746 "build_fcn": (
3747 build_tile,
3748 TosaTensorGen.tgBasic,
3749 TosaTensorValuesGen.tvgDefault,
3750 TosaArgGen.agTile,
3751 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003752 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003753 "error_if_validators": (
3754 TosaErrorValidator.evWrongInputType,
3755 TosaErrorValidator.evWrongOutputType,
3756 TosaErrorValidator.evWrongInputList,
3757 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003758 TosaErrorValidator.evRankMismatch,
3759 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003760 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003761 },
3762 "transpose": {
3763 "op": Op.TRANSPOSE,
3764 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00003765 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003766 "build_fcn": (
3767 build_transpose,
3768 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003769 TosaTensorValuesGen.tvgDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003770 TosaArgGen.agTranspose,
3771 ),
3772 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003773 "error_if_validators": (
3774 TosaErrorValidator.evIndexOutsideBounds,
3775 TosaErrorValidator.evIndexUsedTwice,
3776 TosaErrorValidator.evWrongInputType,
3777 TosaErrorValidator.evWrongOutputType,
3778 TosaErrorValidator.evWrongInputList,
3779 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00003780 TosaErrorValidator.evWrongRank,
3781 TosaErrorValidator.evRankMismatch,
3782 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003783 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003784 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003785 # Data nodes
3786 "const": {
3787 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07003788 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 "build_fcn": (
3790 build_const,
3791 TosaTensorGen.tgBasic,
3792 TosaTensorValuesGen.tvgDefault,
3793 None,
3794 ),
Luke Hutton65872422023-02-20 10:33:04 +00003795 "types": TYPE_FIB + [DType.INT48],
Jared Smolens573ecd42021-03-04 15:24:10 -08003796 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003797 "identity": {
3798 "op": Op.IDENTITY,
3799 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003800 "build_fcn": (
3801 build_unary,
3802 TosaTensorGen.tgBasic,
3803 TosaTensorValuesGen.tvgDefault,
3804 None,
3805 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003806 "types": TYPE_FIB,
3807 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003808 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08003809 "gather": {
3810 "op": Op.GATHER,
3811 # Only specify 'values' tensor here. 'indices' is generated in op building stage
3812 "operands": (1, 0),
3813 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003814 "build_fcn": (
3815 build_gather,
3816 TosaTensorGen.tgBasic,
3817 TosaTensorValuesGen.tvgDefault,
3818 None,
3819 ),
James Ward24dbc422022-10-19 12:20:31 +01003820 "types": (
3821 DType.INT8,
3822 DType.INT16,
3823 DType.INT32,
3824 DType.FP16,
3825 DType.BF16,
3826 DType.FP32,
3827 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003828 "error_if_validators": (
3829 TosaErrorValidator.evWrongInputType,
3830 TosaErrorValidator.evWrongOutputType,
3831 TosaErrorValidator.evWrongInputList,
3832 TosaErrorValidator.evWrongOutputList,
3833 TosaErrorValidator.evWrongRank,
3834 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003835 },
3836 "scatter": {
3837 "op": Op.SCATTER,
3838 # Only specify 'values_in' tensor here.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 # 'indices' and 'input' are generated in op building stage
Kevin Cheng550ccc52021-03-03 11:21:43 -08003840 "operands": (2, 0),
3841 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003842 "build_fcn": (
3843 build_scatter,
3844 TosaTensorGen.tgScatter,
3845 TosaTensorValuesGen.tvgDefault,
3846 None,
3847 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003848 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003849 "error_if_validators": (
3850 TosaErrorValidator.evWrongInputType,
3851 TosaErrorValidator.evWrongOutputType,
3852 TosaErrorValidator.evWrongInputList,
3853 TosaErrorValidator.evWrongOutputList,
3854 TosaErrorValidator.evWrongRank,
3855 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003856 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003857 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08003858 "resize": {
3859 "op": Op.RESIZE,
3860 "operands": (1, 0),
3861 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003862 "build_fcn": (
3863 build_resize,
3864 TosaTensorGen.tgNHWC,
3865 TosaTensorValuesGen.tvgDefault,
3866 TosaArgGen.agResize,
3867 ),
James Ward24dbc422022-10-19 12:20:31 +01003868 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003869 "invalid_test_validators": (
3870 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 ),
3872 "error_if_validators": (
3873 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003874 TosaErrorValidator.evScaleSmallerEqualZero,
3875 TosaErrorValidator.evScaleNLargerMax,
3876 TosaErrorValidator.evScaleDLargerMax,
3877 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003878 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003879 TosaErrorValidator.evBorderSmallerMin,
3880 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003881 TosaErrorValidator.evWrongInputType,
3882 TosaErrorValidator.evWrongOutputType,
3883 TosaErrorValidator.evWrongRank,
3884 TosaErrorValidator.evWrongInputList,
3885 TosaErrorValidator.evWrongOutputList,
3886 TosaErrorValidator.evBatchMismatch,
3887 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003888 TosaErrorValidator.evResizeOutputShapeMismatch,
3889 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003890 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003891 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003892 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08003893 "cast": {
3894 "op": Op.CAST,
3895 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003896 "build_fcn": (
3897 build_cast,
3898 TosaTensorGen.tgBasic,
3899 TosaTensorValuesGen.tvgDefault,
3900 TosaArgGen.agCast,
3901 ),
James Ward8b390432022-08-12 20:48:56 +01003902 "types": (
3903 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003904 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003905 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003906 DType.INT8,
3907 DType.INT16,
3908 DType.INT32,
3909 DType.BOOL,
3910 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003911 "error_if_validators": (
3912 TosaErrorValidator.evWrongInputType,
3913 TosaErrorValidator.evWrongOutputType,
3914 TosaErrorValidator.evWrongInputList,
3915 TosaErrorValidator.evWrongOutputList,
3916 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003917 },
3918 "rescale": {
3919 "op": Op.RESCALE,
3920 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003921 "build_fcn": (
3922 build_rescale,
3923 TosaTensorGen.tgBasic,
3924 TosaTensorValuesGen.tvgDefault,
3925 TosaArgGen.agRescale,
3926 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003927 "types": [
3928 DType.UINT8,
3929 DType.INT8,
3930 DType.INT16,
3931 DType.INT32,
3932 DType.INT48,
3933 DType.UINT16,
3934 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003935 "error_if_validators": (
3936 TosaErrorValidator.evInputZeroPointNotZero,
3937 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01003938 TosaErrorValidator.evU16InputZeroPointNotValid,
3939 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003940 TosaErrorValidator.evScaleTrue,
3941 TosaErrorValidator.evScaleNotTrue,
3942 TosaErrorValidator.evWrongInputType,
3943 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003944 TosaErrorValidator.evWrongInputList,
3945 TosaErrorValidator.evWrongOutputList,
3946 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003947 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003948 # Custom
3949 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08003950 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07003951 # Two varients of cond_if, one that generates one of two constant tensors (no
3952 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
3953 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08003954 "cond_if_const": {
3955 "op": Op.COND_IF,
3956 "operands": (0, 2),
3957 "build_fcn": (
3958 build_cond_if_const,
3959 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003960 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003961 TosaArgGen.agCondIf,
3962 ),
3963 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003964 "error_if_validators": (
3965 TosaErrorValidator.evOutputListThenGraphMismatch,
3966 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003967 TosaErrorValidator.evCondIfCondNotMatchingBool,
3968 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003969 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003970 },
3971 "cond_if_binary": {
3972 "op": Op.COND_IF,
3973 "operands": (2, 0),
3974 "build_fcn": (
3975 build_cond_if_binary,
3976 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003977 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003978 TosaArgGen.agCondIf,
3979 ),
Les Bell6040b4d2021-10-11 12:50:31 +01003980 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003981 "error_if_validators": (
3982 TosaErrorValidator.evInputListThenGraphMismatch,
3983 TosaErrorValidator.evInputListElseGraphMismatch,
3984 TosaErrorValidator.evOutputListThenGraphMismatch,
3985 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00003986 TosaErrorValidator.evCondIfCondNotMatchingBool,
3987 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003988 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003989 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003990 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08003991 "while_loop": {
3992 "op": Op.WHILE_LOOP,
3993 "operands": (0, 1),
3994 "build_fcn": (
3995 build_while_loop,
3996 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003997 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003998 TosaArgGen.agWhileLoop,
3999 ),
4000 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004001 "error_if_validators": (
4002 TosaErrorValidator.evInputListOutputListMismatch,
4003 TosaErrorValidator.evInputListCondGraphMismatch,
4004 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4005 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4006 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004007 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004008 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004009 },
Luke Hutton57287132023-02-06 14:54:18 +00004010 "fft2d": {
4011 "op": Op.FFT2D,
4012 "operands": (2, 0),
4013 "rank": (3, 3),
4014 "build_fcn": (
4015 build_fft2d,
4016 TosaTensorGen.tgFFT2d,
4017 TosaTensorValuesGen.tvgDefault,
4018 TosaArgGen.agFFT2d,
4019 ),
4020 "types": [DType.FP32],
4021 "error_if_validators": (
4022 TosaErrorValidator.evWrongInputType,
4023 TosaErrorValidator.evWrongOutputType,
4024 TosaErrorValidator.evWrongInputList,
4025 TosaErrorValidator.evWrongOutputList,
4026 TosaErrorValidator.evWrongRank,
4027 TosaErrorValidator.evBatchMismatch,
4028 TosaErrorValidator.evKernelNotPowerOfTwo,
4029 TosaErrorValidator.evFFTInputShapeMismatch,
4030 TosaErrorValidator.evFFTOutputShapeMismatch,
4031 ),
4032 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004033 "rfft2d": {
4034 "op": Op.RFFT2D,
4035 "operands": (1, 0),
4036 "rank": (3, 3),
4037 "build_fcn": (
4038 build_rfft2d,
4039 TosaTensorGen.tgRFFT2d,
4040 TosaTensorValuesGen.tvgDefault,
4041 TosaArgGen.agNone,
4042 ),
4043 "types": [DType.FP32],
4044 "error_if_validators": (
4045 TosaErrorValidator.evWrongInputType,
4046 TosaErrorValidator.evWrongOutputType,
4047 TosaErrorValidator.evWrongInputList,
4048 TosaErrorValidator.evWrongOutputList,
4049 TosaErrorValidator.evWrongRank,
4050 TosaErrorValidator.evBatchMismatch,
4051 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004052 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004053 ),
4054 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004055 }
4056
Kevin Cheng550ccc52021-03-03 11:21:43 -08004057
Eric Kunzee5e26762020-10-13 16:11:07 -07004058class OutputShaper:
4059 # Methods in this class compute the expected output shape and datatype
4060 # for common classes of operations
4061 def __init__(self):
4062 pass
4063
4064 # These methods return arguments that can be used for
4065 # creating a new output tensor
4066 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004067 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
4068 if error_name != ErrorIf.RankMismatch:
4069 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004070 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004071
4072 shape = []
4073 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004074 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07004075 shape.append(b.shape[i])
4076 else:
4077 shape.append(a.shape[i])
4078
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004079 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004080 all_dtypes = [
4081 DType.INT8,
4082 DType.INT16,
4083 DType.INT32,
4084 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01004085 DType.FP16,
4086 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004087 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004088 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01004089 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4090 outputDType = rng.choice(wrong_dtypes)
4091 else:
4092 outputDType = a.dtype
4093
4094 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004095
4096 @staticmethod
4097 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004098 assert len(a.shape) == len(b.shape)
4099 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004100
4101 shape = []
4102 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004103 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07004104 shape.append(a.shape[i])
4105
Kevin Cheng550ccc52021-03-03 11:21:43 -08004106 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004107
4108 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004109 def unaryOp(ser, rng, a, error_name=None):
4110 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004111 all_dtypes = [
4112 DType.INT8,
4113 DType.INT16,
4114 DType.INT32,
4115 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004116 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004117 DType.FP16,
4118 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004119 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01004120 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4121 outputDType = rng.choice(wrong_dtypes)
4122 else:
4123 outputDType = a.dtype
4124
4125 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004126
4127 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004128 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004129 if error_name != ErrorIf.RankMismatch:
4130 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004131 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004132
4133 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004134 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004135 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004136 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
4137 else:
4138 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07004139
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004140 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004141 all_dtypes = [
4142 DType.INT8,
4143 DType.INT16,
4144 DType.INT32,
4145 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004146 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004147 DType.FP16,
4148 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004149 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004150 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4151 outputDType = rng.choice(wrong_dtypes)
4152 else:
4153 outputDType = a.dtype
4154
4155 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004156
4157 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004158 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00004159 if error_name != ErrorIf.RankMismatch:
4160 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004161 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004162
4163 # Do broadcast
4164 shape = []
4165 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08004166 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07004167 shape.append(b.shape[i])
4168 else:
4169 shape.append(a.shape[i])
4170
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004171 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004172 wrong_dtypes = [
4173 DType.INT8,
4174 DType.INT16,
4175 DType.INT32,
4176 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004177 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004178 DType.FP16,
4179 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004180 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004181 outputDType = rng.choice(wrong_dtypes)
4182 else:
4183 outputDType = DType.BOOL
4184
4185 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004186
4187 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01004188 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004189 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004190 if error_name not in [
4191 ErrorIf.AxisSmallerZero,
4192 ErrorIf.AxisLargerRank,
4193 ErrorIf.ShapeOfAxisNotOne,
4194 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01004195 shape[axis] = 1
4196 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
4197 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07004198
Matthew Haddond6ce7252021-09-29 15:35:44 +01004199 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004200 all_dtypes = [
4201 DType.INT8,
4202 DType.INT16,
4203 DType.INT32,
4204 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004205 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004206 DType.FP16,
4207 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004208 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01004209 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4210 outputDType = rng.choice(wrong_dtypes)
4211 else:
4212 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004213
Matthew Haddond6ce7252021-09-29 15:35:44 +01004214 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004215
4216 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004217 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004218 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004219
4220 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
4221 del shape[axis]
4222
4223 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
4224 remove = rng.choice([True, False])
4225 if remove and len(shape) > 1:
4226 del shape[0]
4227 else:
4228 shape.append(1)
4229 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
4230 for i in range(len(shape)):
4231 shape[i] = shape[i] + rng.integers(1, 10)
4232
4233 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004234 all_dtypes = [
4235 DType.INT8,
4236 DType.INT16,
4237 DType.INT32,
4238 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004239 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004240 DType.FP16,
4241 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004242 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004243 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
4244 outputDType = rng.choice(wrong_dtypes)
4245 else:
4246 outputDType = DType.INT32
4247
4248 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004249
4250 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004251 def conv2dOp(
4252 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4253 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004254
4255 # IFM: NHWC
4256 # Filter: OHWI
4257 # OFM: NHWC
4258
Kevin Cheng550ccc52021-03-03 11:21:43 -08004259 h = (
4260 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004261 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004262 + padding[0]
4263 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004264 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004265 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004266
Kevin Cheng550ccc52021-03-03 11:21:43 -08004267 w = (
4268 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004269 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004270 + padding[2]
4271 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004272 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004273 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004274
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004275 if error_name == ErrorIf.ConvOutputShapeMismatch:
4276 choices = [1, 2, 3]
4277 change = rng.choice(choices)
4278 # increment in multiples of stride to not hit non-integer error case
4279 if change in [1, 3]:
4280 h = h + (rng.choice(choices) * strides[0])
4281 if change in [2, 3]:
4282 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004283
Eric Kunzee5e26762020-10-13 16:11:07 -07004284 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
4285
James Ward8b390432022-08-12 20:48:56 +01004286 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004287 # Pick some potentially correct output dtype if input type is incorrect
4288 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004289 else:
James Ward8b390432022-08-12 20:48:56 +01004290 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004291
4292 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004293 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004294 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004295 else:
4296 excludes = [out_dtype]
4297 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004298 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004299
Kevin Cheng550ccc52021-03-03 11:21:43 -08004300 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004301
4302 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004303 def conv3dOp(
4304 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
4305 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07004306
4307 # IFM: NDHWC
4308 # Filter: ODHWI
4309 # OFM: NDHWC
4310
4311 d = (
4312 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004313 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004314 + padding[0]
4315 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004316 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07004317 ) // strides[0] + 1
4318
4319 h = (
4320 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004321 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004322 + padding[2]
4323 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004324 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07004325 ) // strides[1] + 1
4326
4327 w = (
4328 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004329 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07004330 + padding[4]
4331 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004332 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07004333 ) // strides[2] + 1
4334
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004335 if error_name == ErrorIf.ConvOutputShapeMismatch:
4336 choices = [1, 2, 3, 4]
4337 change = rng.choice(choices)
4338 # increment in multiples of stride to not hit non-integer error case
4339 if change in [1, 4]:
4340 d = d + (rng.choice(choices) * strides[0])
4341 if change in [2, 4]:
4342 h = h + (rng.choice(choices) * strides[1])
4343 if change in [3, 4]:
4344 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00004345
Kevin Cheng1533b852021-09-01 12:51:58 -07004346 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
4347
James Ward8b390432022-08-12 20:48:56 +01004348 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004349 # Pick some potentially correct output dtype if input type is incorrect
4350 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07004351 else:
James Ward8b390432022-08-12 20:48:56 +01004352 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004353
4354 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004355 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004356 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004357 else:
4358 excludes = [out_dtype]
4359 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004360 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07004361
4362 return ser.addOutput(ofm_shape, out_dtype)
4363
4364 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004365 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01004366 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004367 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07004368 # IFM: NHWC
4369 # Filter: HWCM
4370 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004371
Kevin Cheng550ccc52021-03-03 11:21:43 -08004372 h = (
4373 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004374 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004375 + padding[0]
4376 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004377 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004378 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004379
Kevin Cheng550ccc52021-03-03 11:21:43 -08004380 w = (
4381 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004382 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08004383 + padding[2]
4384 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004385 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08004386 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07004387
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004388 if error_name == ErrorIf.ConvOutputShapeMismatch:
4389 choices = [1, 2, 3]
4390 change = rng.choice(choices)
4391 # increment in multiples of stride to not hit non-integer error case
4392 if change in [1, 3]:
4393 h = h + (rng.choice(choices) * strides[0])
4394 if change in [2, 3]:
4395 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00004396
Eric Kunzee5e26762020-10-13 16:11:07 -07004397 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
4398
James Ward8b390432022-08-12 20:48:56 +01004399 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004400 # Pick some potentially correct output dtype if input type is incorrect
4401 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004402 else:
James Ward8b390432022-08-12 20:48:56 +01004403 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004404
4405 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004406 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004407 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004408 else:
4409 excludes = [out_dtype]
4410 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004411 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004412
Kevin Cheng550ccc52021-03-03 11:21:43 -08004413 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004414
4415 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004416 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004417 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004418 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004419 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004420 h = 1
4421 w = 1
4422 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004423 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
4424 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004425
4426 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004427 choices = [1, 2, 3]
4428 change = rng.choice(choices)
4429 # increment in multiples of stride to not hit non-integer error case
4430 if change in [1, 3]:
4431 h = h + (rng.choice(choices) * stride[0])
4432 if change in [2, 3]:
4433 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07004434 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004435
4436 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004437 all_dtypes = [
4438 DType.INT8,
4439 DType.INT16,
4440 DType.INT32,
4441 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004442 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004443 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004444 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004445 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01004446 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
4447 outputDType = rng.choice(wrong_dtypes)
4448 else:
4449 outputDType = ifm.dtype
4450
4451 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004452
4453 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004454 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004455 # input: N, IC
4456 # filter: OC, IC
4457 # output: N, OC
4458
4459 output_shape = [input.shape[0], filter.shape[0]]
4460
James Ward8b390432022-08-12 20:48:56 +01004461 # Validated in arg_gen (also invalidated for ErrorIf)
4462 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07004463
Kevin Cheng550ccc52021-03-03 11:21:43 -08004464 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004465
4466 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004467 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07004468 # a: N, H, C
4469 # b: N, C, W
4470 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07004471
Kevin Cheng2d60f002021-06-09 14:18:32 -07004472 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004473
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004474 if error_name == ErrorIf.WrongOutputType:
4475 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004476 incorrect_types = (
4477 DType.INT4,
4478 DType.INT8,
4479 DType.INT16,
4480 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004481 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004482 DType.FP16,
4483 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004484 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004485 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004486 incorrect_types = (
4487 DType.INT4,
4488 DType.INT8,
4489 DType.INT16,
4490 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004491 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004492 DType.FP16,
4493 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004494 )
James Ward24dbc422022-10-19 12:20:31 +01004495 elif (
4496 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
4497 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004498 incorrect_types = (
4499 DType.INT4,
4500 DType.INT8,
4501 DType.INT16,
4502 DType.INT32,
4503 DType.INT48,
4504 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004505 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01004506 elif error_name == ErrorIf.WrongInputType:
4507 # Pick some potentially correct output dtype if input type is incorrect
4508 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004509 else:
James Ward8b390432022-08-12 20:48:56 +01004510 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07004511
Kevin Cheng550ccc52021-03-03 11:21:43 -08004512 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004513
4514 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004515 def concatOp(ser, rng, axis, *a, error_name=None):
Matthew Haddon818ab902021-07-27 09:12:49 +01004516 input1 = a[0]
4517 remaining_inputs = a[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07004518
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004519 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01004520 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004521 if not (
4522 # unable to concat tensors of different ranks
4523 error_name == ErrorIf.ConcatInputRankMismatch
4524 # unable to concat tensors along an invalid axis
4525 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004526 ):
4527 for tensor in remaining_inputs:
4528 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07004529
Matthew Haddon01c359d2021-10-15 16:30:48 +01004530 if error_name == ErrorIf.ConcatShapeSumMismatch:
4531 output_shape[axis] += rng.integers(5, 10)
4532
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004533 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004534 all_dtypes = {
4535 DType.INT8,
4536 DType.INT16,
4537 DType.INT32,
4538 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004539 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004540 DType.FP16,
4541 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004542 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004543 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
4544 outputDType = rng.choice(wrong_dtypes)
4545 else:
4546 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01004547
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004548 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004549
4550 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004551 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004552
4553 output_shape = a.shape.copy()
4554
4555 for i in range(len(output_shape)):
4556 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
4557
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004558 if error_name == ErrorIf.PadOutputShapeMismatch:
4559 bad_dim = rng.choice(range(len(output_shape)))
4560 output_shape[bad_dim] -= rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00004561 elif error_name == ErrorIf.RankMismatch:
4562 output_shape = get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004563
Matthew Haddone807aae2021-10-11 18:12:58 +01004564 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004565 all_dtypes = [
4566 DType.INT8,
4567 DType.INT16,
4568 DType.INT32,
4569 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004570 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004571 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004572 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004573 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004574 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4575 outputDType = rng.choice(wrong_dtypes)
4576 else:
4577 outputDType = a.dtype
4578
4579 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004580
4581 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004582 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004583 output_shape = shape.copy()
4584
Matthew Haddone807aae2021-10-11 18:12:58 +01004585 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4586 for i in range(len(output_shape)):
4587 output_shape[i] = output_shape[i] + rng.integers(1, 10)
4588
4589 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004590 all_dtypes = [
4591 DType.INT8,
4592 DType.INT16,
4593 DType.INT32,
4594 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004595 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004596 DType.FP16,
4597 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004598 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004599 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4600 outputDType = rng.choice(wrong_dtypes)
4601 else:
4602 outputDType = a.dtype
4603
4604 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004605
4606 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00004607 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004608
Matthew Haddone807aae2021-10-11 18:12:58 +01004609 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004610 all_dtypes = [
4611 DType.INT8,
4612 DType.INT16,
4613 DType.INT32,
4614 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004615 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004616 DType.FP16,
4617 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004618 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00004619 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01004620 outputDType = rng.choice(wrong_dtypes)
4621 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00004622 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01004623
Luke Huttona4e48ca2023-02-22 11:53:48 +00004624 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004625 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01004626 for index in range(len(output_shape)):
4627 if output_shape[index] <= 2:
4628 output_shape[index] = output_shape[index] + rng.choice([1, 2])
4629 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004630 output_shape[index] = output_shape[index] + rng.choice(
4631 [-2, -1, 1, 2]
4632 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00004633 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
4634 output_shape = input.shape.copy()
4635 elif error_name == ErrorIf.RankMismatch:
4636 output_shape = get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01004637
4638 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004639
4640 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004641 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004642
4643 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08004644 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004645
4646 for i in range(len(output_shape)):
4647 output_shape[i] = a.shape[i] * multiples[i]
4648
Luke Huttona4e48ca2023-02-22 11:53:48 +00004649 if error_name == ErrorIf.RankMismatch:
4650 output_shape = get_rank_mismatch_shape(rng, output_shape)
4651
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004652 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004653 all_dtypes = [
4654 DType.INT8,
4655 DType.INT16,
4656 DType.INT32,
4657 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004658 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004659 DType.FP16,
4660 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004661 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004662 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4663 outputDType = rng.choice(wrong_dtypes)
4664 else:
4665 outputDType = a.dtype
4666
4667 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004668
4669 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01004670 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07004671 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01004672
Kevin Cheng550ccc52021-03-03 11:21:43 -08004673 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07004674
Luke Huttona4e48ca2023-02-22 11:53:48 +00004675 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01004676 for i in range(len(output_shape)):
4677 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004678
Luke Huttona4e48ca2023-02-22 11:53:48 +00004679 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
4680 for i in range(len(output_shape)):
4681 output_shape[i] += rng.integers(1, 10)
4682 elif error_name == ErrorIf.RankMismatch:
4683 output_shape = get_rank_mismatch_shape(rng, output_shape)
4684
Matthew Haddone807aae2021-10-11 18:12:58 +01004685 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004686 all_dtypes = [
4687 DType.INT8,
4688 DType.INT16,
4689 DType.INT32,
4690 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004691 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004692 DType.FP16,
4693 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004694 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01004695 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
4696 outputDType = rng.choice(wrong_dtypes)
4697 else:
4698 outputDType = a.dtype
4699
4700 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004701
4702 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004703 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004704 if error_name != ErrorIf.WrongRank:
4705 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004706 assert len(indices.shape) == 2
4707 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07004708
Kevin Cheng77d0f762020-11-24 10:26:32 -08004709 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
4710
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004711 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004712 all_dtypes = [
4713 DType.INT8,
4714 DType.INT16,
4715 DType.INT32,
4716 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004717 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004718 DType.FP16,
4719 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004720 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004721 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
4722 outputDType = rng.choice(wrong_dtypes)
4723 else:
4724 outputDType = values.dtype
4725
4726 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08004727
4728 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004729 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00004730 if error_name != ErrorIf.WrongRank:
4731 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08004732 assert len(indices.shape) == 2
4733 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08004734 assert values_in.shape[0] == indices.shape[0] # N
4735 assert input.shape[1] == indices.shape[1] # W
4736 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08004737
4738 output_shape = values_in.shape
4739
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004740 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004741 all_dtypes = [
4742 DType.INT8,
4743 DType.INT16,
4744 DType.INT32,
4745 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004746 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004747 DType.FP16,
4748 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004749 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004750 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
4751 outputDType = rng.choice(wrong_dtypes)
4752 else:
4753 outputDType = values_in.dtype
4754
4755 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07004756
4757 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004758 def tableOp(ser, rng, input, error_name=None):
4759 # Same shape as the input, dtype dependent on input dtype
4760 if error_name != ErrorIf.WrongInputType:
4761 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00004762 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004763 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004764 wrong_dtypes = [
4765 DType.INT8,
4766 DType.INT16,
4767 DType.INT32,
4768 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004769 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01004770 DType.FP16,
4771 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004772 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004773 wrong_dtypes.remove(output_dtype)
4774 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004775 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004776
4777 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08004778 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004779 serializer,
4780 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004781 input,
4782 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004783 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004784 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004785 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004786 input_dtype,
4787 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004788 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004789 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004790 # Calculate OH, OW
4791 scale_y_n = scale[0]
4792 scale_y_d = scale[1]
4793 scale_x_n = scale[2]
4794 scale_x_d = scale[3]
4795 if error_name == ErrorIf.ScaleSmallerEqualZero:
4796 scale_y_n = max(scale_y_n, 1)
4797 scale_y_d = max(scale_y_d, 1)
4798 scale_x_n = max(scale_x_n, 1)
4799 scale_x_d = max(scale_x_d, 1)
4800
4801 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
4802 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
4803
4804 if error_name is not None:
4805 # Make sure the output tensor is valid, which can occur when
4806 # scale, offset or border have been changed for ERROR_IFs
4807 oh = max(oh, 1)
4808 ow = max(ow, 1)
4809 if error_name != ErrorIf.MaxDimExceeded:
4810 oh = min(oh, MAX_RESIZE_DIMENSION - 1)
4811 ow = min(ow, MAX_RESIZE_DIMENSION - 1)
4812
4813 if error_name == ErrorIf.ResizeOutputShapeMismatch:
4814 choices = [1, 2, 3]
4815 change = rng.choice(choices)
4816 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
4817 if change in [1, 3]:
4818 if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
4819 oh -= scale_y_d
4820 assert oh > 0 # Should have been caught in agResize
4821 else:
4822 oh += scale_y_d
4823 if change in [2, 3]:
4824 if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
4825 ow -= scale_x_d
4826 assert ow > 0 # Should have been caught in agResize
4827 else:
4828 ow += scale_x_d
4829
Matthew Haddon848efb42021-09-09 12:30:53 +01004830 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004831 output_dims = [
4832 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004833 oh,
4834 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004835 input.shape[0],
4836 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004837 elif error_name == ErrorIf.BatchMismatch:
4838 output_dims = [
4839 input.shape[0] + rng.integers(1, 10),
4840 oh,
4841 ow,
4842 input.shape[3],
4843 ]
4844 elif error_name == ErrorIf.ChannelMismatch:
4845 output_dims = [
4846 input.shape[0],
4847 oh,
4848 ow,
4849 input.shape[3] + rng.integers(1, 10),
4850 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01004851 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004852 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07004853
Matthew Haddon693ba9e2021-09-22 11:24:37 +01004854 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004855
4856 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01004857 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08004858 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07004859
4860 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01004861 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01004862 if error_name == ErrorIf.ConvOutputShapeMismatch:
4863 choices = [1, 2, 3]
4864 change = rng.choice(choices)
4865 if change in [1, 3]:
4866 output_shape[1] = output_shape[1] + rng.choice(choices)
4867 if change in [2, 3]:
4868 output_shape[2] = output_shape[2] + rng.choice(choices)
4869
James Ward8b390432022-08-12 20:48:56 +01004870 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00004871 # Pick some potentially correct output dtype if input type is incorrect
4872 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07004873 else:
James Ward8b390432022-08-12 20:48:56 +01004874 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00004875
4876 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01004877 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004878 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01004879 else:
4880 excludes = [out_dtype]
4881 wrong_dtypes = list(usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00004882 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07004883
Kevin Cheng550ccc52021-03-03 11:21:43 -08004884 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00004885
4886 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00004887 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
4888 outputs = []
4889
4890 assert ifm1.dtype == ifm2.dtype
4891 input_dtype = ifm1.dtype
4892
4893 if error_name != ErrorIf.FFTInputShapeMismatch:
4894 assert ifm1.shape == ifm2.shape
4895
4896 input_shape = ifm1.shape
4897 if error_name != ErrorIf.WrongRank:
4898 assert len(input_shape) == 3
4899
4900 output_shape = input_shape.copy()
4901 output_dtype = input_dtype
4902
4903 if error_name == ErrorIf.WrongOutputType:
4904 excludes = [DType.FP32]
4905 wrong_dtypes = list(usableDTypes(excludes=excludes))
4906 output_dtype = rng.choice(wrong_dtypes)
4907 elif error_name == ErrorIf.BatchMismatch:
4908 output_shape[0] += rng.integers(1, 10)
4909 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4910 modify_dim = rng.choice([1, 2])
4911 output_shape[modify_dim] += rng.integers(1, 10)
4912
4913 outputs.append(serializer.addOutput(output_shape, output_dtype))
4914 outputs.append(serializer.addOutput(output_shape, output_dtype))
4915 return outputs
4916
4917 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00004918 def rfft2dOp(serializer, rng, value, error_name=None):
4919 outputs = []
4920
4921 input_shape = value.shape
4922 if error_name != ErrorIf.WrongRank:
4923 assert len(input_shape) == 3
4924
4925 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
4926
4927 output_dtype = value.dtype
4928 if error_name == ErrorIf.WrongOutputType:
4929 excludes = [DType.FP32]
4930 wrong_dtypes = list(usableDTypes(excludes=excludes))
4931 output_dtype = rng.choice(wrong_dtypes)
4932 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00004933 output_shape[0] += rng.integers(1, 10)
4934 elif error_name == ErrorIf.FFTOutputShapeMismatch:
4935 modify_dim = rng.choice([1, 2])
4936 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00004937
4938 outputs.append(serializer.addOutput(output_shape, output_dtype))
4939 outputs.append(serializer.addOutput(output_shape, output_dtype))
4940 return outputs