blob: 88dd17a241f472584330b99fc005b92079dd02dc [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01006from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01007from datetime import datetime
8from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07009
Jeremy Johnson1271c442023-09-05 11:39:26 +010010import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010013from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010014from generator.tosa_arg_gen import TosaArgGen
15from generator.tosa_arg_gen import TosaQuantGen
16from generator.tosa_arg_gen import TosaTensorGen
17from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000018from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010019from generator.tosa_error_if import TosaErrorIfArgGen
20from generator.tosa_error_if import TosaErrorValidator
21from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010022from generator.tosa_random_gen import TosaHashRandomGenerator
23from generator.tosa_random_gen import TosaRandomGenerator
Jeremy Johnson1271c442023-09-05 11:39:26 +010024from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000025from tosa.DType import DType
26from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010027
Jeremy Johnson1271c442023-09-05 11:39:26 +010028TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
29// SPDX-License-Identifier: Apache-2.0
30// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
31"""
32
Jeremy Johnsonaf090182024-02-13 18:25:39 +000033logging.basicConfig()
34logger = logging.getLogger("tosa_verif_build_tests")
35
Matthew Haddonb724efc2021-08-25 16:40:29 +010036
Eric Kunzee5e26762020-10-13 16:11:07 -070037class TosaTestGen:
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000038 # This currently matches the 8K level defined in the specification.
Jeremy Johnsonb2099702023-04-12 15:59:01 +010039 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010040 TOSA_8K_LEVEL_MAX_KERNEL = 8192
41 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010042
Jeremy Johnson1271c442023-09-05 11:39:26 +010043 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000044 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010045 TOSA_MI_DOT_PRODUCT_MIN = 1000
46
Eric Kunzee5e26762020-10-13 16:11:07 -070047 def __init__(self, args):
48 self.args = args
49 self.basePath = args.output_dir
50 self.random_seed = args.random_seed
51 self.ser = None
Eric Kunzee5e26762020-10-13 16:11:07 -070052 self.createDynamicOpLists()
53 self.initOpListDefaults()
54 self.quantGen = TosaQuantGen()
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 self.global_rng = None
Eric Kunzee5e26762020-10-13 16:11:07 -070056 # Force makeShape to do a specific starting shape
57 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010058 # JSON schema validation
59 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010060 # Data generator library is sometimes needed for compliance set up
61 # even if we are generating the data later (lazy_data_generation)
62 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070063
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010064 # Work out floating point range
65 def convertFPRange(rangeFP, maxFP):
66 # Converts program arguments of max/-max to FP max
67 vals = []
68 for v in rangeFP:
69 if v == "max":
70 v = maxFP
71 elif v == "-max":
72 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000073 elif v < 0:
74 # Trim to minimum data type value
75 v = max(v, -maxFP)
76 elif v > 0:
77 # Trim to maximum data type value
78 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010079 vals.append(v)
80 return tuple(sorted(vals))
81
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010082 self.random_dtype_range = {
83 DType.SHAPE: tuple(self.args.tensor_shape_range[0:2])
84 }
Won Jeon2c34b462024-02-06 18:37:00 +000085 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010086 self.random_dtype_range[dtype] = convertFPRange(
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010087 args.tensor_fp_value_range,
88 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
89 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010090 self.resetGlobalRNG()
91
92 def resetGlobalRNG(self):
93 self.global_rng = TosaRandomGenerator(self.random_seed, self.random_dtype_range)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010094
Eric Kunzee5e26762020-10-13 16:11:07 -070095 def createSerializer(self, opName, testPath):
96 self.testPath = os.path.join(opName, testPath)
97
98 fullPath = os.path.join(self.basePath, self.testPath)
99 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100100 # Embed const data in the flatbuffer
101 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 if self.args.lazy_data_gen:
103 # Lazy data generation - so make constants files
104 constMode = ts.ConstMode.INPUTS
105 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100106 constMode = ts.ConstMode.EMBED_DUMP
107 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -0700108
109 def getSerializer(self):
110 return self.ser
111
evacha01ad8e1e22024-03-19 12:42:17 +0000112 def serialize(self, testName, metaData=None, tags=None):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100113 path = Path(self.basePath) / self.testPath
114
115 # Write out TOSA flatbuffer binary
116 path_fb = path / f"{testName}.tosa"
117 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700118 fd.write(self.ser.serialize())
119
Jeremy Johnson1271c442023-09-05 11:39:26 +0100120 # Get JSON descriptor from serializer
121 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
122
123 if metaData:
124 # Add extra meta data to desc.json
125 desc["meta"] = metaData
126
evacha01ad8e1e22024-03-19 12:42:17 +0000127 if tags:
128 desc["tag"] = tags
129
Jeremy Johnson1271c442023-09-05 11:39:26 +0100130 # Validate desc.json before we output it
131 self.descSchemaValidator.validate_config(desc)
132
133 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100134 if "data_gen" in metaData:
135 if self.args.lazy_data_gen:
136 # Output datagen meta data as CPP data
137 path_md = path / f"{testName}_meta_data_gen.cpp"
138 with path_md.open("w") as fd:
139 fd.write(TOSA_AUTOGENERATED_HEADER)
140 fd.write("// Test meta data for data generation setup\n\n")
141 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
142 json.dump(metaData["data_gen"], fd)
143 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100144 if "compliance" in metaData:
145 # Output datagen meta data as CPP data
146 path_md = path / f"{testName}_meta_compliance.cpp"
147 with path_md.open("w") as fd:
148 fd.write(TOSA_AUTOGENERATED_HEADER)
149 fd.write("// Test meta data for compliance validation\n\n")
150 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
151 json.dump(metaData["compliance"], fd)
152 fd.write(')";\n\n')
153
154 # Write desc.json
155 path_desc = path / "desc.json"
156 with path_desc.open("w") as fd:
157 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700158
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100159 def buildPlaceholderTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700160 placeholders = []
161
Kevin Cheng989cb052021-04-28 16:29:44 -0700162 assert len(shape_list) == len(dtype_list)
163
Jeremy Johnson1271c442023-09-05 11:39:26 +0100164 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700165 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100166 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100167 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700168 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700169
170 return placeholders
171
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100172 def buildConstTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 consts = []
174
Kevin Cheng989cb052021-04-28 16:29:44 -0700175 assert len(shape_list) == len(dtype_list)
176
Jeremy Johnson1271c442023-09-05 11:39:26 +0100177 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700178 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100179 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100180 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700181 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700182
183 return consts
184
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100185 def makeShape(self, rng, rank):
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 if self.targetted_shape:
187 return np.int32(self.targetted_shape)
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000188 else:
189 return np.int32(
190 rng.integers(
191 low=self.args.tensor_shape_range[0],
192 high=self.args.tensor_shape_range[1],
193 size=rank,
194 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800195 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700196
197 def setTargetShape(self, shape):
198 self.targetted_shape = shape
199
Eric Kunzee5e26762020-10-13 16:11:07 -0700200 def shapeStr(self, shape):
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000201 assert shape is not None
202 if len(shape) > 0:
203 # Rank > 0
204 return "x".join([str(d) for d in shape])
205 else:
206 # Rank 0
207 return "0"
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100209 def typeStr(self, dtype):
210 if isinstance(dtype, list) or isinstance(dtype, tuple):
211 assert len(dtype) >= 2
212 strs = [self.typeStr(t) for t in dtype]
213 # Limit types to the first 2 as the 3rd is the accumulator
214 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700215 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100216 if dtype in gtu.DTYPE_ATTRIBUTES:
217 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100219 raise Exception(
220 "Unknown dtype, cannot convert to string: {}".format(dtype)
221 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700222
Luke Hutton57287132023-02-06 14:54:18 +0000223 def constrictBatchSize(self, shape):
224 # Limit the batch size unless an explicit target shape set
225 if self.args.max_batch_size and not self.args.target_shapes:
226 shape[0] = min(shape[0], self.args.max_batch_size)
227 return shape
228
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100229 def makeDimension(self, rng):
230 return rng.randInt(
James Ward30124a82023-02-02 14:56:33 +0000231 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
232 )
233
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100234 def tensorComplianceMetaData(
235 self, op, inputType, argsDict, outputTensor, errorName
236 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000237 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
238 UNSUPPORTED_NON_FP32_INPUT_OPS = (
239 Op.MATMUL,
240 Op.CONV2D,
241 Op.FULLY_CONNECTED,
242 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000243 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000244 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000245 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100246 if (
247 errorName
248 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000249 or (
250 not gtu.dtypeIsSupportedByCompliance(inputType)
251 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
252 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100253 ):
254 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100255 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100256
Jeremy Johnson1271c442023-09-05 11:39:26 +0100257 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100258 compliance_tens = {
259 "mode": None,
260 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
261 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
262 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
264 mode = gtu.ComplianceMode.DOT_PRODUCT
265 compliance_tens["dot_product_info"] = {
266 "s": argsDict["s"],
Suraj Sudhirb5fcfc02024-04-16 16:14:36 -0700267 "ks": (
268 int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"])
269 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100270 }
evacha014a205112024-03-08 16:39:24 +0000271 elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100272 mode = gtu.ComplianceMode.FP_SPECIAL
273 elif "compliance" in op and "ulp" in op["compliance"]:
274 mode = gtu.ComplianceMode.ULP
275 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000276 elif "compliance" in op and "relative" in op["compliance"]:
277 mode = gtu.ComplianceMode.RELATIVE
278 compliance_tens["relative_info"] = {
279 "max": argsDict["max_abs_value"],
280 "scale": op["compliance"]["relative"],
281 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100282 elif op["op"] == Op.REDUCE_PRODUCT:
283 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000284 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000285 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000286 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000287 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
288 compliance_tens["abs_error_info"] = {
289 "lower_bound": op["compliance"]["abs_error_lower_bound"]
290 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800291 elif op["op"] in (Op.SIN, Op.COS):
292 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnson1eb14552024-04-11 16:21:54 +0100293 if "compliance" in op:
294 normal_divisor = op["compliance"].get("abs_error_normal_divisor", 1)
295 bound_addition = op["compliance"].get("abs_error_bound_addition", 0)
296 else:
297 normal_divisor = 1
298 bound_addition = 0
299
300 compliance_tens["abs_error_info"] = {
301 "normal_divisor": normal_divisor,
302 "bound_as_magnitude": True,
303 "bound_addition": bound_addition,
304 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100305 else:
306 mode = gtu.ComplianceMode.EXACT
307 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
308
309 return compliance_tens
310
311 # Build Op functions
312 # Create the output tensor (calling OutputShaper as needed)
313 # Do final tweaks to attributes (if necessary for errorIf)
314 # Add Op into graph
315 # Return resulting tensor information or BuildInfo
316
317 class BuildInfo:
318 """Enhanced build information containing result tensor and associated compliance dict."""
319
320 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000321 if isinstance(resultTensor, list):
322 assert complianceDict is None or isinstance(complianceDict, list)
323 self.resultTensorList = resultTensor
324 self.complianceDictList = complianceDict
325 else:
326 self.resultTensorList = [resultTensor]
327 if complianceDict is None:
328 self.complianceDictList = None
329 else:
330 self.complianceDictList = [complianceDict]
331
332 def getComplianceInfo(self):
333 if self.complianceDictList is None:
334 return None
335 else:
336 tens_dict = {}
337 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
338 if comp is not None:
339 tens_dict[tens.name] = comp
340
341 if tens_dict:
342 # Have some compliance data, so return the info
343 compliance = {
344 "version": "0.1",
345 "tensors": tens_dict,
346 }
347 else:
348 compliance = None
349 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700350
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000351 def build_unary(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100352 self,
353 rng,
354 op,
355 inputs,
356 args_dict,
357 validator_fcns=None,
358 error_name=None,
359 qinfo=None,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000360 ):
361 assert len(inputs) == 1
362 a = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100363 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100364
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000365 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100366
367 # Ensure new output type has correct qinfo
368 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000369 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000370 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100371 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, a.dtype),
372 TosaQuantGen.getZeroPoint(
373 rng, self.args.zeropoint, result_tensor.dtype
374 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000375 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100376
377 # Invalidate Input/Output list for error if checks.
378 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000379 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100380 pCount, cCount = op["operands"]
381 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000382 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100383 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000384 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100385
Les Bell729b0352021-11-24 10:28:21 +0000386 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100387 self.ser,
388 validator_fcns,
389 error_name,
390 op=op,
391 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000392 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000393 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000394 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100395 input_list=input_list,
396 output_list=output_list,
397 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000398 ):
399 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100400
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000401 attr = None
402 if op["op"] == Op.NEGATE:
403 attr = ts.TosaSerializerAttribute()
404 attr.NegateAttribute(qinfo[0], qinfo[1])
405
406 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000407
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000408 compliance = self.tensorComplianceMetaData(
409 op, a.dtype, args_dict, result_tensor, error_name
410 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000411 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700412
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000413 def build_binary_broadcast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100414 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000415 ):
416 assert len(inputs) == 2
417 a, b = inputs
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100418 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100419
420 # Invalidate Input/Output list for error if checks.
421 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000422 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100423 pCount, cCount = op["operands"]
424 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000425 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100426 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000427 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100428
Les Bell729b0352021-11-24 10:28:21 +0000429 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100430 self.ser,
431 validator_fcns,
432 error_name,
433 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000434 input1=a,
435 input2=b,
436 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000437 output_dtype=result_tensor.dtype,
438 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100439 input_list=input_list,
440 output_list=output_list,
441 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000442 ):
443 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100444
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000445 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000446
Jeremy Johnson9a758382023-11-07 16:27:35 +0000447 compliance = self.tensorComplianceMetaData(
448 op, a.dtype, args_dict, result_tensor, error_name
449 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000450
451 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700452
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000453 def build_arithmetic_right_shift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100454 self,
455 rng,
456 op,
457 inputs,
458 args_dict,
459 validator_fcns=None,
460 error_name=None,
461 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000462 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000463 assert len(inputs) == 2
464 a, b = inputs
465 round = args_dict["round"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100466 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100467
468 # Invalidate Input/Output list for error if checks.
469 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000470 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100471 pCount, cCount = op["operands"]
472 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000473 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100474 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000475 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100476
Les Bell729b0352021-11-24 10:28:21 +0000477 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478 self.ser,
479 validator_fcns,
480 error_name,
481 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000482 input1=a,
483 input2=b,
484 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000485 output_dtype=result_tensor.dtype,
486 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100487 input_list=input_list,
488 output_list=output_list,
489 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000490 ):
491 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800492
493 attr = ts.TosaSerializerAttribute()
494 attr.ArithmeticRightShiftAttribute(round)
495
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000496 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000497
498 compliance = self.tensorComplianceMetaData(
499 op, a.dtype, args_dict, result_tensor, error_name
500 )
501
502 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800503
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100504 def build_mul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100505 self,
506 rng,
507 op,
508 inputs,
509 args_dict,
510 validator_fcns=None,
511 error_name=None,
512 qinfo=None,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100513 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000514 # Note that mul is binary operator but it has a shift value tensor
515 assert len(inputs) == 3
516 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100517
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100518 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700519
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100520 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100521 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100522 result_tensor.setDtype(DType.INT32)
523
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100524 if error_name == ErrorIf.WrongOutputType:
525 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 outputDType = rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100527 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100528
529 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000530 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100531 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100532 pCount, cCount = op["operands"]
533 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000534 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100535 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000536 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537
Les Bell729b0352021-11-24 10:28:21 +0000538 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100539 self.ser,
540 validator_fcns,
541 error_name,
542 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000543 input1=a,
544 input2=b,
545 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100546 output_dtype=result_tensor.dtype,
547 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100548 input_list=input_list,
549 output_list=output_list,
550 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000551 ):
552 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Jeremy Johnson0a042992024-02-28 13:20:05 +0000554 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100555
556 compliance = self.tensorComplianceMetaData(
557 op, a.dtype, args_dict, result_tensor, error_name
558 )
559
560 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700561
Jeremy Johnson587cc842024-02-08 11:45:44 +0000562 def build_table(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100563 self,
564 rng,
565 op,
566 inputs,
567 args_dict,
568 validator_fcns=None,
569 error_name=None,
570 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000571 ):
572 assert len(inputs) == 1
573 a = inputs[0]
574 table = args_dict["table"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100575 result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700576
Kevin Chengfe392ce2021-10-18 21:51:55 +0000577 attr = ts.TosaSerializerAttribute()
578 attr.TableAttribute(table)
579
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100580 # Invalidate Input/Output list for error if checks.
581 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000582 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100583 pCount, cCount = op["operands"]
584 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000585 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100586 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000587 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100588
Les Bell729b0352021-11-24 10:28:21 +0000589 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100590 self.ser,
591 validator_fcns,
592 error_name,
593 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000594 input_shape=a.shape,
595 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000596 output_dtype=result_tensor.dtype,
597 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100598 input_list=input_list,
599 output_list=output_list,
600 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000601 ):
602 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100603
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000604 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700605
Jeremy Johnson587cc842024-02-08 11:45:44 +0000606 compliance = self.tensorComplianceMetaData(
607 op, a.dtype, args_dict, result_tensor, error_name
608 )
609
610 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700611
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000612 def build_select(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100613 self,
614 rng,
615 op,
616 inputs,
617 args_dict,
618 validator_fcns=None,
619 error_name=None,
620 qinfo=None,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000621 ):
622 assert len(inputs) == 3
623 cond, a, b = inputs
624
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100625 result_tensor = OutputShaper.selectOp(self.ser, rng, cond, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626
627 # Invalidate Input/Output list for error if checks.
628 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000629 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100630 pCount, cCount = op["operands"]
631 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000632 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100633 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000634 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100635
Les Bell729b0352021-11-24 10:28:21 +0000636 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100637 self.ser,
638 validator_fcns,
639 error_name,
640 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000641 input1=cond,
642 input2=a,
643 input3=b,
644 input_shape=a.shape,
645 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000646 output_dtype=result_tensor.dtype,
647 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100648 input_list=input_list,
649 output_list=output_list,
650 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000651 ):
652 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100653
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000654 self.ser.addOperator(
655 op["op"],
656 input_list,
657 output_list,
658 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000659 compliance = self.tensorComplianceMetaData(
660 op, a.dtype, args_dict, result_tensor, error_name
661 )
662
663 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700664
Jeremy Johnsona0150012023-11-15 15:52:06 +0000665 def build_comparison(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100666 self,
667 rng,
668 op,
669 inputs,
670 args_dict,
671 validator_fcns=None,
672 error_name=None,
673 qinfo=None,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000674 ):
675 assert len(inputs) == 2
676 a, b = inputs
677
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100678 result_tensor = OutputShaper.binaryComparisonOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100679
680 # Invalidate Input/Output list for error if checks.
681 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000682 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100683 pCount, cCount = op["operands"]
684 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000685 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100686 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000687 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100688
Les Bell729b0352021-11-24 10:28:21 +0000689 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100690 self.ser,
691 validator_fcns,
692 error_name,
693 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000694 input1=a,
695 input2=b,
696 input_shape=a.shape,
697 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000698 output_shape=result_tensor.shape,
699 output_dtype=result_tensor.dtype,
700 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100701 input_list=input_list,
702 output_list=output_list,
703 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000704 ):
705 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100706
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000707 self.ser.addOperator(
708 op["op"],
709 input_list,
710 output_list,
711 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000712
713 compliance = self.tensorComplianceMetaData(
714 op, a.dtype, args_dict, result_tensor, error_name
715 )
716 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700717
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000718 def build_argmax(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100719 self, rng, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000720 ):
721 assert len(inputs) == 1
722 a = inputs[0]
723 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100724 result_tensor = OutputShaper.argmaxOp(self.ser, rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100725
726 # Invalidate Input/Output list for error if checks.
727 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000728 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100729 pCount, cCount = op["operands"]
730 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000731 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100732 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000733 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100734
Les Bell729b0352021-11-24 10:28:21 +0000735 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100736 self.ser,
737 validator_fcns,
738 error_name,
739 op=op,
740 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000741 input_shape=a.shape,
742 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000743 output_shape=result_tensor.shape,
744 output_dtype=result_tensor.dtype,
745 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100746 input_list=input_list,
747 output_list=output_list,
748 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000749 ):
750 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700751
752 attr = ts.TosaSerializerAttribute()
753 attr.AxisAttribute(axis)
754
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000755 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000756
757 compliance = self.tensorComplianceMetaData(
758 op, inputs[0].dtype, args_dict, result_tensor, error_name
759 )
760 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700761
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000762 def build_pool2d(
763 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100764 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000765 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100766 inputs,
767 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000768 validator_fcns=None,
769 error_name=None,
770 qinfo=None,
771 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100772 assert len(inputs) == 1
773 input = inputs[0]
774 # max_pool has no accum_dtype
775 accum_dtype = (
776 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
777 )
778 stride = args_dict["stride"]
779 pad = args_dict["pad"]
780 kernel = args_dict["kernel"]
781
Jeremy Johnson0601f802023-11-08 16:28:09 +0000782 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100783 self.ser, rng, input, kernel, stride, pad, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000784 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100785
786 # Ensure new output type has correct qinfo
787 if error_name == ErrorIf.WrongInputType:
788 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000789 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100790 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, input.dtype),
791 TosaQuantGen.getZeroPoint(
792 rng, self.args.zeropoint, result_tensor.dtype
793 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000794 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100795
796 # Invalidate Input/Output list for error if checks.
797 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000798 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100799 pCount, cCount = op["operands"]
800 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000801 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100802 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000803 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100804
Les Bell729b0352021-11-24 10:28:21 +0000805 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100806 self.ser,
807 validator_fcns,
808 error_name,
809 op=op,
810 input_shape=input.shape,
811 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000812 output_shape=result_tensor.shape,
813 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000814 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100815 kernel=kernel,
816 stride=stride,
817 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000818 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000819 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100820 input_list=input_list,
821 output_list=output_list,
822 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000823 ):
824 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000826 if qinfo is None:
827 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700828
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000829 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100830 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000831
832 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700833
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100834 compliance = self.tensorComplianceMetaData(
835 op, inputs[0].dtype, args_dict, result_tensor, error_name
836 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100837
838 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100839
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000840 def build_conv2d(
841 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100842 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000843 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100844 inputs,
845 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000846 validator_fcns=None,
847 error_name=None,
848 qinfo=None,
849 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850 assert len(inputs) == 3
851 ifm, filter, bias = inputs
852 accum_dtype = args_dict["acc_type"]
853 strides = args_dict["stride"]
854 padding = args_dict["pad"]
855 dilations = args_dict["dilation"]
856
Kevin Cheng550ccc52021-03-03 11:21:43 -0800857 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100858 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100859 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100860 rng,
James Ward8b390432022-08-12 20:48:56 +0100861 ifm,
862 filter,
863 accum_dtype,
864 strides,
865 padding,
866 dilations,
867 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000868 )
869
870 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000871 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
872 DType.INT8,
873 DType.UINT8,
874 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000875 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100876 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
877 TosaQuantGen.getZeroPoint(
878 rng, self.args.zeropoint, result_tensor.dtype
879 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000880 ]
Les Bell0e027d42021-11-09 14:42:14 +0000881
882 # Invalidate Input/Output list for error_if checks.
883 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100884 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000885 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000886 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100887 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000888 )
Les Bell0e027d42021-11-09 14:42:14 +0000889
Les Bell729b0352021-11-24 10:28:21 +0000890 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000891 self.ser,
892 validator_fcns,
893 error_name,
894 op=op,
895 input_dtype=ifm.dtype,
896 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100897 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000898 qinfo=qinfo,
899 input_list=input_list,
900 num_operands=num_operands,
901 output_list=output_list,
902 pad=padding,
903 stride=strides,
904 dilation=dilations,
905 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100906 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100907 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000908 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000909 ):
910 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700911
Tai Lyd3797f02023-11-15 23:06:19 +0000912 # TODO - Test local_bound, for now set local bound attribute to False
913 local_bound = False
914
Eric Kunzee5e26762020-10-13 16:11:07 -0700915 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000916 attr.ConvAttribute(
917 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
918 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700919
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000920 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100921
922 compliance = self.tensorComplianceMetaData(
923 op, ifm.dtype, args_dict, result_tensor, error_name
924 )
925
926 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700927
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000928 def build_conv3d(
929 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100930 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000931 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100932 inputs,
933 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000934 validator_fcns=None,
935 error_name=None,
936 qinfo=None,
937 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100938 assert len(inputs) == 3
939 ifm, filter, bias = inputs
940 accum_dtype = args_dict["acc_type"]
941 strides = args_dict["stride"]
942 padding = args_dict["pad"]
943 dilations = args_dict["dilation"]
944
Kevin Cheng1533b852021-09-01 12:51:58 -0700945 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000946 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100947 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100948 rng,
James Ward8b390432022-08-12 20:48:56 +0100949 ifm,
950 filter,
951 accum_dtype,
952 strides,
953 padding,
954 dilations,
955 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000956 )
957
958 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000959 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
960 DType.INT8,
961 DType.UINT8,
962 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000963 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100964 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
965 TosaQuantGen.getZeroPoint(
966 rng, self.args.zeropoint, result_tensor.dtype
967 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000968 ]
Les Bell0e027d42021-11-09 14:42:14 +0000969
970 # Invalidate Input/Output list for error_if checks.
971 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +0000972 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000973 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000974 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100975 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000976 )
Les Bell0e027d42021-11-09 14:42:14 +0000977
Les Bell729b0352021-11-24 10:28:21 +0000978 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000979 self.ser,
980 validator_fcns,
981 error_name,
982 op=op,
983 input_dtype=ifm.dtype,
984 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +0000985 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000986 qinfo=qinfo,
987 input_list=input_list,
988 num_operands=num_operands,
989 output_list=output_list,
990 pad=padding,
991 stride=strides,
992 dilation=dilations,
993 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100994 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +0000995 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000996 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000997 ):
998 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700999
Tai Lyd3797f02023-11-15 23:06:19 +00001000 # TODO - Test local_bound, for now set local bound attribute to False
1001 local_bound = False
1002
Kevin Cheng1533b852021-09-01 12:51:58 -07001003 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +00001004 attr.ConvAttribute(
1005 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
1006 )
Kevin Cheng1533b852021-09-01 12:51:58 -07001007
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001008 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001009
1010 compliance = self.tensorComplianceMetaData(
1011 op, ifm.dtype, args_dict, result_tensor, error_name
1012 )
1013
1014 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001015
Kevin Cheng550ccc52021-03-03 11:21:43 -08001016 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001017 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001018 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001019 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001020 inputs,
1021 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001022 validator_fcns=None,
1023 error_name=None,
1024 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001025 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001026 assert len(inputs) == 3
1027 ifm, filter, bias = inputs
1028 accum_dtype = args_dict["acc_type"]
1029 strides = args_dict["stride"]
1030 out_pad = args_dict["pad"]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001031
TatWai Chong24594f52022-06-08 00:48:04 -07001032 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001033 result_tensor = OutputShaper.transposeConv2DOp(
Suraj Sudhirb5fcfc02024-04-16 16:14:36 -07001034 self.ser, rng, ifm, filter, accum_dtype, strides, out_pad, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001035 )
Les Bell0e027d42021-11-09 14:42:14 +00001036
1037 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001038 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1039 DType.INT8,
1040 DType.UINT8,
1041 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001042 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001043 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1044 TosaQuantGen.getZeroPoint(
1045 rng, self.args.zeropoint, result_tensor.dtype
1046 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001047 ]
Les Bell0e027d42021-11-09 14:42:14 +00001048
1049 # Invalidate Input/Output list for error_if checks.
1050 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001051 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001052 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001053 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001054 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001055 )
Les Bell0e027d42021-11-09 14:42:14 +00001056
Les Bell729b0352021-11-24 10:28:21 +00001057 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001058 self.ser,
1059 validator_fcns,
1060 error_name,
1061 op=op,
1062 input_dtype=ifm.dtype,
1063 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001064 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001065 qinfo=qinfo,
1066 input_list=input_list,
1067 num_operands=num_operands,
1068 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001069 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001070 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001071 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001072 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001073 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001074 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001075 ):
1076 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001077
Tai Lyd3797f02023-11-15 23:06:19 +00001078 # TODO - Test local_bound, for now set local bound attribute to False
1079 local_bound = False
1080
Eric Kunzee5e26762020-10-13 16:11:07 -07001081 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001082 attr.TransposeConvAttribute(
Suraj Sudhirb5fcfc02024-04-16 16:14:36 -07001083 out_pad, strides, qinfo[0], qinfo[1], local_bound, accum_dtype
Tai Lyd3797f02023-11-15 23:06:19 +00001084 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001085
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001086 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001087
1088 compliance = self.tensorComplianceMetaData(
1089 op, ifm.dtype, args_dict, result_tensor, error_name
1090 )
1091
1092 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001093
Kevin Cheng550ccc52021-03-03 11:21:43 -08001094 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001095 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001096 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001097 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001098 inputs,
1099 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001100 validator_fcns=None,
1101 error_name=None,
1102 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001103 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001104 assert len(inputs) == 3
1105 ifm, filter, bias = inputs
1106 accum_dtype = args_dict["acc_type"]
1107 strides = args_dict["stride"]
1108 padding = args_dict["pad"]
1109 dilations = args_dict["dilation"]
1110
Jeremy Johnson4f931302024-01-04 17:05:24 +00001111 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001112 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001113 rng,
James Ward8b390432022-08-12 20:48:56 +01001114 ifm,
1115 filter,
1116 accum_dtype,
1117 strides,
1118 padding,
1119 dilations,
1120 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001121 )
1122
1123 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001124 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1125 DType.INT8,
1126 DType.UINT8,
1127 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001128 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001129 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1130 TosaQuantGen.getZeroPoint(
1131 rng, self.args.zeropoint, result_tensor.dtype
1132 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001133 ]
Les Bell0e027d42021-11-09 14:42:14 +00001134
1135 # Invalidate Input/Output list for error_if checks.
1136 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001137 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001138 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001139 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001140 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001141 )
Les Bell0e027d42021-11-09 14:42:14 +00001142
Les Bell729b0352021-11-24 10:28:21 +00001143 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001144 self.ser,
1145 validator_fcns,
1146 error_name,
1147 op=op,
1148 input_dtype=ifm.dtype,
1149 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001150 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001151 qinfo=qinfo,
1152 input_list=input_list,
1153 num_operands=num_operands,
1154 output_list=output_list,
1155 pad=padding,
1156 stride=strides,
1157 dilation=dilations,
1158 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001159 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001160 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001161 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001162 ):
1163 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001164
Tai Lyd3797f02023-11-15 23:06:19 +00001165 # TODO - Test local_bound, for now set local bound attribute to False
1166 local_bound = False
1167
Eric Kunzee5e26762020-10-13 16:11:07 -07001168 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +00001169 attr.ConvAttribute(
1170 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
1171 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001173 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001174
1175 compliance = self.tensorComplianceMetaData(
1176 op, ifm.dtype, args_dict, result_tensor, error_name
1177 )
1178
1179 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001180
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001181 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001182 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001183 rng,
James Ward8b390432022-08-12 20:48:56 +01001184 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001185 inputs,
1186 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001187 validator_fcns=None,
1188 error_name=None,
1189 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001190 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001191 assert len(inputs) == 3
1192 ifm, filter, bias = inputs
1193 accum_dtype = args_dict["acc_type"]
1194
1195 result_tensor = OutputShaper.fullyConnectedOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001196 self.ser, rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001198
1199 # Invalidate Input/Output list for error if checks.
1200 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001201 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001202 pCount, cCount = op["operands"]
1203 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001204 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001205 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001206 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001207
Les Bell729b0352021-11-24 10:28:21 +00001208 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001209 self.ser,
1210 validator_fcns,
1211 error_name,
1212 op=op,
1213 input_shape=ifm.shape,
1214 input_dtype=ifm.dtype,
1215 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001216 output_shape=result_tensor.shape,
1217 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001218 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001219 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001220 input_list=input_list,
1221 output_list=output_list,
1222 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001223 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001224 ):
1225 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001226
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001227 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001228 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001229
1230 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001231
1232 compliance = self.tensorComplianceMetaData(
1233 op, ifm.dtype, args_dict, result_tensor, error_name
1234 )
1235
1236 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001237
James Ward8b390432022-08-12 20:48:56 +01001238 def build_matmul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001239 self,
1240 rng,
1241 op,
1242 inputs,
1243 args_dict,
1244 validator_fcns=None,
1245 error_name=None,
1246 qinfo=None,
James Ward8b390432022-08-12 20:48:56 +01001247 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001248 assert len(inputs) == 2
1249 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001250 accum_dtype = args_dict["acc_type"]
1251 result_tensor = OutputShaper.matmulOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001252 self.ser, rng, a, b, accum_dtype, error_name
James Ward8b390432022-08-12 20:48:56 +01001253 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001254
1255 # Invalidate Input/Output list for error if checks.
1256 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001257 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001258 pCount, cCount = op["operands"]
1259 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001260 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001261 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001262 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001263
Les Bell729b0352021-11-24 10:28:21 +00001264 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001265 self.ser,
1266 validator_fcns,
1267 error_name,
1268 op=op,
1269 input_shape=a.shape,
1270 input_dtype=a.dtype,
1271 input2_shape=b.shape,
1272 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001273 output_shape=result_tensor.shape,
1274 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001275 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001276 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001277 input_list=input_list,
1278 output_list=output_list,
1279 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001280 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001281 ):
1282 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001283
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001284 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001285 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001286
1287 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001288
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001289 compliance = self.tensorComplianceMetaData(
1290 op, a.dtype, args_dict, result_tensor, error_name
1291 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001292
1293 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001294
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001295 def build_reduce(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001296 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001297 ):
1298 assert len(inputs) == 1
1299 a = inputs[0]
1300 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001301 result_tensor = OutputShaper.reduceOp(self.ser, rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001302
1303 # Invalidate Input/Output list for error if checks.
1304 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001305 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001306 pCount, cCount = op["operands"]
1307 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001308 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001309 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001311
Les Bell729b0352021-11-24 10:28:21 +00001312 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001313 self.ser,
1314 validator_fcns,
1315 error_name,
1316 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001317 axis=axis,
1318 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001319 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001320 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001321 output_dtype=result_tensor.dtype,
1322 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001323 input_list=input_list,
1324 output_list=output_list,
1325 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001326 ):
1327 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001328
1329 attr = ts.TosaSerializerAttribute()
1330 attr.AxisAttribute(axis)
1331
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001332 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001333
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001334 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1335 # Number of products - needed for compliance
1336 args_dict["n"] = a.shape[axis]
1337
1338 compliance = self.tensorComplianceMetaData(
1339 op, a.dtype, args_dict, result_tensor, error_name
1340 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001341
1342 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001343
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001344 def build_clamp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001345 self,
1346 rng,
1347 op,
1348 inputs,
1349 args_dict,
1350 validator_fcns=None,
1351 error_name=None,
1352 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001353 ):
1354 assert len(inputs) == 1
1355 a = inputs[0]
1356
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001357 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001358
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001359 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001360
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001361 if error_name == ErrorIf.MaxSmallerMin:
1362 # Make sure the numbers are different to invoke this error
1363 while v[0] == v[1]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001364 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001365 max_val = min(v)
1366 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001367 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001368 max_val = max(v)
1369 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001370
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001371 # Invalidate Input/Output list for error if checks.
1372 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001373 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374 pCount, cCount = op["operands"]
1375 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001376 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001377 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001378 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001379
Les Bell729b0352021-11-24 10:28:21 +00001380 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001381 self.ser,
1382 validator_fcns,
1383 error_name,
1384 op=op,
1385 max_val=max_val,
1386 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001387 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001388 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001389 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001390 output_dtype=result_tensor.dtype,
1391 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001392 input_list=input_list,
1393 output_list=output_list,
1394 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001395 ):
1396 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001397
1398 attr = ts.TosaSerializerAttribute()
Tai Ly5d0e9c72024-04-05 01:19:31 +00001399 min_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(a.dtype, [min_val])
1400 max_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(a.dtype, [max_val])
1401
1402 # align to 8 bytes
1403 while (len(min_val_as_bytes) % 8) != 0:
1404 min_val_as_bytes.append(0)
1405 while (len(max_val_as_bytes) % 8) != 0:
1406 max_val_as_bytes.append(0)
Tai Ly60dc48c2024-03-08 22:19:41 +00001407
1408 attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001409
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001410 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001411
1412 compliance = self.tensorComplianceMetaData(
1413 op, a.dtype, args_dict, result_tensor, error_name
1414 )
1415
1416 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001417
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001418 def build_activation(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001419 self,
1420 rng,
1421 op,
1422 inputs,
1423 args_dict,
1424 validator_fcns=None,
1425 error_name=None,
1426 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001427 ):
1428 assert len(inputs) == 1
1429 a = inputs[0]
1430
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001431 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432
1433 # Invalidate Input/Output list for error if checks.
1434 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001435 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001436 pCount, cCount = op["operands"]
1437 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001438 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001439 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001440 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001441
Les Bell729b0352021-11-24 10:28:21 +00001442 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001443 self.ser,
1444 validator_fcns,
1445 error_name,
1446 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001448 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001449 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001450 output_dtype=result_tensor.dtype,
1451 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001452 input_list=input_list,
1453 output_list=output_list,
1454 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001455 ):
1456 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001457
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001458 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001459
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001460 compliance = self.tensorComplianceMetaData(
1461 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001462 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001463
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001464 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001465
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001466 def build_concat(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001467 self,
1468 rng,
1469 op,
1470 inputs,
1471 args_dict,
1472 validator_fcns=None,
1473 error_name=None,
1474 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001475 ):
Won Jeon74342e52024-01-09 00:34:40 +00001476 if op["op"] == Op.CONCAT_SHAPE:
1477 axis = 0
1478 else:
1479 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001480 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001481 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001482
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001483 result_tensor = OutputShaper.concatOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001484 self.ser, rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001485 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001486
Matthew Haddon818ab902021-07-27 09:12:49 +01001487 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001488 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001489 input_tensor_names.append(tensor.name)
1490
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001491 # Invalidate Input/Output list for error if checks.
1492 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001493 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001494 pCount, cCount = op["operands"]
1495 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001496 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001497 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001498 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001499
Les Bell729b0352021-11-24 10:28:21 +00001500 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001501 self.ser,
1502 validator_fcns,
1503 error_name,
1504 op=op,
1505 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001506 input_shape=inputs[0].shape,
1507 output_shape=result_tensor.shape,
1508 input_dtype=inputs[0].dtype,
1509 output_dtype=result_tensor.dtype,
1510 inputs=inputs,
1511 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001512 input_list=input_list,
1513 output_list=output_list,
1514 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001515 ):
1516 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001517
Won Jeon74342e52024-01-09 00:34:40 +00001518 if op["op"] == Op.CONCAT:
1519 attr = ts.TosaSerializerAttribute()
1520 attr.AxisAttribute(axis)
1521 else:
1522 assert op["op"] == Op.CONCAT_SHAPE
1523 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001524 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001525
1526 compliance = self.tensorComplianceMetaData(
1527 op, inputs[0].dtype, args_dict, result_tensor, error_name
1528 )
1529
1530 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001531
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001532 def build_pad(
1533 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001534 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001535 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001536 inputs,
1537 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001538 validator_fcns=None,
1539 error_name=None,
1540 qinfo=None,
1541 ):
Tai Lye095da72024-01-25 22:00:18 +00001542 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001543 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001544 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001545 padding = args_dict["pad"]
1546 pad_const_int = args_dict["pad_const_int"]
1547 pad_const_float = args_dict["pad_const_fp"]
1548
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001549 result_tensor = OutputShaper.padOp(self.ser, rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001550
Tai Ly60dc48c2024-03-08 22:19:41 +00001551 # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
1552 if gtu.dtypeIsFloat(a.dtype):
Tai Ly5d0e9c72024-04-05 01:19:31 +00001553 pad_const_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(
1554 a.dtype, [pad_const_float]
1555 )
Tai Ly60dc48c2024-03-08 22:19:41 +00001556 else:
Tai Ly5d0e9c72024-04-05 01:19:31 +00001557 pad_const_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(
1558 a.dtype, [pad_const_int]
1559 )
1560
1561 # align to 8 bytes
1562 while (len(pad_const_val_as_bytes) % 8) != 0:
1563 pad_const_val_as_bytes.append(0)
Tai Ly60dc48c2024-03-08 22:19:41 +00001564
Kevin Chengfe392ce2021-10-18 21:51:55 +00001565 attr = ts.TosaSerializerAttribute()
Tai Ly60dc48c2024-03-08 22:19:41 +00001566 attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
Eric Kunzee5e26762020-10-13 16:11:07 -07001567
Matthew Haddone807aae2021-10-11 18:12:58 +01001568 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001569 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001570 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001571 pCount, cCount = op["operands"]
1572 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001573 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001574 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001575 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001576
Les Bell729b0352021-11-24 10:28:21 +00001577 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001578 self.ser,
1579 validator_fcns,
1580 error_name,
1581 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001582 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001583 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001584 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001585 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001586 pad=padding,
1587 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001588 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001589 input_list=input_list,
1590 output_list=output_list,
1591 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001592 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001593 ):
1594 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001595
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001596 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001597
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001598 compliance = self.tensorComplianceMetaData(
1599 op, a.dtype, args_dict, result_tensor, error_name
1600 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001601
1602 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001603
Won Jeona21b2e82023-08-10 10:33:01 +00001604 def build_dim(
1605 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001606 rng,
Won Jeona21b2e82023-08-10 10:33:01 +00001607 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001608 inputs,
1609 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001610 validator_fcns=None,
1611 error_name=None,
1612 qinfo=None,
1613 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001614 assert len(inputs) == 1
1615 a = inputs[0]
1616 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001617 result_tensor = OutputShaper.dimOp(self.ser, rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001618
1619 # Invalidate Input/Output list for error if checks.
1620 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001621 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001622 pCount, cCount = op["operands"]
1623 num_operands = pCount + cCount
1624 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001625 rng, error_name, input_list, output_list
Won Jeona21b2e82023-08-10 10:33:01 +00001626 )
1627
1628 if not TosaErrorValidator.evValidateErrorIfs(
1629 self.ser,
1630 validator_fcns,
1631 error_name,
1632 op=op,
1633 axis=axis,
1634 input_shape=a.shape,
1635 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001636 output_shape=result_tensor.shape,
1637 output_dtype=result_tensor.dtype,
1638 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001639 input_list=input_list,
1640 output_list=output_list,
1641 num_operands=num_operands,
1642 ):
1643 return None
1644
1645 attr = ts.TosaSerializerAttribute()
1646 attr.AxisAttribute(axis)
1647
1648 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001649 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001650
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001651 def build_reshape(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001652 self,
1653 rng,
1654 op,
1655 inputs,
1656 args_dict,
1657 validator_fcns=None,
1658 error_name=None,
1659 qinfo=None,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001660 ):
Tai Ly8690a082023-12-18 20:40:24 +00001661 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001662 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001663 shape = inputs[1]
1664 shape_attr = args_dict["new_shape"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001665 result_tensor = OutputShaper.reshapeOp(self.ser, rng, a, shape_attr, error_name)
Matthew Haddone807aae2021-10-11 18:12:58 +01001666
1667 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001668 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001669 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001670 pCount, cCount = op["operands"]
1671 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001672 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001673 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001674 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001675
Les Bell729b0352021-11-24 10:28:21 +00001676 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001677 self.ser,
1678 validator_fcns,
1679 error_name,
1680 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001681 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001682 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001683 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001684 output_dtype=result_tensor.dtype,
1685 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001686 input_list=input_list,
1687 output_list=output_list,
1688 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001689 ):
1690 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001691
Tai Ly8690a082023-12-18 20:40:24 +00001692 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001693
1694 compliance = self.tensorComplianceMetaData(
1695 op, a.dtype, args_dict, result_tensor, error_name
1696 )
1697
1698 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001699
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001700 def build_reverse(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001701 self,
1702 rng,
1703 op,
1704 inputs,
1705 args_dict,
1706 validator_fcns=None,
1707 error_name=None,
1708 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001709 ):
1710 assert len(inputs) == 1
1711 a = inputs[0]
1712 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001713 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001714
1715 # Invalidate Input/Output list for error if checks.
1716 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001717 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001718 pCount, cCount = op["operands"]
1719 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001720 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001721 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001722 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001723
Les Bell729b0352021-11-24 10:28:21 +00001724 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001725 self.ser,
1726 validator_fcns,
1727 error_name,
1728 op=op,
1729 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001730 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001731 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001732 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001733 output_dtype=result_tensor.dtype,
1734 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001735 input_list=input_list,
1736 output_list=output_list,
1737 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001738 ):
1739 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001740
1741 attr = ts.TosaSerializerAttribute()
1742 attr.AxisAttribute(axis)
1743
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001744 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001745 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001746
evacha0198477222024-01-26 12:25:32 +00001747 def build_transpose(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001748 self,
1749 rng,
1750 op,
1751 inputs,
1752 args_dict,
1753 validator_fcns=None,
1754 error_name=None,
1755 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00001756 ):
1757 assert len(inputs) == 1
1758 a = inputs[0]
1759 perms = args_dict["perms"]
1760
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001761 result_tensor = OutputShaper.transposeOp(self.ser, rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001762
Kevin Chengfe392ce2021-10-18 21:51:55 +00001763 attr = ts.TosaSerializerAttribute()
1764 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001765
Matthew Haddone807aae2021-10-11 18:12:58 +01001766 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001767 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001768 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001769 pCount, cCount = op["operands"]
1770 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001771 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001772 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001774
Les Bell729b0352021-11-24 10:28:21 +00001775 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001776 self.ser,
1777 validator_fcns,
1778 error_name,
1779 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001780 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001781 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001782 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001783 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001784 output_dtype=result_tensor.dtype,
1785 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001786 input_list=input_list,
1787 output_list=output_list,
1788 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001789 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001790 ):
1791 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001792
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001793 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001794
1795 compliance = self.tensorComplianceMetaData(
1796 op, a.dtype, args_dict, result_tensor, error_name
1797 )
1798
1799 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001800
evacha017f7d4252024-01-24 12:08:09 +00001801 def build_slice(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001802 self,
1803 rng,
1804 op,
1805 inputs,
1806 args_dict,
1807 validator_fcns=None,
1808 error_name=None,
1809 qinfo=None,
evacha017f7d4252024-01-24 12:08:09 +00001810 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001811 assert len(inputs) == 3
1812 a, start_var, size_var = inputs
1813 start_const = args_dict["start"]
1814 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001815
1816 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001817 self.ser, rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001818 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001819
1820 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001821 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001822 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001823 pCount, cCount = op["operands"]
1824 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001825 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001826 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001828
Les Bell729b0352021-11-24 10:28:21 +00001829 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001830 self.ser,
1831 validator_fcns,
1832 error_name,
1833 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001834 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001835 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001836 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001837 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001838 start=start_const,
1839 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001840 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001841 input_list=input_list,
1842 output_list=output_list,
1843 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001844 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001845 ):
1846 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001847
Tai Ly8ead6c42024-02-14 22:35:44 +00001848 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001849
1850 compliance = self.tensorComplianceMetaData(
1851 op, a.dtype, args_dict, result_tensor, error_name
1852 )
1853
1854 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001855
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001856 def build_tile(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001857 self,
1858 rng,
1859 op,
1860 inputs,
1861 args_dict,
1862 validator_fcns=None,
1863 error_name=None,
1864 qinfo=None,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001865 ):
Tai Ly8690a082023-12-18 20:40:24 +00001866 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001867 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001868 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001869 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001870 result_tensor = OutputShaper.tileOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001871 self.ser, rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001872 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001873
1874 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001875 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001876 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001877 pCount, cCount = op["operands"]
1878 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001879 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001880 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001881 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001882
Les Bell729b0352021-11-24 10:28:21 +00001883 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001884 self.ser,
1885 validator_fcns,
1886 error_name,
1887 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001888 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001889 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001890 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001891 output_dtype=result_tensor.dtype,
1892 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001893 input_list=input_list,
1894 output_list=output_list,
1895 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001896 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001897 ):
1898 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001899
Tai Ly8690a082023-12-18 20:40:24 +00001900 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001901
1902 compliance = self.tensorComplianceMetaData(
1903 op, a.dtype, args_dict, result_tensor, error_name
1904 )
1905
1906 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001907
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001908 def build_gather(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001909 self,
1910 rng,
1911 op,
1912 inputs,
1913 args_dict,
1914 validator_fcns=None,
1915 error_name=None,
1916 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001917 ):
1918 assert len(inputs) == 2
1919 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001920
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001921 result_tensor = OutputShaper.gatherOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001922 self.ser, rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001923 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001924
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001925 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001926 input_list = [values.name, indices.name]
1927 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001928 pCount, cCount = op["operands"]
1929 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001930 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001931 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001932 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001933
Les Bell729b0352021-11-24 10:28:21 +00001934 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001935 self.ser,
1936 validator_fcns,
1937 error_name,
1938 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001939 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001940 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001941 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001942 output_dtype=result_tensor.dtype,
1943 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001944 input_list=input_list,
1945 output_list=output_list,
1946 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001947 ):
1948 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001949
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001950 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001951
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001952 compliance = self.tensorComplianceMetaData(
1953 op, values.dtype, args_dict, result_tensor, error_name
1954 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001955
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001956 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001957
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001958 def build_scatter(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001959 self,
1960 rng,
1961 op,
1962 inputs,
1963 args_dict,
1964 validator_fcns=None,
1965 error_name=None,
1966 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001967 ):
1968 assert len(inputs) == 3
1969 values_in, indices, input = inputs
1970 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001971 self.ser, rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001972 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001973
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001974 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001975 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001976 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001977 pCount, cCount = op["operands"]
1978 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001979 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001980 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001981 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001982
Les Bell729b0352021-11-24 10:28:21 +00001983 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001984 self.ser,
1985 validator_fcns,
1986 error_name,
1987 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001988 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001989 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001990 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001991 output_dtype=result_tensor.dtype,
1992 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001993 input_list=input_list,
1994 output_list=output_list,
1995 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001996 ):
1997 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001998
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001999 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002000
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00002001 compliance = self.tensorComplianceMetaData(
2002 op, values_in.dtype, args_dict, result_tensor, error_name
2003 )
2004
2005 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002006
Kevin Cheng550ccc52021-03-03 11:21:43 -08002007 def build_resize(
2008 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002009 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002010 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002011 inputs,
2012 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01002013 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002014 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002015 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002016 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002017 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002018 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002019 scale_input = inputs[1]
2020 offset_input = inputs[2]
2021 border_input = inputs[3]
2022
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002023 mode = args_dict["mode"]
2024 scale = args_dict["scale"]
2025 offset = args_dict["offset"]
2026 border = args_dict["border"]
2027 output_dtype = args_dict["output_dtype"]
2028
2029 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002030 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002031 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002032 input,
2033 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002034 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002035 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002036 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002037 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002038 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002039 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002040 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002041
Matthew Haddon848efb42021-09-09 12:30:53 +01002042 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002043 input_list = [
2044 input.name,
2045 scale_input.name,
2046 offset_input.name,
2047 border_input.name,
2048 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002049 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002050 pCount, cCount = op["operands"]
2051 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002052 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002053 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002054 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002055
Les Bell729b0352021-11-24 10:28:21 +00002056 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002057 self.ser,
2058 validator_fcns,
2059 error_name,
2060 op=op,
2061 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002062 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002063 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002064 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002065 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002066 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002067 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002068 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002069 input_list=input_list,
2070 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002071 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002072 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002073 ):
2074 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002075
Eric Kunzee5e26762020-10-13 16:11:07 -07002076 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002077 # write empty scale/offset/border into ResizeAttribute
2078 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002079 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002080
2081 compliance = self.tensorComplianceMetaData(
2082 op, input.dtype, args_dict, result_tensor, error_name
2083 )
2084
2085 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002086
evacha0198477222024-01-26 12:25:32 +00002087 def build_const(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002088 self,
2089 rng,
2090 op,
2091 inputs,
2092 args_dict,
2093 validator_fcns=None,
2094 error_name=None,
2095 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00002096 ):
2097 assert len(inputs) == 1
2098 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002099 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002100
2101 compliance = self.tensorComplianceMetaData(
2102 op, val.dtype, args_dict, val, error_name
2103 )
2104
2105 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002106
2107 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002108 def build_cast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002109 self,
2110 rng,
2111 op,
2112 inputs,
2113 args_dict,
2114 validator_fcns=None,
2115 error_name=None,
2116 qinfo=None,
Jeremy Johnson708da822023-11-15 16:25:45 +00002117 ):
2118 assert len(inputs) == 1
2119 val = inputs[0]
2120 out_dtype = args_dict["out_type"]
2121
2122 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002123 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002124 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002125
2126 # Invalidate Input/Output list for error if checks.
2127 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002128 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002129 pCount, cCount = op["operands"]
2130 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002131 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002132 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002134
Les Bell729b0352021-11-24 10:28:21 +00002135 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002136 self.ser,
2137 validator_fcns,
2138 error_name,
2139 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002140 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002141 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002142 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002143 output_dtype=result_tensor.dtype,
2144 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002145 input_list=input_list,
2146 output_list=output_list,
2147 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002148 ):
2149 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002150
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002151 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002152
2153 compliance = self.tensorComplianceMetaData(
2154 op, val.dtype, args_dict, result_tensor, error_name
2155 )
2156
2157 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002158
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002159 def build_rescale(
2160 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002161 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002162 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002163 inputs,
2164 args_dict,
2165 validator_fcns=None,
2166 error_name=None,
2167 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002168 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002169 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002170 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002171 multiplier_val = inputs[1]
2172 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002173 out_dtype = args_dict["output_dtype"]
2174 scale32 = args_dict["scale"]
2175 double_round = args_dict["double_round"]
2176 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002177 shift_arr = args_dict["shift"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002178 multiplier_arr = args_dict["multiplier"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002179
2180 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002181 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002182 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002183
2184 if per_channel:
2185 nc = val.shape[-1]
2186 else:
2187 nc = 1
2188
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002189 in_type_width = gtu.dtypeWidth(val.dtype)
2190 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002191
Tai Ly8690a082023-12-18 20:40:24 +00002192 input_unsigned = False
2193 output_unsigned = False
2194
Kevin Cheng3a478572021-01-22 17:21:02 -08002195 if val.dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002196 input_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002197 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002198 elif val.dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002199 input_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002200 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002201 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002202 elif error_name in [
2203 ErrorIf.InputZeroPointNotZero,
2204 ErrorIf.U16InputZeroPointNotValid,
2205 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002206 input_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002207 if input_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002208 input_zp = input_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002209 in_type_width += 1
2210 elif val.dtype == DType.UINT16:
2211 # Must come after ErrorIf.U16InputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002212 input_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002213 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002214 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002215 else:
2216 input_zp = 0
2217
Kevin Cheng3a478572021-01-22 17:21:02 -08002218 if out_dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002219 output_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002220 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002221 elif out_dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002222 output_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002223 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002224 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002225 elif error_name in [
2226 ErrorIf.OutputZeroPointNotZero,
2227 ErrorIf.U16OutputZeroPointNotValid,
2228 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002229 output_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002230 if output_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002231 output_zp = output_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002232 out_type_width += 1
2233 elif out_dtype == DType.UINT16:
2234 # Must come after ErrorIf.U16OutputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002235 output_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002236 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002237 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002238 else:
2239 output_zp = 0
2240
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002241 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2242 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002243
2244 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002245 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2246 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002247
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002248 logger.debug(
2249 f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
2250 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002251 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002252 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002253 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002254 assert val.placeholderFilename
2255 values = np.load(
2256 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2257 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002258 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2259 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2260 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002261 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2262 # Check we can safely convert to the expected dtype
2263 assert (
2264 val_adj.all() >= np.iinfo(values.dtype).min
2265 and val_adj.all() <= np.iinfo(values.dtype).max
2266 )
2267
2268 # Force casting to output datatype
2269 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2270
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002271 if not np.all(np.array_equal(values, val_adj)):
2272 # Values changed so overwrite file with new values
2273 np.save(
2274 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2275 val_adj,
2276 False,
2277 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002278
Matthew Haddonc2025212021-10-08 21:21:05 +01002279 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002280 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002281 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002282 pCount, cCount = op["operands"]
2283 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002284 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002285 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002286 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002287
2288 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002289 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002290 self.ser,
2291 validator_fcns,
2292 error_name,
2293 op=op,
2294 input_dtype=val.dtype,
2295 output_dtype=out_dtype,
2296 input_shape=val.shape,
2297 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002298 scale32=scale32,
2299 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002300 input_list=input_list,
2301 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002302 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002303 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002304 ):
2305 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002306
Eric Kunzee5e26762020-10-13 16:11:07 -07002307 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002308 attr.RescaleAttribute(
2309 input_zp,
2310 output_zp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002311 scale32,
2312 double_round,
2313 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002314 input_unsigned,
2315 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002316 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002317
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002318 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002319
2320 compliance = self.tensorComplianceMetaData(
2321 op, val.dtype, args_dict, result_tensor, error_name
2322 )
2323
2324 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002325
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002326 def _get_condition_tensor(self, rng, op, cond, error_name):
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002327 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002328 cond_type = gtu.get_wrong_output_type(op, rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002329 else:
2330 cond_type = DType.BOOL
2331 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002332 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002333 if choice == 1:
2334 cond_shape = [2]
2335 else:
2336 cond_shape = [1, 2]
2337 else:
2338 # Must be of size 1 (rank 0)
2339 cond_shape = []
2340 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2341 return cond_tens
2342
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002343 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002344 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002345 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002346 op,
2347 inputs,
2348 args_dict,
2349 validator_fcns=None,
2350 error_name=None,
2351 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002352 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002353 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002354 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002355 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002356 assert len(inputs) == 2
2357 then_tens, else_tens = inputs
2358
2359 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002360
2361 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002362 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002363
2364 # Make then/else tensors
2365 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002366
Jeremy Johnson587cc842024-02-08 11:45:44 +00002367 dtype = DType.INT32
2368
Matthew Haddon630c17c2021-10-14 15:05:41 +01002369 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002370 if error_name in [
2371 ErrorIf.CondIfOutputListThenGraphMismatch,
2372 ErrorIf.CondIfOutputListElseGraphMismatch,
2373 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002374 incorrect_shape = deepcopy(then_tens.shape)
2375 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002376 incorrect_shape[i] += (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002377 rng.choice([-3, -2, 2, 3])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002378 if incorrect_shape[i] > 3
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002379 else rng.choice([1, 2, 4])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002380 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002381 incorrect_arr = np.int32(rng.integers(0, 256, size=incorrect_shape))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002382
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002383 then_arr = np.int32(rng.integers(0, 256, size=out_shape))
2384 else_arr = np.int32(rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002385
2386 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002387 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002388
2389 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002390 then_block = "THEN_BLOCK"
2391 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002392 attr = ts.TosaSerializerAttribute()
2393 attr.CondIfAttribute(then_block, else_block)
2394
2395 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002396 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002397
Jerry Ge9e94af82022-10-27 09:57:00 -07002398 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002399 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002400 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002401 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002402 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002403 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002404 self.ser.addOutputTensor(then_tens)
2405
Jerry Ge9e94af82022-10-27 09:57:00 -07002406 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002407 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002408 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002409 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002410 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002411 self.ser.addOutputTensor(else_tens)
2412
Les Bell729b0352021-11-24 10:28:21 +00002413 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002414 self.ser,
2415 validator_fcns,
2416 error_name,
2417 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002418 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002419 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002420 ):
2421 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002422
Jeremy Johnson587cc842024-02-08 11:45:44 +00002423 compliance = self.tensorComplianceMetaData(
2424 op, dtype, args_dict, result_tensor, error_name
2425 )
2426
2427 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002428
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002429 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002430 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002431 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002432 op,
2433 inputs,
2434 args_dict,
2435 validator_fcns=None,
2436 error_name=None,
2437 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002438 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002439 # For cond_if with a binary op in the then/else blocks, take a and b and
2440 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002441 assert len(inputs) == 2
2442 a, b = inputs
2443
2444 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002445
2446 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002447 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002448
Jeremy Johnson587cc842024-02-08 11:45:44 +00002449 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002450
2451 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002452 then_block = "THEN_BLOCK"
2453 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002454 attr = ts.TosaSerializerAttribute()
2455 attr.CondIfAttribute(then_block, else_block)
2456
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002457 if error_name in [
2458 ErrorIf.CondIfInputListThenGraphMismatch,
2459 ErrorIf.CondIfInputListElseGraphMismatch,
2460 ErrorIf.CondIfOutputListElseGraphMismatch,
2461 ErrorIf.CondIfOutputListThenGraphMismatch,
2462 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002463 incorrect_shape = a.shape.copy()
2464 for i in range(len(incorrect_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002465 incorrect_shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002466 incorrect_block_input = deepcopy(a)
2467 incorrect_block_input.shape = incorrect_shape
2468
Eric Kunzee5e26762020-10-13 16:11:07 -07002469 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002470 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002471 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002472 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002473
James Ward24dbc422022-10-19 12:20:31 +01002474 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002475 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002476 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002477 then_op, else_op = (
2478 self.TOSA_OP_LIST["logical_right_shift"],
2479 self.TOSA_OP_LIST["logical_left_shift"],
2480 )
Les Bell6040b4d2021-10-11 12:50:31 +01002481 else:
2482 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002483
Jeremy Johnson587cc842024-02-08 11:45:44 +00002484 # Determine the element-wise binary operation that compliance will need to
2485 # check the results of
2486 compliance_op = then_op if cond else else_op
2487
2488 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002489 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002490 if (
2491 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2492 and block == then_block
2493 ) or (
2494 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2495 and block == else_block
2496 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002497 self.ser.addInputTensor(incorrect_block_input)
2498 self.ser.addInputTensor(b)
2499 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002500 elif (
2501 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2502 and block == then_block
2503 ) or (
2504 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2505 and block == else_block
2506 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002507 self.ser.addInputTensor(a)
2508 self.ser.addInputTensor(b)
2509 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2510 else:
2511 self.ser.addInputTensor(a)
2512 self.ser.addInputTensor(b)
2513 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002514 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002515
Les Bell729b0352021-11-24 10:28:21 +00002516 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002517 self.ser,
2518 validator_fcns,
2519 error_name,
2520 op=op,
2521 a=a,
2522 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002523 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002524 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002525 ):
2526 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002527
Jeremy Johnson587cc842024-02-08 11:45:44 +00002528 compliance = self.tensorComplianceMetaData(
2529 compliance_op, a.dtype, args_dict, result_tensor, error_name
2530 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002531
Jeremy Johnson587cc842024-02-08 11:45:44 +00002532 return TosaTestGen.BuildInfo(result_tensor, compliance)
2533
2534 def build_while_loop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002535 self,
2536 rng,
2537 op,
2538 inputs,
2539 args_dict,
2540 validator_fcns=None,
2541 error_name=None,
2542 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002543 ):
2544 assert len(inputs) == 1
2545 a = inputs[0]
2546 iter_val = args_dict["iterations"]
2547
Kevin Cheng550ccc52021-03-03 11:21:43 -08002548 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002549
Kevin Cheng550ccc52021-03-03 11:21:43 -08002550 cond_block = "COND_BLOCK"
2551 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002552
2553 attr = ts.TosaSerializerAttribute()
2554 attr.WhileLoopAttribute(cond_block, body_block)
2555
2556 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002557 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002558 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002559 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002560
2561 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002562 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2563 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002564 if error_name == ErrorIf.InputListOutputListMismatch:
2565 incorrect_acc = deepcopy(acc)
2566 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002567 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002568 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2569 else:
2570 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002571
2572 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002573 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002574 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002575 [iter.name, a.name, acc.name],
2576 [iter_out.name, a_out.name, acc_out.name],
2577 attr,
2578 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002579 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002580
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002581 if error_name in [
2582 ErrorIf.InputListCondGraphMismatch,
2583 ErrorIf.InputListBodyGraphInputMismatch,
2584 ErrorIf.InputListBodyGraphOutputMismatch,
2585 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002586 incorrect_iter = deepcopy(iter)
2587 for i in range(len(incorrect_iter.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002588 incorrect_iter.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002589 if len(incorrect_iter.shape) == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002590 incorrect_iter.shape.append(rng.choice([-3, -2, 2, 3]))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002591
2592 incorrect_acc = deepcopy(acc)
2593 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002594 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002595
Eric Kunzee5e26762020-10-13 16:11:07 -07002596 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002597 self.ser.addBasicBlock(cond_block)
2598
Matthew Haddon630c17c2021-10-14 15:05:41 +01002599 if error_name == ErrorIf.InputListCondGraphMismatch:
2600 self.ser.addInputTensor(incorrect_iter)
2601 self.ser.addInputTensor(a)
2602 self.ser.addInputTensor(incorrect_acc)
2603 else:
2604 self.ser.addInputTensor(iter)
2605 self.ser.addInputTensor(a)
2606 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002607 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002608
2609 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002610 cond_type = rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002611 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002612 cond_type = DType.BOOL
2613 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002614 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002615 if choice == 1:
2616 cond_shape = [3]
2617 else:
2618 cond_shape = [1, 2]
2619 else:
2620 cond_shape = []
2621 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002622
Kevin Cheng550ccc52021-03-03 11:21:43 -08002623 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002624
2625 # BODY block (input: a, acc, iter, output: a, acc, iter)
2626 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002627 self.ser.addBasicBlock(body_block)
2628
Matthew Haddon630c17c2021-10-14 15:05:41 +01002629 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2630 self.ser.addInputTensor(incorrect_iter)
2631 self.ser.addInputTensor(a)
2632 self.ser.addInputTensor(incorrect_acc)
2633 else:
2634 self.ser.addInputTensor(iter)
2635 self.ser.addInputTensor(a)
2636 self.ser.addInputTensor(acc)
2637
Kevin Cheng550ccc52021-03-03 11:21:43 -08002638 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002639
2640 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002641 iter_body_out = self.ser.addIntermediate(
2642 incorrect_iter.shape, incorrect_iter.dtype
2643 )
2644 acc_body_out = self.ser.addIntermediate(
2645 incorrect_acc.shape, incorrect_acc.dtype
2646 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002647 else:
2648 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2649 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2650
Eric Kunzee5e26762020-10-13 16:11:07 -07002651 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2652 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2653 self.ser.addOutputTensor(iter_body_out)
2654 self.ser.addOutputTensor(a)
2655 self.ser.addOutputTensor(acc_body_out)
2656
Les Bell729b0352021-11-24 10:28:21 +00002657 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002658 self.ser,
2659 validator_fcns,
2660 error_name,
2661 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002662 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002663 ):
2664 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002665
Jeremy Johnson587cc842024-02-08 11:45:44 +00002666 compliance = self.tensorComplianceMetaData(
2667 op, a.dtype, args_dict, acc_out, error_name
2668 )
2669
2670 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002671
Luke Hutton57287132023-02-06 14:54:18 +00002672 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002673 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002674 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002675 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002676 inputs,
2677 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002678 validator_fcns=None,
2679 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002680 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002681 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002682 assert len(inputs) == 2
2683 val1, val2 = inputs
2684 inverse = args_dict["inverse"]
2685
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002686 results = OutputShaper.fft2dOp(self.ser, rng, val1, val2, error_name)
Luke Hutton57287132023-02-06 14:54:18 +00002687
2688 input_names = [val1.name, val2.name]
2689 pCount, cCount = op["operands"]
2690 num_operands = pCount + cCount
2691
2692 output_names = [res.name for res in results]
2693 output_shapes = [res.shape for res in results]
2694 output_dtypes = [res.dtype for res in results]
2695
2696 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002697 rng, error_name, input_names, output_names
Luke Hutton57287132023-02-06 14:54:18 +00002698 )
2699
2700 if not TosaErrorValidator.evValidateErrorIfs(
2701 self.ser,
2702 validator_fcns,
2703 error_name,
2704 op=op,
2705 inverse=inverse,
2706 input1=val1,
2707 input2=val2,
2708 input_shape=val1.shape,
2709 input_dtype=val1.dtype,
2710 output_shape=output_shapes,
2711 output_dtype=output_dtypes,
2712 result_tensors=results,
2713 input_list=input_names,
2714 output_list=output_names,
2715 num_operands=num_operands,
2716 ):
2717 return None
2718
Tai Lyd3797f02023-11-15 23:06:19 +00002719 # TODO - Test local_bound, for now set local bound attribute to False
2720 local_bound = False
2721
Luke Hutton57287132023-02-06 14:54:18 +00002722 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002723 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002724
2725 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002726
2727 compliance = []
2728 for res in results:
2729 compliance.append(
2730 self.tensorComplianceMetaData(
2731 op, val1.dtype, args_dict, res, error_name
2732 )
2733 )
2734
2735 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002736
Tai Lyd3797f02023-11-15 23:06:19 +00002737 def build_rfft2d(
2738 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002739 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002740 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002741 inputs,
2742 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002743 validator_fcns=None,
2744 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002745 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002746 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002747 assert len(inputs) == 1
2748 val = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002749 results = OutputShaper.rfft2dOp(self.ser, rng, val, error_name)
Luke Hutton261b7b62023-01-10 14:50:31 +00002750
2751 input_names = [val.name]
2752 pCount, cCount = op["operands"]
2753 num_operands = pCount + cCount
2754
2755 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002756 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002757 output_dtypes = [res.dtype for res in results]
2758
2759 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002760 rng, error_name, input_names, output_names
Luke Hutton261b7b62023-01-10 14:50:31 +00002761 )
2762
2763 if not TosaErrorValidator.evValidateErrorIfs(
2764 self.ser,
2765 validator_fcns,
2766 error_name,
2767 op=op,
2768 input_shape=val.shape,
2769 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002770 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002771 output_dtype=output_dtypes,
2772 result_tensors=results,
2773 input_list=input_names,
2774 output_list=output_names,
2775 num_operands=num_operands,
2776 ):
2777 return None
2778
Tai Lyd3797f02023-11-15 23:06:19 +00002779 # TODO - Test local_bound, for now set local bound attribute to False
2780 local_bound = False
2781
2782 attr = ts.TosaSerializerAttribute()
2783 attr.RFFTAttribute(local_bound)
2784
2785 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002786
2787 compliance = []
2788 for res in results:
2789 compliance.append(
2790 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2791 )
2792
2793 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002794
Won Jeon74342e52024-01-09 00:34:40 +00002795 def build_shape_op(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002796 self,
2797 rng,
2798 op,
2799 inputs,
2800 args_dict,
2801 validator_fcns=None,
2802 error_name=None,
2803 qinfo=None,
Won Jeon74342e52024-01-09 00:34:40 +00002804 ):
2805 assert len(inputs) == 2
2806 a, b = inputs
2807
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002808 result_tensor = OutputShaper.addShapeOp(self.ser, rng, a, b, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00002809
2810 # Invalidate Input/Output list for error if checks.
2811 input_list = [a.name, b.name]
2812 output_list = [result_tensor.name]
2813 pCount, cCount = op["operands"]
2814 num_operands = pCount + cCount
2815 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2816 self, error_name, input_list, output_list
2817 )
2818
2819 if not TosaErrorValidator.evValidateErrorIfs(
2820 self.ser,
2821 validator_fcns,
2822 error_name,
2823 op=op,
2824 input1=a,
2825 input2=b,
2826 input_shape=a.shape,
2827 input_dtype=a.dtype,
2828 output_shape=result_tensor.shape,
2829 output_dtype=result_tensor.dtype,
2830 result_tensors=[result_tensor],
2831 input_list=input_list,
2832 output_list=output_list,
2833 num_operands=num_operands,
2834 ):
2835 return None
2836
2837 self.ser.addOperator(
2838 op["op"],
2839 input_list,
2840 output_list,
2841 )
2842 compliance = self.tensorComplianceMetaData(
2843 op, a.dtype, args_dict, result_tensor, error_name
2844 )
2845
2846 return TosaTestGen.BuildInfo(result_tensor, compliance)
2847
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002848 def create_filter_lists(
2849 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2850 ):
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002851 # Create a default testing rank range
2852 if testType == "positive":
2853 # 0-3 inclusive to keep test sizes reasonably small.
2854 default_test_rank_range = range(0, 4)
2855 else:
2856 # Some errors do not work with rank 0, use 1-3
2857 default_test_rank_range = range(1, 4)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002858
2859 # Calculate the filters based on what is requested and what the operator allows
2860 rmin, rmax = op["rank"]
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002861
2862 if shapeFilter:
2863 # Specified shapes - ignore rank filter and default to op ranks below
2864 rankFilter = None
2865 ranksToCheck = []
2866 elif rankFilter is None:
2867 # No set rank filter so ensure default behaviour is bounded
2868 ranksToCheck = default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01002869 else:
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002870 ranksToCheck = rankFilter
2871
2872 cleanRankFilter = []
2873 # Ensure rank values are allowed by operator
2874 for rank in ranksToCheck:
2875 if rank >= rmin and rank <= rmax:
2876 cleanRankFilter.append(rank)
2877
2878 if shapeFilter or (len(cleanRankFilter) == 0 and rankFilter is None):
2879 # Shapes specified or default test ranks didn't meet
2880 # op requirements - so just use op ranks
Matthew Haddon1c00b712021-10-01 15:51:03 +01002881 cleanRankFilter = range(rmin, rmax + 1)
2882
2883 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002884
Matthew Haddon1c00b712021-10-01 15:51:03 +01002885 if dtypeFilter is not None:
2886 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002887 # Create list of operator dtypes filtered by requested dtypes
2888 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002889 if dtype in dtypeFilter or (
2890 isinstance(dtype, list) and dtype[0] in dtypeFilter
2891 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002892 cleanDtypeFilter.append(dtype)
2893 else:
2894 cleanDtypeFilter = dtypes
2895
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002896 if not shapeFilter:
2897 shapeFilter = [None]
2898
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002899 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002900 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002901 "shapeFilter": shapeFilter,
2902 "rankFilter": cleanRankFilter,
2903 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002904 }
2905 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002906 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002907 if validator is not None:
2908 validator_info = validator(check=False, op=op)
2909 else:
2910 return None
2911
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002912 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002913
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002914 # Set parameters as required
2915 if error_arguments["rank"] is not None:
2916 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002917 else:
2918 rankFilter = cleanRankFilter
2919
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002920 if error_arguments["dtype"] is not None:
2921 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002922 else:
2923 dtypeFilter = cleanDtypeFilter
2924
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002925 if error_arguments["shape"] is not None:
2926 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002927 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002928 shapeFilter = shapeFilter[
2929 :2
2930 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002931
2932 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002933 "shapeFilter": shapeFilter,
2934 "rankFilter": rankFilter,
2935 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002936 }
2937 return filterDict
2938
Kevin Cheng550ccc52021-03-03 11:21:43 -08002939 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002940 self,
2941 opName,
2942 shapeFilter=[None],
2943 rankFilter=None,
2944 dtypeFilter=None,
2945 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002946 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002947
2948 try:
2949 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002950 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002951 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002952
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002953 if not self.args.stable_rng:
2954 # Initialize a new random number generator per op
2955 self.resetGlobalRNG()
Eric Kunzee5e26762020-10-13 16:11:07 -07002956
Jeremy Johnson1271c442023-09-05 11:39:26 +01002957 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002958
Eric Kunzee5e26762020-10-13 16:11:07 -07002959 # Test list consists of a tuple of:
2960 # (opName, testNameStr, dtype, shapeList, argumentsList)
2961 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002962 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002963 error_if_validators = op["error_if_validators"]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002964 num_error_types_created = 0
Matthew Haddon1c00b712021-10-01 15:51:03 +01002965 else:
2966 error_if_validators = [None]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002967 num_error_types_created = None
Eric Kunzee5e26762020-10-13 16:11:07 -07002968
Matthew Haddon1c00b712021-10-01 15:51:03 +01002969 for validator in error_if_validators:
2970 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002971 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002972 else:
2973 error_name = None
2974
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002975 filterDict = self.create_filter_lists(
2976 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2977 )
2978 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002979 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002980 cleanRankFilter = filterDict["rankFilter"]
2981 cleanDtypeFilter = filterDict["dtypeFilter"]
2982 cleanShapeFilter = filterDict["shapeFilter"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002983 logger.debug(
2984 f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
2985 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002986
2987 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002988 for t in cleanDtypeFilter:
2989 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002990 # Filter out by rank
2991 if shape is not None and len(shape) != r:
2992 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002993 self.setTargetShape(shape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002994 typeStr = self.typeStr(t)
2995 if self.args.stable_rng:
2996 shape_rng = TosaHashRandomGenerator(
2997 self.random_seed,
2998 [opName, r, typeStr],
2999 self.random_dtype_range,
3000 )
3001 else:
3002 shape_rng = self.global_rng
3003 shapeList = tgen_fcn(self, shape_rng, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07003004
Matthew Haddon74567092021-07-16 15:38:20 +01003005 shapeStr = self.shapeStr(shapeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07003006
Matthew Haddon74567092021-07-16 15:38:20 +01003007 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3008 argList = []
3009 if agen_fcn:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003010 if self.args.stable_rng:
3011 arg_rng = TosaHashRandomGenerator(
3012 self.random_seed,
3013 [opName, shapeStr, typeStr],
3014 self.random_dtype_range,
3015 )
3016 else:
3017 arg_rng = self.global_rng
3018
3019 argList = agen_fcn(
3020 self, arg_rng, opName, shapeList, t, error_name
3021 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003022 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003023 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003024
Matthew Haddon74567092021-07-16 15:38:20 +01003025 for argStr, args in argList:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003026 # Create the test name string - for example: add_1x2x3_i32
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003027 if testType == "positive":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003028 name_parts = [opName, shapeStr, typeStr]
3029 else:
3030 assert testType == "negative"
3031 name_parts = [
3032 opName,
3033 "ERRORIF",
3034 error_name,
3035 shapeStr,
3036 typeStr,
3037 ]
3038 if argStr:
3039 name_parts.append(argStr)
3040 testStr = "_".join(name_parts)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003041
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003042 testList.append(
3043 (opName, testStr, t, error_name, shapeList, args)
3044 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003045 if error_name is not None:
3046 # Check the last test is of the error we wanted
3047 if len(testList) == 0 or testList[-1][3] != error_name:
3048 if self.args.level8k:
3049 logger.info(f"Missing {error_name} tests due to level8k mode")
3050 else:
3051 logger.error(f"ERROR: Failed to create any {error_name} tests")
3052 logger.debug(
3053 "Last test created: {}".format(
3054 testList[-1] if testList else None
3055 )
3056 )
3057 else:
3058 # Successfully created at least one ERRROR_IF test
3059 num_error_types_created += 1
Matthew Haddon1c00b712021-10-01 15:51:03 +01003060
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003061 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01003062 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3063 if "invalid_test_validators" in op:
3064 invalid_test_validators = op["invalid_test_validators"]
3065 clean_testList = []
3066 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01003067 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003068 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003069 if validator_fcn(
3070 opName=test[0],
3071 input_dtype=test[2],
3072 shapeList=test[4],
3073 args=test[5],
3074 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003075 remove_test = True
3076 if not remove_test:
3077 clean_testList.append(test)
3078 testList = clean_testList
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003079 else:
3080 if num_error_types_created is not None and not self.args.level8k:
3081 remaining_error_types = (
3082 len(error_if_validators) - num_error_types_created
3083 )
3084 if remaining_error_types:
3085 raise Exception(
3086 f"Failed to create {remaining_error_types} error types for {opName}"
3087 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003088
3089 return testList
3090
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003091 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00003092 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003093 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003094 try:
3095 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003096 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003097 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003098
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003099 logger.info(f"Creating {testStr}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01003100
Eric Kunzee5e26762020-10-13 16:11:07 -07003101 # Create a serializer
3102 self.createSerializer(opName, testStr)
3103
Jeremy Johnson1271c442023-09-05 11:39:26 +01003104 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003105 if "error_if_validators" in op:
3106 error_if_validators = op["error_if_validators"]
3107 else:
3108 error_if_validators = None
3109
Kevin Cheng550ccc52021-03-03 11:21:43 -08003110 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003111 num_operands = pCount + cCount
3112
3113 if isinstance(dtype_or_dtypeList, list):
3114 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003115 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003116 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003117 else:
3118 dtypeList = [dtype_or_dtypeList] * (num_operands)
3119
Won Jeon74342e52024-01-09 00:34:40 +00003120 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003121 assert (
3122 len(shapeList) == num_operands
3123 ), "shapeList length {} must match number of operands {}".format(
3124 len(shapeList), num_operands
3125 )
3126 assert (
3127 len(dtypeList) == num_operands
3128 ), "dtypeList length {} must match number of operands {}".format(
3129 len(dtypeList), num_operands
3130 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003131
3132 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003133 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003134 except KeyError:
3135 qgen = None
3136
3137 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003138
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003139 # Set the random number generator
3140 if self.args.stable_rng:
3141 build_rng = TosaHashRandomGenerator(
3142 self.random_seed, [testStr], self.random_dtype_range
3143 )
3144 else:
3145 build_rng = self.global_rng
3146
Matthew Haddon1c00b712021-10-01 15:51:03 +01003147 if qgen is not None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003148 qinfo = qgen(
3149 build_rng, self.args.zeropoint, op, dtype_or_dtypeList, error_name
3150 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003151 else:
3152 qinfo = None
3153
Jeremy Johnson1271c442023-09-05 11:39:26 +01003154 # Extra meta data for the desc.json
3155 tensMeta = {}
3156
Jeremy Johnson587cc842024-02-08 11:45:44 +00003157 # Check we are using the new interface with an argsDict dictionary
3158 assert isinstance(
3159 argsDict, dict
3160 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003161
Jeremy Johnson587cc842024-02-08 11:45:44 +00003162 # New interface with args info in dictionary
3163 assert "dg_type" in argsDict
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003164 tvgInfo = tvgen_fcn(
3165 self, build_rng, opName, dtypeList, shapeList, argsDict, error_name
3166 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003167 if tvgInfo.dataGenDict:
3168 tensMeta["data_gen"] = tvgInfo.dataGenDict
3169 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003170
evacha01ad8e1e22024-03-19 12:42:17 +00003171 tags = argsDict.get("tags", None)
3172
Jeremy Johnson587cc842024-02-08 11:45:44 +00003173 result = build_fcn(
3174 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003175 build_rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003176 op,
3177 tens,
3178 argsDict,
3179 validator_fcns=error_if_validators,
3180 error_name=error_name,
3181 qinfo=qinfo,
3182 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003183
Jeremy Johnson1271c442023-09-05 11:39:26 +01003184 if result:
Les Bell729b0352021-11-24 10:28:21 +00003185 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003186 if isinstance(result, TosaTestGen.BuildInfo):
3187 # Add the compliance meta data (if any)
3188 compliance = result.getComplianceInfo()
3189 if compliance:
3190 tensMeta["compliance"] = compliance
evacha01ad8e1e22024-03-19 12:42:17 +00003191 self.serialize("test", tensMeta, tags)
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003192 return True
Les Bell729b0352021-11-24 10:28:21 +00003193 else:
3194 # The test is not valid
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003195 logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003196 return False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003197
Eric Kunzee5e26762020-10-13 16:11:07 -07003198 def createDynamicOpLists(self):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003199 # Find all the ops marked as templates
3200 templateKeys = []
3201 for opName in self.TOSA_OP_LIST:
Eric Kunzee5e26762020-10-13 16:11:07 -07003202 try:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003203 if self.TOSA_OP_LIST[opName]["template"]:
3204 templateKeys.append(opName)
Eric Kunzee5e26762020-10-13 16:11:07 -07003205 except KeyError:
3206 pass
3207
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003208 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3209
3210 # Add dynamic ops based on kernel sizes
3211 for opName in templateKeys:
3212 assert opName.endswith("_TEMPLATE"), "Found incorrect template"
3213 realName = opName[: len(opName) - len("_TEMPLATE")]
3214 template = self.TOSA_OP_LIST[opName]
3215 k_rank = 3 if realName == "conv3d" else 2
3216
3217 # Choose kernels to build tests for from the template or args
3218 if self.args.level8k:
3219 if k_rank == 3:
3220 kernels = [[1, bigK, 1], [2, 2, bigK]]
3221 else:
3222 kernels = [[1, bigK], [bigK, 2]]
3223 else:
3224 kernels = []
3225 if len(self.args.conv_kernels) > 0:
3226 kernels = [k for k in self.args.conv_kernels if len(k) == k_rank]
3227 if len(kernels) == 0:
3228 logger.debug(
3229 f"{realName} op using defaults as no rank {k_rank} kernels found in {self.args.conv_kernels}"
3230 )
3231 if len(kernels) == 0:
3232 # Fallback to use the defined template kernels
3233 kernels = self.TOSA_OP_LIST[opName]["filter"]
3234
3235 # Dynamically create ops for listed kernel sizes
3236 for k in kernels:
3237 kernelStr = "x".join([str(d) for d in k])
3238 testName = f"{realName}_{kernelStr}"
3239 kernelOp = template.copy()
3240 kernelOp["filter"] = k
3241 kernelOp["template"] = False
3242 kernelOp["real_name"] = realName
3243 self.TOSA_OP_LIST[testName] = kernelOp
3244
3245 # Delete the template after having created the dynamic ops
3246 del self.TOSA_OP_LIST[opName]
Eric Kunzee5e26762020-10-13 16:11:07 -07003247
3248 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003249 """Fill in default fields for ops if they aren't already specified.
3250 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003251 for op in self.TOSA_OP_LIST:
3252
3253 # Required fields
3254 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003255 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003256 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003257 raise Exception(
3258 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3259 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003260
3261 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003262 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003263 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003264 raise Exception(
3265 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3266 op
3267 )
3268 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003269
3270 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003271 _ = self.TOSA_OP_LIST[op]["types"]
3272 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003273 raise Exception(
3274 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3275 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003276
3277 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003278 _ = self.TOSA_OP_LIST[op]["op"]
3279 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003280 raise Exception(
3281 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3282 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003283
3284 # Put in default rank range, if missing
3285 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003286 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003287 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003288 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003289
3290 # Tensor operator list
3291 # 'op': op name
3292 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003293 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3294 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003295 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3296 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003297 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003298
Kevin Cheng550ccc52021-03-03 11:21:43 -08003299 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003300 TYPE_INT_FP = [
3301 DType.INT8,
3302 DType.INT16,
3303 DType.INT32,
3304 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003305 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003306 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003307 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003308
Kevin Cheng550ccc52021-03-03 11:21:43 -08003309 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003310 TYPE_FI32 = [
3311 DType.FP32,
3312 DType.FP16,
3313 DType.BF16,
3314 DType.INT32,
3315 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003316 TYPE_FIB = [
3317 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003318 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003319 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003320 DType.INT8,
3321 DType.INT16,
3322 DType.INT32,
3323 DType.BOOL,
3324 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003325 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003326
Won Jeon2c34b462024-02-06 18:37:00 +00003327 TYPE_NARROW_INT_FP = [
3328 DType.INT8,
3329 DType.INT16,
3330 DType.FP16,
3331 DType.BF16,
3332 DType.FP32,
3333 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003334
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003335 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003336 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003337 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003338 [DType.INT8, DType.INT8, DType.INT32],
3339 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003340 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003341 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003342 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003343 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003344 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3345 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003346 ]
3347
Jeremy Johnson18a379d2024-03-28 15:53:21 +00003348 DEFAULT_RANK_RANGE = (0, gtu.MAX_TENSOR_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003349
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003350 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3351 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3352
evacha01ad8e1e22024-03-19 12:42:17 +00003353 PSEUDO_RANDOM_DATAGEN = {
3354 DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM,),
3355 DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM,),
3356 }
3357 DOT_PRODUCT_DATAGEN = {
3358 DType.FP16: (gtu.DataGenType.DOT_PRODUCT,),
3359 DType.FP32: (gtu.DataGenType.DOT_PRODUCT,),
3360 }
3361 EW_UNARY_DATAGEN = {
evacha014a205112024-03-08 16:39:24 +00003362 DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE),
3363 }
3364 PR_FS_DATAGEN = {
3365 DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL),
3366 DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL),
evacha01ad8e1e22024-03-19 12:42:17 +00003367 }
3368
Eric Kunzee5e26762020-10-13 16:11:07 -07003369 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003370 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003371 "argmax": {
3372 "op": Op.ARGMAX,
3373 "operands": (1, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00003374 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003375 "build_fcn": (
3376 build_argmax,
3377 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003378 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003379 TosaArgGen.agAxis,
3380 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003381 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003382 "error_if_validators": (
3383 TosaErrorValidator.evAxisSmallerZero,
3384 TosaErrorValidator.evAxisLargerRank,
3385 TosaErrorValidator.evArgmaxOutputRankMismatch,
3386 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3387 TosaErrorValidator.evWrongRank,
3388 TosaErrorValidator.evWrongInputType,
3389 TosaErrorValidator.evWrongOutputType,
3390 TosaErrorValidator.evWrongInputList,
3391 TosaErrorValidator.evWrongOutputList,
3392 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003393 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003394 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003395 "avg_pool2d": {
3396 "op": Op.AVG_POOL2D,
3397 "operands": (1, 0),
3398 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003399 "build_fcn": (
3400 build_pool2d,
3401 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003402 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003403 TosaArgGen.agPooling,
3404 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003405 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003406 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003407 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003408 "error_if_validators": (
3409 TosaErrorValidator.evKernelSmallerOne,
3410 TosaErrorValidator.evStrideSmallerOne,
3411 TosaErrorValidator.evPadSmallerZero,
3412 TosaErrorValidator.evWrongRank,
3413 TosaErrorValidator.evWrongInputType,
3414 TosaErrorValidator.evWrongOutputType,
3415 TosaErrorValidator.evWrongInputList,
3416 TosaErrorValidator.evWrongOutputList,
3417 TosaErrorValidator.evInputZeroPointNotZero,
3418 TosaErrorValidator.evOutputZeroPointNotZero,
3419 TosaErrorValidator.evPadLargerEqualKernel,
3420 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003421 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003422 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003423 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003424 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003425 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003426 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003427 "conv2d_TEMPLATE": {
3428 "op": Op.CONV2D,
3429 "operands": (1, 2),
3430 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003431 "build_fcn": (
3432 build_conv2d,
3433 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003434 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003435 TosaArgGen.agConv,
3436 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003437 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003438 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003439 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3440 "error_if_validators": (
3441 TosaErrorValidator.evWrongInputType,
3442 TosaErrorValidator.evWrongOutputType,
3443 TosaErrorValidator.evWrongInputList,
3444 TosaErrorValidator.evWrongOutputList,
3445 TosaErrorValidator.evInputZeroPointNotZero,
3446 TosaErrorValidator.evWeightZeroPointNotZero,
3447 TosaErrorValidator.evPadSmallerZero,
3448 TosaErrorValidator.evStrideSmallerOne,
3449 TosaErrorValidator.evDilationSmallerOne,
3450 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003451 TosaErrorValidator.evConvOutputShapeMismatch,
3452 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003453 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003454 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003455 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003456 "broadcastable_bias": True,
3457 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003458 "template": True,
3459 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003460 # Templated operator. Filled in by createDynamicOpLists
3461 "conv3d_TEMPLATE": {
3462 "op": Op.CONV3D,
3463 "operands": (1, 2),
3464 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003465 "build_fcn": (
3466 build_conv3d,
3467 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003468 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003469 TosaArgGen.agConv,
3470 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003471 "qgen": TosaQuantGen.qgConv,
3472 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003473 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3474 "error_if_validators": (
3475 TosaErrorValidator.evWrongInputType,
3476 TosaErrorValidator.evWrongOutputType,
3477 TosaErrorValidator.evWrongInputList,
3478 TosaErrorValidator.evWrongOutputList,
3479 TosaErrorValidator.evInputZeroPointNotZero,
3480 TosaErrorValidator.evWeightZeroPointNotZero,
3481 TosaErrorValidator.evPadSmallerZero,
3482 TosaErrorValidator.evStrideSmallerOne,
3483 TosaErrorValidator.evDilationSmallerOne,
3484 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003485 TosaErrorValidator.evConvOutputShapeMismatch,
3486 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003487 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003488 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003489 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003490 "filter": KERNELS_3D,
Kevin Cheng1533b852021-09-01 12:51:58 -07003491 "template": True,
3492 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003493 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003494 "depthwise_conv2d_TEMPLATE": {
3495 "op": Op.DEPTHWISE_CONV2D,
3496 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003497 "rank": (4, 4),
3498 "build_fcn": (
3499 build_depthwise_conv2d,
3500 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003501 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003502 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003503 ),
3504 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003505 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003506 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3507 "error_if_validators": (
3508 TosaErrorValidator.evWrongInputType,
3509 TosaErrorValidator.evWrongOutputType,
3510 TosaErrorValidator.evWrongInputList,
3511 TosaErrorValidator.evWrongOutputList,
3512 TosaErrorValidator.evInputZeroPointNotZero,
3513 TosaErrorValidator.evWeightZeroPointNotZero,
3514 TosaErrorValidator.evPadSmallerZero,
3515 TosaErrorValidator.evStrideSmallerOne,
3516 TosaErrorValidator.evDilationSmallerOne,
3517 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003518 TosaErrorValidator.evConvOutputShapeMismatch,
3519 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003520 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003521 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003522 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003523 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003524 "template": True,
3525 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003526 "fully_connected": {
3527 "op": Op.FULLY_CONNECTED,
3528 "operands": (1, 2),
3529 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003530 "build_fcn": (
3531 build_fully_connected,
3532 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003533 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003534 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003535 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003537 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003538 "error_if_validators": (
3539 TosaErrorValidator.evInputZeroPointNotZero,
3540 TosaErrorValidator.evWeightZeroPointNotZero,
3541 TosaErrorValidator.evWrongRank,
3542 TosaErrorValidator.evWrongInputType,
3543 TosaErrorValidator.evWrongOutputType,
3544 TosaErrorValidator.evWrongInputList,
3545 TosaErrorValidator.evWrongOutputList,
3546 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003547 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 "matmul": {
3550 "op": Op.MATMUL,
3551 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003552 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003553 "build_fcn": (
3554 build_matmul,
3555 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003556 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003557 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003558 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003559 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003560 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003561 "error_if_validators": (
3562 TosaErrorValidator.evInputZeroPointNotZero,
3563 TosaErrorValidator.evWrongRank,
3564 TosaErrorValidator.evWrongInputType,
3565 TosaErrorValidator.evWrongOutputType,
3566 TosaErrorValidator.evWrongInputList,
3567 TosaErrorValidator.evWrongOutputList,
3568 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003569 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003570 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003571 "max_pool2d": {
3572 "op": Op.MAX_POOL2D,
3573 "operands": (1, 0),
3574 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003575 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003576 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003577 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003578 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003579 TosaArgGen.agPooling,
3580 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003581 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003582 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003583 "error_if_validators": (
3584 TosaErrorValidator.evKernelSmallerOne,
3585 TosaErrorValidator.evStrideSmallerOne,
3586 TosaErrorValidator.evPadSmallerZero,
3587 TosaErrorValidator.evWrongRank,
3588 TosaErrorValidator.evWrongInputType,
3589 TosaErrorValidator.evWrongOutputType,
3590 TosaErrorValidator.evWrongInputList,
3591 TosaErrorValidator.evWrongOutputList,
3592 TosaErrorValidator.evPadLargerEqualKernel,
3593 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003594 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003595 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003596 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003597 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003598 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003599 "transpose_conv2d_TEMPLATE": {
3600 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003601 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003602 "rank": (4, 4),
3603 "build_fcn": (
3604 build_transpose_conv2d,
3605 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003606 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003607 TosaArgGen.agTransposeConv2D,
3608 ),
3609 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003610 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003611 "invalid_test_validators": (
3612 TosaInvalidValidator.ivHeightWidthInvalid,
3613 TosaInvalidValidator.ivNonPositiveOutputShape,
3614 ),
3615 "error_if_validators": (
3616 TosaErrorValidator.evWrongInputType,
3617 TosaErrorValidator.evWrongOutputType,
3618 TosaErrorValidator.evWrongInputList,
3619 TosaErrorValidator.evWrongOutputList,
3620 TosaErrorValidator.evInputZeroPointNotZero,
3621 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003622 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003623 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003624 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003625 TosaErrorValidator.evConvOutputShapeMismatch,
Tai Lyf36f2562024-03-14 16:21:29 +00003626 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003627 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003628 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003629 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003630 "template": True,
3631 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003632 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003633 "clamp": {
3634 "op": Op.CLAMP,
3635 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003636 "build_fcn": (
3637 build_clamp,
3638 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003639 TosaTensorValuesGen.tvgLazyGenDefault,
3640 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003641 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003642 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003643 "error_if_validators": (
3644 TosaErrorValidator.evMaxSmallerMin,
3645 TosaErrorValidator.evWrongInputType,
3646 TosaErrorValidator.evWrongOutputType,
3647 TosaErrorValidator.evWrongInputList,
3648 TosaErrorValidator.evWrongOutputList,
3649 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003650 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003651 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003652 "sigmoid": {
3653 "op": Op.SIGMOID,
3654 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003655 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003656 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003657 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003658 TosaTensorValuesGen.tvgLazyGenDefault,
3659 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003660 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003661 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003662 "error_if_validators": (
3663 TosaErrorValidator.evWrongInputType,
3664 TosaErrorValidator.evWrongOutputType,
3665 TosaErrorValidator.evWrongInputList,
3666 TosaErrorValidator.evWrongOutputList,
3667 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003668 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003669 },
3670 "tanh": {
3671 "op": Op.TANH,
3672 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003673 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003674 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003675 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003676 TosaTensorValuesGen.tvgLazyGenDefault,
3677 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003678 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003679 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003680 "error_if_validators": (
3681 TosaErrorValidator.evWrongInputType,
3682 TosaErrorValidator.evWrongOutputType,
3683 TosaErrorValidator.evWrongInputList,
3684 TosaErrorValidator.evWrongOutputList,
3685 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003686 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003687 "compliance": {
3688 "abs_error_lower_bound": 0.5,
3689 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003690 },
Won Jeon78155c62023-06-10 00:20:04 +00003691 "erf": {
3692 "op": Op.ERF,
3693 "operands": (1, 0),
3694 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003695 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003696 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003697 TosaTensorValuesGen.tvgLazyGenDefault,
3698 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003699 ),
3700 "types": TYPE_FP,
3701 "error_if_validators": (
3702 TosaErrorValidator.evWrongInputType,
3703 TosaErrorValidator.evWrongOutputType,
3704 TosaErrorValidator.evWrongInputList,
3705 TosaErrorValidator.evWrongOutputList,
3706 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003707 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003708 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003709 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 # Elementwise Binary Operators
3711 "add": {
3712 "op": Op.ADD,
3713 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003714 "build_fcn": (
3715 build_binary_broadcast,
3716 TosaTensorGen.tgBroadcastFuzz,
3717 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003718 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003719 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003720 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003721 "error_if_validators": (
3722 TosaErrorValidator.evRankMismatch,
3723 TosaErrorValidator.evWrongInputType,
3724 TosaErrorValidator.evWrongOutputType,
3725 TosaErrorValidator.evWrongInputList,
3726 TosaErrorValidator.evWrongOutputList,
3727 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003728 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003729 ),
evacha014a205112024-03-08 16:39:24 +00003730 "data_gen": PR_FS_DATAGEN,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003731 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003732 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 "arithmetic_right_shift": {
3734 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3735 "operands": (2, 0),
3736 "build_fcn": (
3737 build_arithmetic_right_shift,
3738 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003739 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003740 TosaArgGen.agArithmeticRightShift,
3741 ),
3742 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003743 "error_if_validators": (
3744 TosaErrorValidator.evRankMismatch,
3745 TosaErrorValidator.evWrongInputType,
3746 TosaErrorValidator.evWrongOutputType,
3747 TosaErrorValidator.evWrongInputList,
3748 TosaErrorValidator.evWrongOutputList,
3749 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003750 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003751 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003752 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003753 "bitwise_and": {
3754 "op": Op.BITWISE_AND,
3755 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003756 "build_fcn": (
3757 build_binary_broadcast,
3758 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003759 TosaTensorValuesGen.tvgLazyGenDefault,
3760 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003761 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003762 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003763 "error_if_validators": (
3764 TosaErrorValidator.evRankMismatch,
3765 TosaErrorValidator.evWrongInputType,
3766 TosaErrorValidator.evWrongOutputType,
3767 TosaErrorValidator.evWrongInputList,
3768 TosaErrorValidator.evWrongOutputList,
3769 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003770 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003771 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003772 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003773 "bitwise_or": {
3774 "op": Op.BITWISE_OR,
3775 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003776 "build_fcn": (
3777 build_binary_broadcast,
3778 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003779 TosaTensorValuesGen.tvgLazyGenDefault,
3780 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003781 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003782 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003783 "error_if_validators": (
3784 TosaErrorValidator.evRankMismatch,
3785 TosaErrorValidator.evWrongInputType,
3786 TosaErrorValidator.evWrongOutputType,
3787 TosaErrorValidator.evWrongInputList,
3788 TosaErrorValidator.evWrongOutputList,
3789 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003790 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003791 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003792 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003793 "bitwise_xor": {
3794 "op": Op.BITWISE_XOR,
3795 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003796 "build_fcn": (
3797 build_binary_broadcast,
3798 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003799 TosaTensorValuesGen.tvgLazyGenDefault,
3800 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003801 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003802 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003803 "error_if_validators": (
3804 TosaErrorValidator.evRankMismatch,
3805 TosaErrorValidator.evWrongInputType,
3806 TosaErrorValidator.evWrongOutputType,
3807 TosaErrorValidator.evWrongInputList,
3808 TosaErrorValidator.evWrongOutputList,
3809 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003810 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003811 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003812 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003813 "intdiv": {
3814 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003815 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003816 "build_fcn": (
3817 build_binary_broadcast,
3818 TosaTensorGen.tgBroadcastFuzz,
3819 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003820 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003821 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003822 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003823 "error_if_validators": (
3824 TosaErrorValidator.evRankMismatch,
3825 TosaErrorValidator.evWrongInputType,
3826 TosaErrorValidator.evWrongOutputType,
3827 TosaErrorValidator.evWrongInputList,
3828 TosaErrorValidator.evWrongOutputList,
3829 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003830 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003831 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003832 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003833 "logical_and": {
3834 "op": Op.LOGICAL_AND,
3835 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003836 "build_fcn": (
3837 build_binary_broadcast,
3838 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003839 TosaTensorValuesGen.tvgLazyGenDefault,
3840 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003841 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003842 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003843 "error_if_validators": (
3844 TosaErrorValidator.evRankMismatch,
3845 TosaErrorValidator.evWrongInputType,
3846 TosaErrorValidator.evWrongOutputType,
3847 TosaErrorValidator.evWrongInputList,
3848 TosaErrorValidator.evWrongOutputList,
3849 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003850 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003851 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003852 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003853 "logical_left_shift": {
3854 "op": Op.LOGICAL_LEFT_SHIFT,
3855 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003856 "build_fcn": (
3857 build_binary_broadcast,
3858 TosaTensorGen.tgBroadcastFuzz,
3859 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003860 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003861 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003862 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003863 "error_if_validators": (
3864 TosaErrorValidator.evRankMismatch,
3865 TosaErrorValidator.evWrongInputType,
3866 TosaErrorValidator.evWrongOutputType,
3867 TosaErrorValidator.evWrongInputList,
3868 TosaErrorValidator.evWrongOutputList,
3869 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003870 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003872 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003873 "logical_right_shift": {
3874 "op": Op.LOGICAL_RIGHT_SHIFT,
3875 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003876 "build_fcn": (
3877 build_binary_broadcast,
3878 TosaTensorGen.tgBroadcastFuzz,
3879 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003880 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003881 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003882 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003883 "error_if_validators": (
3884 TosaErrorValidator.evRankMismatch,
3885 TosaErrorValidator.evWrongInputType,
3886 TosaErrorValidator.evWrongOutputType,
3887 TosaErrorValidator.evWrongInputList,
3888 TosaErrorValidator.evWrongOutputList,
3889 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003890 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003891 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003892 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003893 "logical_or": {
3894 "op": Op.LOGICAL_OR,
3895 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003896 "build_fcn": (
3897 build_binary_broadcast,
3898 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003899 TosaTensorValuesGen.tvgLazyGenDefault,
3900 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003901 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003902 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003903 "error_if_validators": (
3904 TosaErrorValidator.evRankMismatch,
3905 TosaErrorValidator.evWrongInputType,
3906 TosaErrorValidator.evWrongOutputType,
3907 TosaErrorValidator.evWrongInputList,
3908 TosaErrorValidator.evWrongOutputList,
3909 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003910 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003911 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003912 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003913 "logical_xor": {
3914 "op": Op.LOGICAL_XOR,
3915 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003916 "build_fcn": (
3917 build_binary_broadcast,
3918 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003919 TosaTensorValuesGen.tvgLazyGenDefault,
3920 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003921 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003922 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003923 "error_if_validators": (
3924 TosaErrorValidator.evRankMismatch,
3925 TosaErrorValidator.evWrongInputType,
3926 TosaErrorValidator.evWrongOutputType,
3927 TosaErrorValidator.evWrongInputList,
3928 TosaErrorValidator.evWrongOutputList,
3929 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003930 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003931 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003932 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003933 "maximum": {
3934 "op": Op.MAXIMUM,
3935 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003936 "build_fcn": (
3937 build_binary_broadcast,
3938 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003939 TosaTensorValuesGen.tvgLazyGenDefault,
3940 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003941 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003942 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003943 "error_if_validators": (
3944 TosaErrorValidator.evRankMismatch,
3945 TosaErrorValidator.evWrongInputType,
3946 TosaErrorValidator.evWrongOutputType,
3947 TosaErrorValidator.evWrongInputList,
3948 TosaErrorValidator.evWrongOutputList,
3949 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003950 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003951 ),
evacha014a205112024-03-08 16:39:24 +00003952 "data_gen": PR_FS_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003954 "minimum": {
3955 "op": Op.MINIMUM,
3956 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003957 "build_fcn": (
3958 build_binary_broadcast,
3959 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003960 TosaTensorValuesGen.tvgLazyGenDefault,
3961 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003962 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003963 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003964 "error_if_validators": (
3965 TosaErrorValidator.evRankMismatch,
3966 TosaErrorValidator.evWrongInputType,
3967 TosaErrorValidator.evWrongOutputType,
3968 TosaErrorValidator.evWrongInputList,
3969 TosaErrorValidator.evWrongOutputList,
3970 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003971 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003972 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003973 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003974 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003975 "mul": {
3976 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003977 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003978 "build_fcn": (
3979 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003980 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003981 TosaTensorValuesGen.tvgMul,
3982 TosaArgGen.agMul,
3983 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003984 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003985 "error_if_validators": (
3986 TosaErrorValidator.evWrongInputType,
3987 TosaErrorValidator.evWrongOutputType,
3988 TosaErrorValidator.evWrongInputList,
3989 TosaErrorValidator.evWrongOutputList,
3990 TosaErrorValidator.evRankMismatch,
3991 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003992 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003993 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003994 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003995 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003996 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003997 "pow": {
3998 "op": Op.POW,
3999 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004000 "build_fcn": (
4001 build_binary_broadcast,
4002 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00004003 TosaTensorValuesGen.tvgPow,
4004 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004005 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004006 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004007 "error_if_validators": (
4008 TosaErrorValidator.evRankMismatch,
4009 TosaErrorValidator.evWrongInputType,
4010 TosaErrorValidator.evWrongOutputType,
4011 TosaErrorValidator.evWrongInputList,
4012 TosaErrorValidator.evWrongOutputList,
4013 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004014 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004015 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004016 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004017 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004018 "sub": {
4019 "op": Op.SUB,
4020 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004021 "build_fcn": (
4022 build_binary_broadcast,
4023 TosaTensorGen.tgBroadcastFuzz,
4024 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004025 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004026 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004027 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004028 "error_if_validators": (
4029 TosaErrorValidator.evRankMismatch,
4030 TosaErrorValidator.evWrongInputType,
4031 TosaErrorValidator.evWrongOutputType,
4032 TosaErrorValidator.evWrongInputList,
4033 TosaErrorValidator.evWrongOutputList,
4034 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004035 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004036 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004037 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004038 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004039 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004040 "table": {
4041 "op": Op.TABLE,
4042 # Use the automatic generation functions to create the input array
4043 # but create the table tensor in the build function, as it may be
4044 # a different type from the input
4045 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004046 "build_fcn": (
4047 build_table,
4048 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004049 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004050 TosaArgGen.agTable,
4051 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004052 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004053 "error_if_validators": (
4054 TosaErrorValidator.evWrongInputType,
4055 TosaErrorValidator.evWrongOutputType,
4056 TosaErrorValidator.evWrongInputList,
4057 TosaErrorValidator.evWrongOutputList,
4058 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004059 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004060 # Elementwise Unary operators
4061 "abs": {
4062 "op": Op.ABS,
4063 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004064 "build_fcn": (
4065 build_unary,
4066 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004067 TosaTensorValuesGen.tvgLazyGenDefault,
4068 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004069 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004070 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004071 "error_if_validators": (
4072 TosaErrorValidator.evWrongInputType,
4073 TosaErrorValidator.evWrongOutputType,
4074 TosaErrorValidator.evWrongInputList,
4075 TosaErrorValidator.evWrongOutputList,
4076 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004077 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004078 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004079 "bitwise_not": {
4080 "op": Op.BITWISE_NOT,
4081 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004082 "build_fcn": (
4083 build_unary,
4084 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004085 TosaTensorValuesGen.tvgLazyGenDefault,
4086 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004087 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004088 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004089 "error_if_validators": (
4090 TosaErrorValidator.evWrongInputType,
4091 TosaErrorValidator.evWrongOutputType,
4092 TosaErrorValidator.evWrongInputList,
4093 TosaErrorValidator.evWrongOutputList,
4094 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004095 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004096 "ceil": {
4097 "op": Op.CEIL,
4098 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004099 "build_fcn": (
4100 build_unary,
4101 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004102 TosaTensorValuesGen.tvgLazyGenDefault,
4103 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004104 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004105 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004106 "error_if_validators": (
4107 TosaErrorValidator.evWrongInputType,
4108 TosaErrorValidator.evWrongOutputType,
4109 TosaErrorValidator.evWrongInputList,
4110 TosaErrorValidator.evWrongOutputList,
4111 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004112 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004113 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004114 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004115 "clz": {
4116 "op": Op.CLZ,
4117 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004118 "build_fcn": (
4119 build_unary,
4120 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004121 TosaTensorValuesGen.tvgLazyGenDefault,
4122 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004123 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004124 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004125 "error_if_validators": (
4126 TosaErrorValidator.evWrongInputType,
4127 TosaErrorValidator.evWrongOutputType,
4128 TosaErrorValidator.evWrongInputList,
4129 TosaErrorValidator.evWrongOutputList,
4130 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004131 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004132 "cos": {
4133 "op": Op.COS,
4134 "operands": (1, 0),
4135 "build_fcn": (
4136 build_unary,
4137 TosaTensorGen.tgBasic,
4138 TosaTensorValuesGen.tvgLazyGenDefault,
4139 TosaArgGen.agNone,
4140 ),
4141 "types": TYPE_FP,
4142 "error_if_validators": (
4143 TosaErrorValidator.evWrongInputType,
4144 TosaErrorValidator.evWrongOutputType,
4145 TosaErrorValidator.evWrongInputList,
4146 TosaErrorValidator.evWrongOutputList,
4147 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004148 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson1eb14552024-04-11 16:21:54 +01004149 "compliance": {
4150 "abs_error_normal_divisor": 2,
4151 "abs_error_bound_addition": 1,
4152 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004153 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004154 "exp": {
4155 "op": Op.EXP,
4156 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004157 "build_fcn": (
4158 build_unary,
4159 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004160 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004161 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004162 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004163 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004164 "error_if_validators": (
4165 TosaErrorValidator.evWrongInputType,
4166 TosaErrorValidator.evWrongOutputType,
4167 TosaErrorValidator.evWrongInputList,
4168 TosaErrorValidator.evWrongOutputList,
4169 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004170 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004171 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004172 "floor": {
4173 "op": Op.FLOOR,
4174 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004175 "build_fcn": (
4176 build_unary,
4177 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004178 TosaTensorValuesGen.tvgLazyGenDefault,
4179 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004180 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004181 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004182 "error_if_validators": (
4183 TosaErrorValidator.evWrongInputType,
4184 TosaErrorValidator.evWrongOutputType,
4185 TosaErrorValidator.evWrongInputList,
4186 TosaErrorValidator.evWrongOutputList,
4187 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004188 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004189 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004190 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004191 "log": {
4192 "op": Op.LOG,
4193 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004194 "build_fcn": (
4195 build_unary,
4196 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004197 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004198 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004199 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004200 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004201 "error_if_validators": (
4202 TosaErrorValidator.evWrongInputType,
4203 TosaErrorValidator.evWrongOutputType,
4204 TosaErrorValidator.evWrongInputList,
4205 TosaErrorValidator.evWrongOutputList,
4206 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004207 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004208 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004209 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004210 "logical_not": {
4211 "op": Op.LOGICAL_NOT,
4212 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004213 "build_fcn": (
4214 build_unary,
4215 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004216 TosaTensorValuesGen.tvgLazyGenDefault,
4217 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004218 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004219 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004220 "error_if_validators": (
4221 TosaErrorValidator.evWrongInputType,
4222 TosaErrorValidator.evWrongOutputType,
4223 TosaErrorValidator.evWrongInputList,
4224 TosaErrorValidator.evWrongOutputList,
4225 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004226 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004227 "negate": {
4228 "op": Op.NEGATE,
4229 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004230 "build_fcn": (
4231 build_unary,
4232 TosaTensorGen.tgBasic,
4233 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004234 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004235 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004236 "qgen": TosaQuantGen.qgUnary,
4237 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004238 "error_if_validators": (
4239 TosaErrorValidator.evInputZeroPointNotZero,
4240 TosaErrorValidator.evOutputZeroPointNotZero,
4241 TosaErrorValidator.evWrongInputType,
4242 TosaErrorValidator.evWrongOutputType,
4243 TosaErrorValidator.evWrongInputList,
4244 TosaErrorValidator.evWrongOutputList,
4245 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004246 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004247 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004248 "reciprocal": {
4249 "op": Op.RECIPROCAL,
4250 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004251 "build_fcn": (
4252 build_unary,
4253 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004254 TosaTensorValuesGen.tvgLazyGenDefault,
4255 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004256 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004257 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004258 "error_if_validators": (
4259 TosaErrorValidator.evWrongInputType,
4260 TosaErrorValidator.evWrongOutputType,
4261 TosaErrorValidator.evWrongInputList,
4262 TosaErrorValidator.evWrongOutputList,
4263 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004264 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004265 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004266 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004267 "rsqrt": {
4268 "op": Op.RSQRT,
4269 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004270 "build_fcn": (
4271 build_unary,
4272 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004273 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004274 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004275 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004276 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004277 "error_if_validators": (
4278 TosaErrorValidator.evWrongInputType,
4279 TosaErrorValidator.evWrongOutputType,
4280 TosaErrorValidator.evWrongInputList,
4281 TosaErrorValidator.evWrongOutputList,
4282 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004283 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004284 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004285 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004286 "sin": {
4287 "op": Op.SIN,
4288 "operands": (1, 0),
4289 "build_fcn": (
4290 build_unary,
4291 TosaTensorGen.tgBasic,
4292 TosaTensorValuesGen.tvgLazyGenDefault,
4293 TosaArgGen.agNone,
4294 ),
4295 "types": TYPE_FP,
4296 "error_if_validators": (
4297 TosaErrorValidator.evWrongInputType,
4298 TosaErrorValidator.evWrongOutputType,
4299 TosaErrorValidator.evWrongInputList,
4300 TosaErrorValidator.evWrongOutputList,
4301 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004302 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jerry Ge51bd4f52024-02-20 11:21:19 -08004303 "compliance": {"abs_error_normal_divisor": 2},
4304 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004305 # Elementwise Ternary operators
4306 "select": {
4307 "op": Op.SELECT,
4308 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004309 "build_fcn": (
4310 build_select,
4311 TosaTensorGen.tgBroadcastFuzz,
4312 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004313 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004314 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004315 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004316 "error_if_validators": (
4317 TosaErrorValidator.evRankMismatch,
4318 TosaErrorValidator.evWrongInputType,
4319 TosaErrorValidator.evWrongOutputType,
4320 TosaErrorValidator.evWrongInputList,
4321 TosaErrorValidator.evWrongOutputList,
4322 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004323 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004324 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004325 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004326 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004327 # Comparison operators
4328 "equal": {
4329 "op": Op.EQUAL,
4330 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004331 "build_fcn": (
4332 build_comparison,
4333 TosaTensorGen.tgBroadcastFuzz,
4334 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004335 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004336 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004337 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004338 "error_if_validators": (
4339 TosaErrorValidator.evRankMismatch,
4340 TosaErrorValidator.evWrongInputType,
4341 TosaErrorValidator.evWrongOutputType,
4342 TosaErrorValidator.evWrongInputList,
4343 TosaErrorValidator.evWrongOutputList,
4344 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004345 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004346 ),
evacha014a205112024-03-08 16:39:24 +00004347 "data_gen": PR_FS_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004348 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004349 "greater_equal": {
4350 "op": Op.GREATER_EQUAL,
4351 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004352 "build_fcn": (
4353 build_comparison,
4354 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004355 TosaTensorValuesGen.tvgLazyGenDefault,
4356 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004357 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004358 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004359 "error_if_validators": (
4360 TosaErrorValidator.evRankMismatch,
4361 TosaErrorValidator.evWrongInputType,
4362 TosaErrorValidator.evWrongOutputType,
4363 TosaErrorValidator.evWrongInputList,
4364 TosaErrorValidator.evWrongOutputList,
4365 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004366 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004367 ),
evacha014a205112024-03-08 16:39:24 +00004368 "data_gen": PR_FS_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004369 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004370 "greater": {
4371 "op": Op.GREATER,
4372 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004373 "build_fcn": (
4374 build_comparison,
4375 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004376 TosaTensorValuesGen.tvgLazyGenDefault,
4377 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004378 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004379 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004380 "error_if_validators": (
4381 TosaErrorValidator.evRankMismatch,
4382 TosaErrorValidator.evWrongInputType,
4383 TosaErrorValidator.evWrongOutputType,
4384 TosaErrorValidator.evWrongInputList,
4385 TosaErrorValidator.evWrongOutputList,
4386 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004387 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004388 ),
evacha014a205112024-03-08 16:39:24 +00004389 "data_gen": PR_FS_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004390 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004391 # Reduction operators
4392 "reduce_all": {
4393 "op": Op.REDUCE_ALL,
4394 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004395 "build_fcn": (
4396 build_reduce,
4397 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004398 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004399 TosaArgGen.agAxis,
4400 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004401 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004402 "error_if_validators": (
4403 TosaErrorValidator.evAxisLargerRank,
4404 TosaErrorValidator.evAxisSmallerZero,
4405 TosaErrorValidator.evShapeOfAxisNotOne,
4406 TosaErrorValidator.evWrongInputType,
4407 TosaErrorValidator.evWrongOutputType,
4408 TosaErrorValidator.evWrongRank,
4409 TosaErrorValidator.evWrongInputList,
4410 TosaErrorValidator.evWrongOutputList,
4411 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004412 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004413 "reduce_any": {
4414 "op": Op.REDUCE_ANY,
4415 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004416 "build_fcn": (
4417 build_reduce,
4418 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004419 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004420 TosaArgGen.agAxis,
4421 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004422 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004423 "error_if_validators": (
4424 TosaErrorValidator.evAxisLargerRank,
4425 TosaErrorValidator.evAxisSmallerZero,
4426 TosaErrorValidator.evShapeOfAxisNotOne,
4427 TosaErrorValidator.evWrongInputType,
4428 TosaErrorValidator.evWrongOutputType,
4429 TosaErrorValidator.evWrongRank,
4430 TosaErrorValidator.evWrongInputList,
4431 TosaErrorValidator.evWrongOutputList,
4432 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004434 "reduce_max": {
4435 "op": Op.REDUCE_MAX,
4436 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004437 "build_fcn": (
4438 build_reduce,
4439 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004440 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004441 TosaArgGen.agAxis,
4442 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004443 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004444 "error_if_validators": (
4445 TosaErrorValidator.evAxisLargerRank,
4446 TosaErrorValidator.evAxisSmallerZero,
4447 TosaErrorValidator.evShapeOfAxisNotOne,
4448 TosaErrorValidator.evWrongInputType,
4449 TosaErrorValidator.evWrongOutputType,
4450 TosaErrorValidator.evWrongRank,
4451 TosaErrorValidator.evWrongInputList,
4452 TosaErrorValidator.evWrongOutputList,
4453 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004454 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004455 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004456 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004457 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004458 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004459 "build_fcn": (
4460 build_reduce,
4461 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004462 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004463 TosaArgGen.agAxis,
4464 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004465 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004466 "error_if_validators": (
4467 TosaErrorValidator.evAxisLargerRank,
4468 TosaErrorValidator.evAxisSmallerZero,
4469 TosaErrorValidator.evShapeOfAxisNotOne,
4470 TosaErrorValidator.evWrongInputType,
4471 TosaErrorValidator.evWrongOutputType,
4472 TosaErrorValidator.evWrongRank,
4473 TosaErrorValidator.evWrongInputList,
4474 TosaErrorValidator.evWrongOutputList,
4475 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004476 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004477 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004478 "reduce_product": {
4479 "op": Op.REDUCE_PRODUCT,
4480 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004481 "build_fcn": (
4482 build_reduce,
4483 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004484 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004485 TosaArgGen.agAxis,
4486 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004487 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004488 "error_if_validators": (
4489 TosaErrorValidator.evAxisLargerRank,
4490 TosaErrorValidator.evAxisSmallerZero,
4491 TosaErrorValidator.evShapeOfAxisNotOne,
4492 TosaErrorValidator.evWrongInputType,
4493 TosaErrorValidator.evWrongOutputType,
4494 TosaErrorValidator.evWrongRank,
4495 TosaErrorValidator.evWrongInputList,
4496 TosaErrorValidator.evWrongOutputList,
4497 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004498 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004499 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004500 "reduce_sum": {
4501 "op": Op.REDUCE_SUM,
4502 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004503 "build_fcn": (
4504 build_reduce,
4505 TosaTensorGen.tgBasic,
4506 TosaTensorValuesGen.tvgReduceSum,
4507 TosaArgGen.agAxis,
4508 ),
James Ward24dbc422022-10-19 12:20:31 +01004509 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004510 "error_if_validators": (
4511 TosaErrorValidator.evAxisLargerRank,
4512 TosaErrorValidator.evAxisSmallerZero,
4513 TosaErrorValidator.evShapeOfAxisNotOne,
4514 TosaErrorValidator.evWrongInputType,
4515 TosaErrorValidator.evWrongOutputType,
4516 TosaErrorValidator.evWrongRank,
4517 TosaErrorValidator.evWrongInputList,
4518 TosaErrorValidator.evWrongOutputList,
4519 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004520 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004521 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004522 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004523 "concat": {
4524 "op": Op.CONCAT,
4525 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004526 "build_fcn": (
4527 build_concat,
4528 TosaTensorGen.tgConcat,
4529 TosaTensorValuesGen.tvgConcat,
4530 TosaArgGen.agAxis,
4531 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004532 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004533 "error_if_validators": (
4534 TosaErrorValidator.evAxisLargerRank,
4535 TosaErrorValidator.evAxisSmallerZero,
4536 TosaErrorValidator.evConcatInputRankMismatch,
4537 TosaErrorValidator.evConcatShapeSumMismatch,
4538 TosaErrorValidator.evConcatInputDimMismatch,
4539 TosaErrorValidator.evWrongInputType,
4540 TosaErrorValidator.evWrongOutputType,
4541 TosaErrorValidator.evWrongOutputList,
4542 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004543 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004544 },
4545 "pad": {
4546 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004547 "operands": (2, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004548 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004549 "build_fcn": (
4550 build_pad,
4551 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004552 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004553 TosaArgGen.agPad,
4554 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004555 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004556 "error_if_validators": (
4557 TosaErrorValidator.evWrongInputType,
4558 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004559 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004560 TosaErrorValidator.evWrongOutputType,
4561 TosaErrorValidator.evWrongInputList,
4562 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004563 TosaErrorValidator.evRankMismatch,
4564 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004565 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004566 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004567 },
Won Jeona21b2e82023-08-10 10:33:01 +00004568 "dim": {
4569 "op": Op.DIM,
4570 "operands": (1, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004571 "rank": (1, gtu.MAX_TENSOR_RANK),
Won Jeona21b2e82023-08-10 10:33:01 +00004572 "build_fcn": (
4573 build_dim,
4574 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004575 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004576 TosaArgGen.agAxis,
4577 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004578 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004579 "error_if_validators": (
4580 TosaErrorValidator.evAxisLargerRank,
4581 TosaErrorValidator.evAxisSmallerZero,
4582 TosaErrorValidator.evWrongInputType,
4583 TosaErrorValidator.evWrongInputList,
4584 TosaErrorValidator.evWrongOutputList,
4585 TosaErrorValidator.evWrongRank,
4586 ),
4587 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004588 "reshape": {
4589 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004590 "operands": (2, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004591 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004592 "build_fcn": (
4593 build_reshape,
4594 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004595 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004596 TosaArgGen.agReshape,
4597 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004598 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004599 "error_if_validators": (
4600 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4601 TosaErrorValidator.evWrongInputType,
4602 TosaErrorValidator.evWrongOutputType,
4603 TosaErrorValidator.evWrongInputList,
4604 TosaErrorValidator.evWrongOutputList,
4605 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004606 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004607 },
4608 "reverse": {
4609 "op": Op.REVERSE,
4610 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004611 "build_fcn": (
4612 build_reverse,
4613 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004614 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004615 TosaArgGen.agAxis,
4616 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004617 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004618 "error_if_validators": (
4619 TosaErrorValidator.evAxisSmallerZero,
4620 TosaErrorValidator.evAxisLargerRank,
4621 TosaErrorValidator.evWrongInputType,
4622 TosaErrorValidator.evWrongOutputType,
4623 TosaErrorValidator.evWrongInputList,
4624 TosaErrorValidator.evWrongOutputList,
4625 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004626 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004627 },
4628 "slice": {
4629 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004630 "operands": (3, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004631 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004632 "build_fcn": (
4633 build_slice,
4634 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004635 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004636 TosaArgGen.agSlice,
4637 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004638 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004639 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004640 # TODO Turn off these error categories for now as the reference
4641 # model cannot allocate memory space for empty tensor. We probably
4642 # can report an accurate error messege at the right place during
4643 # exeuction.
4644 # TosaErrorValidator.evStartSmallerZero,
4645 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004646 TosaErrorValidator.evStartSizeOutsideBounds,
4647 TosaErrorValidator.evSizeOutputShapeMismatch,
4648 TosaErrorValidator.evInputSizeStartLengthMismatch,
4649 TosaErrorValidator.evWrongRank,
4650 TosaErrorValidator.evWrongInputType,
4651 TosaErrorValidator.evWrongOutputType,
4652 TosaErrorValidator.evWrongInputList,
4653 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004654 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004655 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004656 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004657 },
4658 "tile": {
4659 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004660 "operands": (2, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004661 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004662 "build_fcn": (
4663 build_tile,
4664 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004665 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004666 TosaArgGen.agTile,
4667 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004668 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004669 "error_if_validators": (
4670 TosaErrorValidator.evWrongInputType,
4671 TosaErrorValidator.evWrongOutputType,
4672 TosaErrorValidator.evWrongInputList,
4673 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004674 TosaErrorValidator.evRankMismatch,
4675 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004676 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004677 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004678 },
4679 "transpose": {
4680 "op": Op.TRANSPOSE,
4681 "operands": (1, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004682 "rank": (1, gtu.MAX_TENSOR_RANK),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004683 "build_fcn": (
4684 build_transpose,
4685 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004686 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004687 TosaArgGen.agTranspose,
4688 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004689 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004690 "error_if_validators": (
4691 TosaErrorValidator.evIndexOutsideBounds,
4692 TosaErrorValidator.evIndexUsedTwice,
4693 TosaErrorValidator.evWrongInputType,
4694 TosaErrorValidator.evWrongOutputType,
4695 TosaErrorValidator.evWrongInputList,
4696 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004697 TosaErrorValidator.evWrongRank,
4698 TosaErrorValidator.evRankMismatch,
4699 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004700 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004701 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004702 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004703 # Data nodes
4704 "const": {
4705 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004706 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004707 "build_fcn": (
4708 build_const,
4709 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004710 TosaTensorValuesGen.tvgLazyGenDefault,
4711 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004712 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004713 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha01ad8e1e22024-03-19 12:42:17 +00004714 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004715 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004716 "identity": {
4717 "op": Op.IDENTITY,
4718 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004719 "build_fcn": (
4720 build_unary,
4721 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004722 TosaTensorValuesGen.tvgLazyGenDefault,
4723 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004724 ),
evacha011adff832024-03-06 17:33:44 +00004725 "types": TYPE_FIB + [DType.INT4, DType.INT48],
evacha01ad8e1e22024-03-19 12:42:17 +00004726 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004727 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004728 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004729 "gather": {
4730 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004731 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004732 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004733 "build_fcn": (
4734 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004735 TosaTensorGen.tgGather,
4736 TosaTensorValuesGen.tvgGather,
4737 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004738 ),
James Ward24dbc422022-10-19 12:20:31 +01004739 "types": (
4740 DType.INT8,
4741 DType.INT16,
4742 DType.INT32,
4743 DType.FP16,
4744 DType.BF16,
4745 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004746 DType.FP8E4M3,
4747 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004748 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004749 "error_if_validators": (
4750 TosaErrorValidator.evWrongInputType,
4751 TosaErrorValidator.evWrongOutputType,
4752 TosaErrorValidator.evWrongInputList,
4753 TosaErrorValidator.evWrongOutputList,
4754 TosaErrorValidator.evWrongRank,
4755 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004756 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004757 },
4758 "scatter": {
4759 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004760 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004761 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004762 "build_fcn": (
4763 build_scatter,
4764 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004765 TosaTensorValuesGen.tvgScatter,
4766 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004767 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004768 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004769 "error_if_validators": (
4770 TosaErrorValidator.evWrongInputType,
4771 TosaErrorValidator.evWrongOutputType,
4772 TosaErrorValidator.evWrongInputList,
4773 TosaErrorValidator.evWrongOutputList,
4774 TosaErrorValidator.evWrongRank,
4775 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004776 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004777 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004778 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004779 "resize": {
4780 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004781 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004782 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004783 "build_fcn": (
4784 build_resize,
4785 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004786 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004787 TosaArgGen.agResize,
4788 ),
James Ward24dbc422022-10-19 12:20:31 +01004789 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004790 "invalid_test_validators": (
4791 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004792 ),
4793 "error_if_validators": (
4794 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004795 TosaErrorValidator.evScaleSmallerEqualZero,
4796 TosaErrorValidator.evScaleNLargerMax,
4797 TosaErrorValidator.evScaleDLargerMax,
4798 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004799 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004800 TosaErrorValidator.evBorderSmallerMin,
4801 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004802 TosaErrorValidator.evWrongInputType,
4803 TosaErrorValidator.evWrongOutputType,
4804 TosaErrorValidator.evWrongRank,
4805 TosaErrorValidator.evWrongInputList,
4806 TosaErrorValidator.evWrongOutputList,
4807 TosaErrorValidator.evBatchMismatch,
4808 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004809 TosaErrorValidator.evResizeOutputShapeMismatch,
4810 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004811 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004812 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004813 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004814 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004815 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004816 "cast": {
4817 "op": Op.CAST,
4818 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004819 "build_fcn": (
4820 build_cast,
4821 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004822 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004823 TosaArgGen.agCast,
4824 ),
James Ward8b390432022-08-12 20:48:56 +01004825 "types": (
4826 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004827 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004828 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004829 DType.INT8,
4830 DType.INT16,
4831 DType.INT32,
4832 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004833 DType.FP8E4M3,
4834 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004835 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004836 "error_if_validators": (
4837 TosaErrorValidator.evWrongInputType,
4838 TosaErrorValidator.evWrongOutputType,
4839 TosaErrorValidator.evWrongInputList,
4840 TosaErrorValidator.evWrongOutputList,
4841 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004842 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson708da822023-11-15 16:25:45 +00004843 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004844 },
4845 "rescale": {
4846 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004847 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004848 "build_fcn": (
4849 build_rescale,
4850 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004851 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004852 TosaArgGen.agRescale,
4853 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004854 "types": [
4855 DType.UINT8,
4856 DType.INT8,
4857 DType.INT16,
4858 DType.INT32,
4859 DType.INT48,
4860 DType.UINT16,
4861 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004862 "error_if_validators": (
4863 TosaErrorValidator.evInputZeroPointNotZero,
4864 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004865 TosaErrorValidator.evU16InputZeroPointNotValid,
4866 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004867 TosaErrorValidator.evScaleTrue,
4868 TosaErrorValidator.evScaleNotTrue,
4869 TosaErrorValidator.evWrongInputType,
4870 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004871 TosaErrorValidator.evWrongInputList,
4872 TosaErrorValidator.evWrongOutputList,
4873 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004874 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004875 # Custom
4876 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004877 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004878 # Two varients of cond_if, one that generates one of two constant tensors (no
4879 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4880 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004881 "cond_if_const": {
4882 "op": Op.COND_IF,
4883 "operands": (0, 2),
4884 "build_fcn": (
4885 build_cond_if_const,
4886 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004887 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004888 TosaArgGen.agCondIf,
4889 ),
4890 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004891 "error_if_validators": (
4892 TosaErrorValidator.evOutputListThenGraphMismatch,
4893 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004894 TosaErrorValidator.evCondIfCondNotMatchingBool,
4895 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004896 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004897 },
4898 "cond_if_binary": {
4899 "op": Op.COND_IF,
4900 "operands": (2, 0),
4901 "build_fcn": (
4902 build_cond_if_binary,
4903 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004904 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004905 TosaArgGen.agCondIf,
4906 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004907 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004908 "error_if_validators": (
4909 TosaErrorValidator.evInputListThenGraphMismatch,
4910 TosaErrorValidator.evInputListElseGraphMismatch,
4911 TosaErrorValidator.evOutputListThenGraphMismatch,
4912 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004913 TosaErrorValidator.evCondIfCondNotMatchingBool,
4914 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004915 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004916 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004917 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004918 "while_loop": {
4919 "op": Op.WHILE_LOOP,
4920 "operands": (0, 1),
4921 "build_fcn": (
4922 build_while_loop,
4923 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004924 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004925 TosaArgGen.agWhileLoop,
4926 ),
4927 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004928 "error_if_validators": (
4929 TosaErrorValidator.evInputListOutputListMismatch,
4930 TosaErrorValidator.evInputListCondGraphMismatch,
4931 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4932 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4933 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004934 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004935 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004936 },
Luke Hutton57287132023-02-06 14:54:18 +00004937 "fft2d": {
4938 "op": Op.FFT2D,
4939 "operands": (2, 0),
4940 "rank": (3, 3),
4941 "build_fcn": (
4942 build_fft2d,
4943 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004944 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004945 TosaArgGen.agFFT2d,
4946 ),
4947 "types": [DType.FP32],
4948 "error_if_validators": (
4949 TosaErrorValidator.evWrongInputType,
4950 TosaErrorValidator.evWrongOutputType,
4951 TosaErrorValidator.evWrongInputList,
4952 TosaErrorValidator.evWrongOutputList,
4953 TosaErrorValidator.evWrongRank,
4954 TosaErrorValidator.evBatchMismatch,
4955 TosaErrorValidator.evKernelNotPowerOfTwo,
4956 TosaErrorValidator.evFFTInputShapeMismatch,
4957 TosaErrorValidator.evFFTOutputShapeMismatch,
4958 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004959 "data_gen": DOT_PRODUCT_DATAGEN,
Luke Hutton57287132023-02-06 14:54:18 +00004960 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004961 "rfft2d": {
4962 "op": Op.RFFT2D,
4963 "operands": (1, 0),
4964 "rank": (3, 3),
4965 "build_fcn": (
4966 build_rfft2d,
4967 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004968 TosaTensorValuesGen.tvgLazyGenDefault,
4969 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004970 ),
4971 "types": [DType.FP32],
4972 "error_if_validators": (
4973 TosaErrorValidator.evWrongInputType,
4974 TosaErrorValidator.evWrongOutputType,
4975 TosaErrorValidator.evWrongInputList,
4976 TosaErrorValidator.evWrongOutputList,
4977 TosaErrorValidator.evWrongRank,
4978 TosaErrorValidator.evBatchMismatch,
4979 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004980 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004981 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004982 "data_gen": DOT_PRODUCT_DATAGEN,
Luke Hutton261b7b62023-01-10 14:50:31 +00004983 },
Won Jeon74342e52024-01-09 00:34:40 +00004984 # Shape
4985 "add_shape": {
4986 "op": Op.ADD_SHAPE,
4987 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004988 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004989 "build_fcn": (
4990 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004991 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004992 TosaTensorValuesGen.tvgAddSub,
4993 TosaArgGen.agNone,
4994 ),
4995 "types": [DType.SHAPE],
4996 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4997 },
4998 "sub_shape": {
4999 "op": Op.SUB_SHAPE,
5000 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005001 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005002 "build_fcn": (
5003 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005004 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005005 TosaTensorValuesGen.tvgAddSub,
5006 TosaArgGen.agNone,
5007 ),
5008 "types": [DType.SHAPE],
5009 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5010 },
5011 "mul_shape": {
5012 "op": Op.MUL_SHAPE,
5013 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005014 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005015 "build_fcn": (
5016 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005017 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005018 TosaTensorValuesGen.tvgMul,
5019 TosaArgGen.agNone,
5020 ),
5021 "types": [DType.SHAPE],
5022 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5023 },
5024 "div_shape": {
5025 "op": Op.DIV_SHAPE,
5026 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005027 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005028 "build_fcn": (
5029 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005030 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005031 TosaTensorValuesGen.tvgIntDiv,
5032 TosaArgGen.agNone,
5033 ),
5034 "types": [DType.SHAPE],
5035 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5036 },
5037 "concat_shape": {
5038 "op": Op.CONCAT_SHAPE,
5039 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005040 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005041 "build_fcn": (
5042 build_concat,
5043 TosaTensorGen.tgConcat,
5044 TosaTensorValuesGen.tvgConcat,
5045 TosaArgGen.agNone,
5046 ),
5047 "types": [DType.SHAPE],
5048 "error_if_validators": (),
5049 },
5050 "const_shape": {
5051 "op": Op.CONST_SHAPE,
5052 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005053 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005054 "build_fcn": (
5055 build_const,
5056 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005057 TosaTensorValuesGen.tvgLazyGenDefault,
5058 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005059 ),
5060 "types": [DType.SHAPE],
5061 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005062 }
5063
Kevin Cheng550ccc52021-03-03 11:21:43 -08005064
Eric Kunzee5e26762020-10-13 16:11:07 -07005065class OutputShaper:
5066 # Methods in this class compute the expected output shape and datatype
5067 # for common classes of operations
5068 def __init__(self):
5069 pass
5070
5071 # These methods return arguments that can be used for
5072 # creating a new output tensor
5073 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005074 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5075 if error_name != ErrorIf.RankMismatch:
5076 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005077 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005078
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005079 # Work out broadcasted output shape (when not ERRORIF test)
Eric Kunzee5e26762020-10-13 16:11:07 -07005080 shape = []
5081 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005082 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005083 shape.append(b.shape[i])
5084 else:
5085 shape.append(a.shape[i])
5086
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005087 if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
5088 # Can only create this error for rank > 0
5089 fuzz_idx = rng.integers(0, len(shape))
Jerry Ge135c9552023-05-23 20:59:32 +00005090 shape[fuzz_idx] += 1
5091
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005092 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005093 all_dtypes = [
5094 DType.INT8,
5095 DType.INT16,
5096 DType.INT32,
5097 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005098 DType.FP16,
5099 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005100 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005101 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005102 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5103 outputDType = rng.choice(wrong_dtypes)
5104 else:
5105 outputDType = a.dtype
5106
5107 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005108
5109 @staticmethod
5110 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005111 assert len(a.shape) == len(b.shape)
5112 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005113
5114 shape = []
5115 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005116 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005117 shape.append(a.shape[i])
5118
Kevin Cheng550ccc52021-03-03 11:21:43 -08005119 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005120
5121 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005122 def unaryOp(ser, rng, a, error_name=None):
5123 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005124 all_dtypes = [
5125 DType.INT8,
5126 DType.INT16,
5127 DType.INT32,
5128 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005129 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005130 DType.FP16,
5131 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005132 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005133 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5134 outputDType = rng.choice(wrong_dtypes)
5135 else:
5136 outputDType = a.dtype
5137
5138 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005139
5140 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005141 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005142 if error_name != ErrorIf.RankMismatch:
5143 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005144 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005145
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005146 # Work out broadcasted output shape (when not ERRORIF test)
Eric Kunzee5e26762020-10-13 16:11:07 -07005147 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005148 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005149 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005150 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5151 else:
5152 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005153
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005154 if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
5155 # Can only create this error for rank > 0
5156 fuzz_idx = rng.integers(0, len(shape))
Jerry Ge135c9552023-05-23 20:59:32 +00005157 shape[fuzz_idx] += 1
5158
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005159 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005160 all_dtypes = [
5161 DType.INT8,
5162 DType.INT16,
5163 DType.INT32,
5164 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005165 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005166 DType.FP16,
5167 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005168 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005169 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5170 outputDType = rng.choice(wrong_dtypes)
5171 else:
5172 outputDType = a.dtype
5173
5174 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005175
5176 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005177 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005178 if error_name != ErrorIf.RankMismatch:
5179 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005180 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005181
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005182 # Work out broadcasted output shape
Eric Kunzee5e26762020-10-13 16:11:07 -07005183 shape = []
5184 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005185 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005186 shape.append(b.shape[i])
5187 else:
5188 shape.append(a.shape[i])
5189
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005190 if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
5191 # Can only create this error for rank > 0
5192 fuzz_idx = rng.integers(0, len(shape))
Jerry Ge135c9552023-05-23 20:59:32 +00005193 shape[fuzz_idx] += 1
5194
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005195 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005196 wrong_dtypes = [
5197 DType.INT8,
5198 DType.INT16,
5199 DType.INT32,
5200 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005201 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005202 DType.FP16,
5203 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005204 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005205 outputDType = rng.choice(wrong_dtypes)
5206 else:
5207 outputDType = DType.BOOL
5208
5209 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005210
5211 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005212 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005213 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005214 if error_name not in [
5215 ErrorIf.AxisSmallerZero,
5216 ErrorIf.AxisLargerRank,
5217 ErrorIf.ShapeOfAxisNotOne,
5218 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005219 shape[axis] = 1
5220 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5221 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005222
Matthew Haddond6ce7252021-09-29 15:35:44 +01005223 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005224 all_dtypes = [
5225 DType.INT8,
5226 DType.INT16,
5227 DType.INT32,
5228 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005229 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005230 DType.FP16,
5231 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005232 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005233 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5234 outputDType = rng.choice(wrong_dtypes)
5235 else:
5236 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005237
Matthew Haddond6ce7252021-09-29 15:35:44 +01005238 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005239
5240 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005241 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005242 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005243
5244 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5245 del shape[axis]
5246
5247 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5248 remove = rng.choice([True, False])
5249 if remove and len(shape) > 1:
5250 del shape[0]
5251 else:
5252 shape.append(1)
5253 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5254 for i in range(len(shape)):
5255 shape[i] = shape[i] + rng.integers(1, 10)
5256
5257 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005258 all_dtypes = [
5259 DType.INT8,
5260 DType.INT16,
5261 DType.INT32,
5262 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005263 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005264 DType.FP16,
5265 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005266 DType.FP8E4M3,
5267 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005268 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005269 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5270 outputDType = rng.choice(wrong_dtypes)
5271 else:
5272 outputDType = DType.INT32
5273
5274 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005275
5276 @staticmethod
Tai Lyf36f2562024-03-14 16:21:29 +00005277 def _get_conv_output_type(input_dtype):
5278 if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
5279 return input_dtype
5280 elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
5281 return DType.FP16
5282 elif input_dtype in (DType.INT8, DType.INT4):
5283 return DType.INT32
5284 elif input_dtype in (DType.INT16,):
5285 return DType.INT48
5286 assert True, f"Unsupported convolution data type {input_dtype}"
5287
5288 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005289 def conv2dOp(
5290 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5291 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005292
5293 # IFM: NHWC
5294 # Filter: OHWI
5295 # OFM: NHWC
5296
Kevin Cheng550ccc52021-03-03 11:21:43 -08005297 h = (
5298 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005299 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005300 + padding[0]
5301 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005302 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005303 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005304
Kevin Cheng550ccc52021-03-03 11:21:43 -08005305 w = (
5306 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005307 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005308 + padding[2]
5309 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005310 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005311 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005312
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005313 if error_name == ErrorIf.ConvOutputShapeMismatch:
5314 choices = [1, 2, 3]
5315 change = rng.choice(choices)
5316 # increment in multiples of stride to not hit non-integer error case
5317 if change in [1, 3]:
5318 h = h + (rng.choice(choices) * strides[0])
5319 if change in [2, 3]:
5320 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005321
Eric Kunzee5e26762020-10-13 16:11:07 -07005322 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5323
James Ward8b390432022-08-12 20:48:56 +01005324 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005325 # Pick some potentially correct output dtype if input type is incorrect
5326 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005327 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005328 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005329
5330 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005331 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005332 excludes = [DType.FP16, DType.FP32]
Jeremy Johnson80fd9b82024-03-12 11:46:50 +00005333 elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
Won Jeon2c34b462024-02-06 18:37:00 +00005334 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005335 else:
5336 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005337 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005338 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005339
Kevin Cheng550ccc52021-03-03 11:21:43 -08005340 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005341
5342 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005343 def conv3dOp(
5344 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5345 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005346
5347 # IFM: NDHWC
5348 # Filter: ODHWI
5349 # OFM: NDHWC
5350
5351 d = (
5352 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005353 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005354 + padding[0]
5355 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005356 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005357 ) // strides[0] + 1
5358
5359 h = (
5360 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005361 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005362 + padding[2]
5363 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005364 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005365 ) // strides[1] + 1
5366
5367 w = (
5368 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005369 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005370 + padding[4]
5371 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005372 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005373 ) // strides[2] + 1
5374
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005375 if error_name == ErrorIf.ConvOutputShapeMismatch:
5376 choices = [1, 2, 3, 4]
5377 change = rng.choice(choices)
5378 # increment in multiples of stride to not hit non-integer error case
5379 if change in [1, 4]:
5380 d = d + (rng.choice(choices) * strides[0])
5381 if change in [2, 4]:
5382 h = h + (rng.choice(choices) * strides[1])
5383 if change in [3, 4]:
5384 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005385
Kevin Cheng1533b852021-09-01 12:51:58 -07005386 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5387
James Ward8b390432022-08-12 20:48:56 +01005388 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005389 # Pick some potentially correct output dtype if input type is incorrect
5390 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005391 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005392 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005393
5394 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005395 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005396 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005397 else:
5398 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005399 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005400 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005401
5402 return ser.addOutput(ofm_shape, out_dtype)
5403
5404 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005405 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005406 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005407 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005408 # IFM: NHWC
5409 # Filter: HWCM
5410 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005411
Kevin Cheng550ccc52021-03-03 11:21:43 -08005412 h = (
5413 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005414 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005415 + padding[0]
5416 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005417 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005418 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005419
Kevin Cheng550ccc52021-03-03 11:21:43 -08005420 w = (
5421 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005422 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005423 + padding[2]
5424 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005425 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005426 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005427
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005428 if error_name == ErrorIf.ConvOutputShapeMismatch:
5429 choices = [1, 2, 3]
5430 change = rng.choice(choices)
5431 # increment in multiples of stride to not hit non-integer error case
5432 if change in [1, 3]:
5433 h = h + (rng.choice(choices) * strides[0])
5434 if change in [2, 3]:
5435 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005436
Eric Kunzee5e26762020-10-13 16:11:07 -07005437 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5438
James Ward8b390432022-08-12 20:48:56 +01005439 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005440 # Pick some potentially correct output dtype if input type is incorrect
5441 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005442 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005443 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005444
5445 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005446 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005447 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005448 else:
5449 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005450 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005451 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005452
Kevin Cheng550ccc52021-03-03 11:21:43 -08005453 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005454
5455 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005456 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005457 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005458 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005459 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005460 h = 1
5461 w = 1
5462 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005463 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5464 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005465
5466 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005467 choices = [1, 2, 3]
5468 change = rng.choice(choices)
5469 # increment in multiples of stride to not hit non-integer error case
5470 if change in [1, 3]:
5471 h = h + (rng.choice(choices) * stride[0])
5472 if change in [2, 3]:
5473 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005474 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005475
5476 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005477 all_dtypes = [
5478 DType.INT8,
5479 DType.INT16,
5480 DType.INT32,
5481 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005482 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005483 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005484 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005485 DType.FP8E4M3,
5486 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005487 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005488 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5489 outputDType = rng.choice(wrong_dtypes)
5490 else:
5491 outputDType = ifm.dtype
5492
5493 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005494
5495 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005496 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005497 # input: N, IC
5498 # filter: OC, IC
5499 # output: N, OC
5500
5501 output_shape = [input.shape[0], filter.shape[0]]
5502
James Ward8b390432022-08-12 20:48:56 +01005503 # Validated in arg_gen (also invalidated for ErrorIf)
5504 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005505
Kevin Cheng550ccc52021-03-03 11:21:43 -08005506 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005507
5508 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005509 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005510 # a: N, H, C
5511 # b: N, C, W
5512 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005513
Kevin Cheng2d60f002021-06-09 14:18:32 -07005514 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005515
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005516 if error_name == ErrorIf.WrongOutputType:
5517 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005518 incorrect_types = (
5519 DType.INT4,
5520 DType.INT8,
5521 DType.INT16,
5522 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005523 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005524 DType.FP16,
5525 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005526 DType.FP8E4M3,
5527 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005528 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005529 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005530 incorrect_types = (
5531 DType.INT4,
5532 DType.INT8,
5533 DType.INT16,
5534 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005535 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005536 DType.FP16,
5537 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005538 DType.FP8E4M3,
5539 DType.FP8E5M2,
5540 )
5541 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5542 incorrect_types = (
5543 DType.INT4,
5544 DType.INT8,
5545 DType.INT16,
5546 DType.INT32,
5547 DType.INT48,
5548 DType.FP32,
5549 DType.BF16,
5550 DType.FP8E4M3,
5551 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005552 )
James Ward24dbc422022-10-19 12:20:31 +01005553 elif (
5554 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5555 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005556 incorrect_types = (
5557 DType.INT4,
5558 DType.INT8,
5559 DType.INT16,
5560 DType.INT32,
5561 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005562 DType.FP8E4M3,
5563 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005564 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005565 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005566 elif error_name == ErrorIf.WrongInputType:
5567 # Pick some potentially correct output dtype if input type is incorrect
5568 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005569 else:
James Ward8b390432022-08-12 20:48:56 +01005570 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005571
Kevin Cheng550ccc52021-03-03 11:21:43 -08005572 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005573
5574 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005575 def concatOp(ser, rng, axis, inputs, error_name=None):
5576 input1 = inputs[0]
5577 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005578
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005579 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005580 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005581 if not (
5582 # unable to concat tensors of different ranks
5583 error_name == ErrorIf.ConcatInputRankMismatch
5584 # unable to concat tensors along an invalid axis
5585 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005586 ):
5587 for tensor in remaining_inputs:
5588 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005589
Matthew Haddon01c359d2021-10-15 16:30:48 +01005590 if error_name == ErrorIf.ConcatShapeSumMismatch:
5591 output_shape[axis] += rng.integers(5, 10)
5592
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005593 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005594 all_dtypes = {
5595 DType.INT8,
5596 DType.INT16,
5597 DType.INT32,
5598 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005599 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005600 DType.FP16,
5601 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005602 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005603 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5604 outputDType = rng.choice(wrong_dtypes)
5605 else:
5606 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005607
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005608 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005609
5610 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005611 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005612
5613 output_shape = a.shape.copy()
5614
5615 for i in range(len(output_shape)):
5616 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5617
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005618 if error_name == ErrorIf.PadOutputShapeMismatch:
5619 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005620 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005621 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005622 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005623
Matthew Haddone807aae2021-10-11 18:12:58 +01005624 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005625 all_dtypes = [
5626 DType.INT8,
5627 DType.INT16,
5628 DType.INT32,
5629 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005630 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005631 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005632 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005633 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005634 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5635 outputDType = rng.choice(wrong_dtypes)
5636 else:
5637 outputDType = a.dtype
5638
5639 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005640
5641 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005642 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005643 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005644
5645 if error_name == ErrorIf.WrongOutputType:
5646 all_dtypes = [
5647 DType.INT8,
5648 DType.INT16,
5649 DType.INT32,
5650 DType.INT48,
5651 DType.FP32,
5652 DType.FP16,
5653 DType.BF16,
5654 ]
5655 wrong_dtypes = list(set(all_dtypes))
5656 outputDType = rng.choice(wrong_dtypes)
5657 else:
5658 outputDType = DType.SHAPE
5659
5660 return ser.addOutput(output_shape, outputDType)
5661
5662 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005663 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005664 output_shape = shape.copy()
5665
Matthew Haddone807aae2021-10-11 18:12:58 +01005666 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5667 for i in range(len(output_shape)):
5668 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5669
5670 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005671 all_dtypes = [
5672 DType.INT8,
5673 DType.INT16,
5674 DType.INT32,
5675 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005676 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005677 DType.FP16,
5678 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005679 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005680 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5681 outputDType = rng.choice(wrong_dtypes)
5682 else:
5683 outputDType = a.dtype
5684
5685 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005686
5687 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005688 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005689
Matthew Haddone807aae2021-10-11 18:12:58 +01005690 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005691 all_dtypes = [
5692 DType.INT8,
5693 DType.INT16,
5694 DType.INT32,
5695 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005696 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005697 DType.FP16,
5698 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005699 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005700 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005701 outputDType = rng.choice(wrong_dtypes)
5702 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005703 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005704
Luke Huttona4e48ca2023-02-22 11:53:48 +00005705 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005706 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005707 for index in range(len(output_shape)):
5708 if output_shape[index] <= 2:
5709 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5710 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005711 output_shape[index] = output_shape[index] + rng.choice(
5712 [-2, -1, 1, 2]
5713 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005714 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5715 output_shape = input.shape.copy()
5716 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005717 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005718
5719 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005720
5721 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005722 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005723
5724 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005725 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005726
5727 for i in range(len(output_shape)):
5728 output_shape[i] = a.shape[i] * multiples[i]
5729
Luke Huttona4e48ca2023-02-22 11:53:48 +00005730 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005731 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005732
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005733 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005734 all_dtypes = [
5735 DType.INT8,
5736 DType.INT16,
5737 DType.INT32,
5738 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005739 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005740 DType.FP16,
5741 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005742 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005743 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5744 outputDType = rng.choice(wrong_dtypes)
5745 else:
5746 outputDType = a.dtype
5747
5748 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005749
5750 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005751 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005752 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005753
Kevin Cheng550ccc52021-03-03 11:21:43 -08005754 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005755
Luke Huttona4e48ca2023-02-22 11:53:48 +00005756 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005757 for i in range(len(output_shape)):
5758 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005759
Luke Huttona4e48ca2023-02-22 11:53:48 +00005760 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5761 for i in range(len(output_shape)):
5762 output_shape[i] += rng.integers(1, 10)
5763 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005764 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005765
Matthew Haddone807aae2021-10-11 18:12:58 +01005766 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005767 all_dtypes = [
5768 DType.INT8,
5769 DType.INT16,
5770 DType.INT32,
5771 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005772 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005773 DType.FP16,
5774 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005775 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005776 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5777 outputDType = rng.choice(wrong_dtypes)
5778 else:
5779 outputDType = a.dtype
5780
5781 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005782
5783 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005784 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005785 if error_name != ErrorIf.WrongRank:
5786 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005787 assert len(indices.shape) == 2
5788 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005789
Kevin Cheng77d0f762020-11-24 10:26:32 -08005790 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5791
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005792 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005793 all_dtypes = [
5794 DType.INT8,
5795 DType.INT16,
5796 DType.INT32,
5797 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005798 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005799 DType.FP16,
5800 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005801 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005802 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5803 outputDType = rng.choice(wrong_dtypes)
5804 else:
5805 outputDType = values.dtype
5806
5807 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005808
5809 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005810 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005811 if error_name != ErrorIf.WrongRank:
5812 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005813 assert len(indices.shape) == 2
5814 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005815 assert values_in.shape[0] == indices.shape[0] # N
5816 assert input.shape[1] == indices.shape[1] # W
5817 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005818
5819 output_shape = values_in.shape
5820
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005821 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005822 all_dtypes = [
5823 DType.INT8,
5824 DType.INT16,
5825 DType.INT32,
5826 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005827 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005828 DType.FP16,
5829 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005830 DType.FP8E4M3,
5831 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005832 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005833 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5834 outputDType = rng.choice(wrong_dtypes)
5835 else:
5836 outputDType = values_in.dtype
5837
5838 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005839
5840 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005841 def tableOp(ser, rng, input, error_name=None):
5842 # Same shape as the input, dtype dependent on input dtype
5843 if error_name != ErrorIf.WrongInputType:
5844 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005845 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005846 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005847 wrong_dtypes = [
5848 DType.INT8,
5849 DType.INT16,
5850 DType.INT32,
5851 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005852 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005853 DType.FP16,
5854 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005855 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005856 wrong_dtypes.remove(output_dtype)
5857 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005858 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005859
5860 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005861 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005862 serializer,
5863 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005864 input,
5865 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005866 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005867 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005868 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005869 input_dtype,
5870 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005871 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005872 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005873 # Calculate OH, OW
5874 scale_y_n = scale[0]
5875 scale_y_d = scale[1]
5876 scale_x_n = scale[2]
5877 scale_x_d = scale[3]
5878 if error_name == ErrorIf.ScaleSmallerEqualZero:
5879 scale_y_n = max(scale_y_n, 1)
5880 scale_y_d = max(scale_y_d, 1)
5881 scale_x_n = max(scale_x_n, 1)
5882 scale_x_d = max(scale_x_d, 1)
5883
5884 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5885 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5886
5887 if error_name is not None:
5888 # Make sure the output tensor is valid, which can occur when
5889 # scale, offset or border have been changed for ERROR_IFs
5890 oh = max(oh, 1)
5891 ow = max(ow, 1)
5892 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005893 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5894 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005895
5896 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5897 choices = [1, 2, 3]
5898 change = rng.choice(choices)
5899 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5900 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005901 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005902 oh -= scale_y_d
5903 assert oh > 0 # Should have been caught in agResize
5904 else:
5905 oh += scale_y_d
5906 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005907 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005908 ow -= scale_x_d
5909 assert ow > 0 # Should have been caught in agResize
5910 else:
5911 ow += scale_x_d
5912
Matthew Haddon848efb42021-09-09 12:30:53 +01005913 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005914 output_dims = [
5915 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005916 oh,
5917 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005918 input.shape[0],
5919 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005920 elif error_name == ErrorIf.BatchMismatch:
5921 output_dims = [
5922 input.shape[0] + rng.integers(1, 10),
5923 oh,
5924 ow,
5925 input.shape[3],
5926 ]
5927 elif error_name == ErrorIf.ChannelMismatch:
5928 output_dims = [
5929 input.shape[0],
5930 oh,
5931 ow,
5932 input.shape[3] + rng.integers(1, 10),
5933 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005934 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005935 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005936
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005937 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005938
5939 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005940 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005941 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005942
5943 @staticmethod
Suraj Sudhirb5fcfc02024-04-16 16:14:36 -07005944 def transposeConv2DOp(
5945 ser, rng, ifm, filter, accum_dtype, strides, padding, error_name=None
5946 ):
5947
5948 h = (ifm.shape[1] - 1) * strides[0] + padding[0] + padding[1] + filter.shape[1]
5949
5950 w = (ifm.shape[2] - 1) * strides[1] + padding[2] + padding[3] + filter.shape[2]
5951
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005952 if error_name == ErrorIf.ConvOutputShapeMismatch:
5953 choices = [1, 2, 3]
5954 change = rng.choice(choices)
5955 if change in [1, 3]:
Suraj Sudhirb5fcfc02024-04-16 16:14:36 -07005956 h = h + rng.choice(choices)
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005957 if change in [2, 3]:
Suraj Sudhirb5fcfc02024-04-16 16:14:36 -07005958 w = w + rng.choice(choices)
5959
5960 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005961
James Ward8b390432022-08-12 20:48:56 +01005962 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005963 # Pick some potentially correct output dtype if input type is incorrect
5964 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005965 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005966 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005967
5968 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005969 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005970 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005971 else:
5972 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005973 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005974 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005975
Suraj Sudhirb5fcfc02024-04-16 16:14:36 -07005976 return ser.addOutput(ofm_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005977
5978 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005979 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5980 outputs = []
5981
5982 assert ifm1.dtype == ifm2.dtype
5983 input_dtype = ifm1.dtype
5984
5985 if error_name != ErrorIf.FFTInputShapeMismatch:
5986 assert ifm1.shape == ifm2.shape
5987
5988 input_shape = ifm1.shape
5989 if error_name != ErrorIf.WrongRank:
5990 assert len(input_shape) == 3
5991
5992 output_shape = input_shape.copy()
5993 output_dtype = input_dtype
5994
5995 if error_name == ErrorIf.WrongOutputType:
5996 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005997 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005998 output_dtype = rng.choice(wrong_dtypes)
5999 elif error_name == ErrorIf.BatchMismatch:
6000 output_shape[0] += rng.integers(1, 10)
6001 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6002 modify_dim = rng.choice([1, 2])
6003 output_shape[modify_dim] += rng.integers(1, 10)
6004
6005 outputs.append(serializer.addOutput(output_shape, output_dtype))
6006 outputs.append(serializer.addOutput(output_shape, output_dtype))
6007 return outputs
6008
6009 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00006010 def rfft2dOp(serializer, rng, value, error_name=None):
6011 outputs = []
6012
6013 input_shape = value.shape
6014 if error_name != ErrorIf.WrongRank:
6015 assert len(input_shape) == 3
6016
6017 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
6018
6019 output_dtype = value.dtype
6020 if error_name == ErrorIf.WrongOutputType:
6021 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006022 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00006023 output_dtype = rng.choice(wrong_dtypes)
6024 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00006025 output_shape[0] += rng.integers(1, 10)
6026 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6027 modify_dim = rng.choice([1, 2])
6028 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00006029
6030 outputs.append(serializer.addOutput(output_shape, output_dtype))
6031 outputs.append(serializer.addOutput(output_shape, output_dtype))
6032 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00006033
6034 @staticmethod
6035 def addShapeOp(ser, rng, a, b, error_name=None):
6036 if error_name != ErrorIf.RankMismatch:
6037 assert len(a.shape) == len(b.shape)
6038 assert a.dtype == b.dtype
6039
Jeremy Johnson18a379d2024-03-28 15:53:21 +00006040 shape = a.shape.copy()
Won Jeon74342e52024-01-09 00:34:40 +00006041
Jeremy Johnson18a379d2024-03-28 15:53:21 +00006042 # Do not expect rank 0 tests!
6043 assert len(shape) > 0
Won Jeon74342e52024-01-09 00:34:40 +00006044 if error_name == ErrorIf.DimensionMismatch:
Jeremy Johnson18a379d2024-03-28 15:53:21 +00006045 # Can only create this error for rank > 0
6046 fuzz_idx = rng.integers(0, len(shape))
Won Jeon74342e52024-01-09 00:34:40 +00006047 shape[fuzz_idx] += 1
6048
6049 if error_name == ErrorIf.WrongOutputType:
6050 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
6051 outputDType = rng.choice(wrong_dtypes)
6052 else:
6053 outputDType = DType.SHAPE
6054 return ser.addOutput(shape, outputDType)