blob: 40788a2b458714df493d55d8b4a560f6ebd048ce [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005import os
Matthew Haddon630c17c2021-10-14 15:05:41 +01006from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01007from datetime import datetime
8from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -07009
Jeremy Johnson1271c442023-09-05 11:39:26 +010010import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000011import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000012import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010013from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010014from generator.tosa_arg_gen import TosaArgGen
15from generator.tosa_arg_gen import TosaQuantGen
16from generator.tosa_arg_gen import TosaTensorGen
17from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000018from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010019from generator.tosa_error_if import TosaErrorIfArgGen
20from generator.tosa_error_if import TosaErrorValidator
21from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010022from generator.tosa_random_gen import TosaHashRandomGenerator
23from generator.tosa_random_gen import TosaRandomGenerator
Jeremy Johnson1271c442023-09-05 11:39:26 +010024from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000025from tosa.DType import DType
26from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010027
Jeremy Johnson1271c442023-09-05 11:39:26 +010028TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
29// SPDX-License-Identifier: Apache-2.0
30// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
31"""
32
Jeremy Johnsonaf090182024-02-13 18:25:39 +000033logging.basicConfig()
34logger = logging.getLogger("tosa_verif_build_tests")
35
Matthew Haddonb724efc2021-08-25 16:40:29 +010036
Eric Kunzee5e26762020-10-13 16:11:07 -070037class TosaTestGen:
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000038 # This currently matches the 8K level defined in the specification.
Jeremy Johnsonb2099702023-04-12 15:59:01 +010039 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010040 TOSA_8K_LEVEL_MAX_KERNEL = 8192
41 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010042
Jeremy Johnson1271c442023-09-05 11:39:26 +010043 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000044 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010045 TOSA_MI_DOT_PRODUCT_MIN = 1000
46
Eric Kunzee5e26762020-10-13 16:11:07 -070047 def __init__(self, args):
48 self.args = args
49 self.basePath = args.output_dir
50 self.random_seed = args.random_seed
51 self.ser = None
Eric Kunzee5e26762020-10-13 16:11:07 -070052 self.createDynamicOpLists()
53 self.initOpListDefaults()
54 self.quantGen = TosaQuantGen()
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010055 self.global_rng = None
Eric Kunzee5e26762020-10-13 16:11:07 -070056 # Force makeShape to do a specific starting shape
57 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010058 # JSON schema validation
59 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010060 # Data generator library is sometimes needed for compliance set up
61 # even if we are generating the data later (lazy_data_generation)
62 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070063
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010064 # Work out floating point range
65 def convertFPRange(rangeFP, maxFP):
66 # Converts program arguments of max/-max to FP max
67 vals = []
68 for v in rangeFP:
69 if v == "max":
70 v = maxFP
71 elif v == "-max":
72 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000073 elif v < 0:
74 # Trim to minimum data type value
75 v = max(v, -maxFP)
76 elif v > 0:
77 # Trim to maximum data type value
78 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010079 vals.append(v)
80 return tuple(sorted(vals))
81
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010082 self.random_dtype_range = {
83 DType.SHAPE: tuple(self.args.tensor_shape_range[0:2])
84 }
Won Jeon2c34b462024-02-06 18:37:00 +000085 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010086 self.random_dtype_range[dtype] = convertFPRange(
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010087 args.tensor_fp_value_range,
88 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
89 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010090 self.resetGlobalRNG()
91
92 def resetGlobalRNG(self):
93 self.global_rng = TosaRandomGenerator(self.random_seed, self.random_dtype_range)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010094
Eric Kunzee5e26762020-10-13 16:11:07 -070095 def createSerializer(self, opName, testPath):
96 self.testPath = os.path.join(opName, testPath)
97
98 fullPath = os.path.join(self.basePath, self.testPath)
99 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100100 # Embed const data in the flatbuffer
101 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +0100102 if self.args.lazy_data_gen:
103 # Lazy data generation - so make constants files
104 constMode = ts.ConstMode.INPUTS
105 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100106 constMode = ts.ConstMode.EMBED_DUMP
107 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -0700108
109 def getSerializer(self):
110 return self.ser
111
evacha01ad8e1e22024-03-19 12:42:17 +0000112 def serialize(self, testName, metaData=None, tags=None):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100113 path = Path(self.basePath) / self.testPath
114
115 # Write out TOSA flatbuffer binary
116 path_fb = path / f"{testName}.tosa"
117 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700118 fd.write(self.ser.serialize())
119
Jeremy Johnson1271c442023-09-05 11:39:26 +0100120 # Get JSON descriptor from serializer
121 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
122
123 if metaData:
124 # Add extra meta data to desc.json
125 desc["meta"] = metaData
126
evacha01ad8e1e22024-03-19 12:42:17 +0000127 if tags:
128 desc["tag"] = tags
129
Jeremy Johnson1271c442023-09-05 11:39:26 +0100130 # Validate desc.json before we output it
131 self.descSchemaValidator.validate_config(desc)
132
133 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100134 if "data_gen" in metaData:
135 if self.args.lazy_data_gen:
136 # Output datagen meta data as CPP data
137 path_md = path / f"{testName}_meta_data_gen.cpp"
138 with path_md.open("w") as fd:
139 fd.write(TOSA_AUTOGENERATED_HEADER)
140 fd.write("// Test meta data for data generation setup\n\n")
141 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
142 json.dump(metaData["data_gen"], fd)
143 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100144 if "compliance" in metaData:
145 # Output datagen meta data as CPP data
146 path_md = path / f"{testName}_meta_compliance.cpp"
147 with path_md.open("w") as fd:
148 fd.write(TOSA_AUTOGENERATED_HEADER)
149 fd.write("// Test meta data for compliance validation\n\n")
150 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
151 json.dump(metaData["compliance"], fd)
152 fd.write(')";\n\n')
153
154 # Write desc.json
155 path_desc = path / "desc.json"
156 with path_desc.open("w") as fd:
157 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700158
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100159 def buildPlaceholderTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700160 placeholders = []
161
Kevin Cheng989cb052021-04-28 16:29:44 -0700162 assert len(shape_list) == len(dtype_list)
163
Jeremy Johnson1271c442023-09-05 11:39:26 +0100164 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700165 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100166 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100167 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700168 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700169
170 return placeholders
171
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100172 def buildConstTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 consts = []
174
Kevin Cheng989cb052021-04-28 16:29:44 -0700175 assert len(shape_list) == len(dtype_list)
176
Jeremy Johnson1271c442023-09-05 11:39:26 +0100177 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700178 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100179 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100180 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700181 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700182
183 return consts
184
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100185 def makeShape(self, rng, rank):
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 if self.targetted_shape:
187 return np.int32(self.targetted_shape)
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000188 else:
189 return np.int32(
190 rng.integers(
191 low=self.args.tensor_shape_range[0],
192 high=self.args.tensor_shape_range[1],
193 size=rank,
194 )
Kevin Cheng550ccc52021-03-03 11:21:43 -0800195 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700196
197 def setTargetShape(self, shape):
198 self.targetted_shape = shape
199
Eric Kunzee5e26762020-10-13 16:11:07 -0700200 def shapeStr(self, shape):
Jeremy Johnson18a379d2024-03-28 15:53:21 +0000201 assert shape is not None
202 if len(shape) > 0:
203 # Rank > 0
204 return "x".join([str(d) for d in shape])
205 else:
206 # Rank 0
207 return "0"
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100209 def typeStr(self, dtype):
210 if isinstance(dtype, list) or isinstance(dtype, tuple):
211 assert len(dtype) >= 2
212 strs = [self.typeStr(t) for t in dtype]
213 # Limit types to the first 2 as the 3rd is the accumulator
214 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700215 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100216 if dtype in gtu.DTYPE_ATTRIBUTES:
217 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100219 raise Exception(
220 "Unknown dtype, cannot convert to string: {}".format(dtype)
221 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700222
Luke Hutton57287132023-02-06 14:54:18 +0000223 def constrictBatchSize(self, shape):
224 # Limit the batch size unless an explicit target shape set
225 if self.args.max_batch_size and not self.args.target_shapes:
226 shape[0] = min(shape[0], self.args.max_batch_size)
227 return shape
228
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100229 def makeDimension(self, rng):
230 return rng.randInt(
James Ward30124a82023-02-02 14:56:33 +0000231 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
232 )
233
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100234 def tensorComplianceMetaData(
235 self, op, inputType, argsDict, outputTensor, errorName
236 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000237 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
238 UNSUPPORTED_NON_FP32_INPUT_OPS = (
239 Op.MATMUL,
240 Op.CONV2D,
241 Op.FULLY_CONNECTED,
242 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000243 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000244 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000245 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100246 if (
247 errorName
248 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000249 or (
250 not gtu.dtypeIsSupportedByCompliance(inputType)
251 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
252 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100253 ):
254 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100255 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100256
Jeremy Johnson1271c442023-09-05 11:39:26 +0100257 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100258 compliance_tens = {
259 "mode": None,
260 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
261 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
262 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
264 mode = gtu.ComplianceMode.DOT_PRODUCT
265 compliance_tens["dot_product_info"] = {
266 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100267 "ks": int(argsDict["ksb"])
268 if "ksb" in argsDict
269 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100270 }
evacha014a205112024-03-08 16:39:24 +0000271 elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100272 mode = gtu.ComplianceMode.FP_SPECIAL
273 elif "compliance" in op and "ulp" in op["compliance"]:
274 mode = gtu.ComplianceMode.ULP
275 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000276 elif "compliance" in op and "relative" in op["compliance"]:
277 mode = gtu.ComplianceMode.RELATIVE
278 compliance_tens["relative_info"] = {
279 "max": argsDict["max_abs_value"],
280 "scale": op["compliance"]["relative"],
281 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100282 elif op["op"] == Op.REDUCE_PRODUCT:
283 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000284 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000285 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000286 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000287 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
288 compliance_tens["abs_error_info"] = {
289 "lower_bound": op["compliance"]["abs_error_lower_bound"]
290 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800291 elif op["op"] in (Op.SIN, Op.COS):
292 mode = gtu.ComplianceMode.ABS_ERROR
293 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
294 compliance_tens["abs_error_info"] = {
295 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
296 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100297 else:
298 mode = gtu.ComplianceMode.EXACT
299 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
300
301 return compliance_tens
302
303 # Build Op functions
304 # Create the output tensor (calling OutputShaper as needed)
305 # Do final tweaks to attributes (if necessary for errorIf)
306 # Add Op into graph
307 # Return resulting tensor information or BuildInfo
308
309 class BuildInfo:
310 """Enhanced build information containing result tensor and associated compliance dict."""
311
312 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000313 if isinstance(resultTensor, list):
314 assert complianceDict is None or isinstance(complianceDict, list)
315 self.resultTensorList = resultTensor
316 self.complianceDictList = complianceDict
317 else:
318 self.resultTensorList = [resultTensor]
319 if complianceDict is None:
320 self.complianceDictList = None
321 else:
322 self.complianceDictList = [complianceDict]
323
324 def getComplianceInfo(self):
325 if self.complianceDictList is None:
326 return None
327 else:
328 tens_dict = {}
329 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
330 if comp is not None:
331 tens_dict[tens.name] = comp
332
333 if tens_dict:
334 # Have some compliance data, so return the info
335 compliance = {
336 "version": "0.1",
337 "tensors": tens_dict,
338 }
339 else:
340 compliance = None
341 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700342
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000343 def build_unary(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100344 self,
345 rng,
346 op,
347 inputs,
348 args_dict,
349 validator_fcns=None,
350 error_name=None,
351 qinfo=None,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000352 ):
353 assert len(inputs) == 1
354 a = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100355 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100356
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000357 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100358
359 # Ensure new output type has correct qinfo
360 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000361 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000362 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100363 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, a.dtype),
364 TosaQuantGen.getZeroPoint(
365 rng, self.args.zeropoint, result_tensor.dtype
366 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000367 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100368
369 # Invalidate Input/Output list for error if checks.
370 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000371 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100372 pCount, cCount = op["operands"]
373 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000374 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100375 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000376 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100377
Les Bell729b0352021-11-24 10:28:21 +0000378 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100379 self.ser,
380 validator_fcns,
381 error_name,
382 op=op,
383 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000384 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000385 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000386 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100387 input_list=input_list,
388 output_list=output_list,
389 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000390 ):
391 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000393 attr = None
394 if op["op"] == Op.NEGATE:
395 attr = ts.TosaSerializerAttribute()
396 attr.NegateAttribute(qinfo[0], qinfo[1])
397
398 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000399
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000400 compliance = self.tensorComplianceMetaData(
401 op, a.dtype, args_dict, result_tensor, error_name
402 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000403 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700404
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000405 def build_binary_broadcast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100406 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000407 ):
408 assert len(inputs) == 2
409 a, b = inputs
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100410 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100411
412 # Invalidate Input/Output list for error if checks.
413 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000414 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100415 pCount, cCount = op["operands"]
416 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000417 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100418 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000419 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100420
Les Bell729b0352021-11-24 10:28:21 +0000421 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100422 self.ser,
423 validator_fcns,
424 error_name,
425 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000426 input1=a,
427 input2=b,
428 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000429 output_dtype=result_tensor.dtype,
430 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100431 input_list=input_list,
432 output_list=output_list,
433 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000434 ):
435 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100436
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000437 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000438
Jeremy Johnson9a758382023-11-07 16:27:35 +0000439 compliance = self.tensorComplianceMetaData(
440 op, a.dtype, args_dict, result_tensor, error_name
441 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000442
443 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700444
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000445 def build_arithmetic_right_shift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100446 self,
447 rng,
448 op,
449 inputs,
450 args_dict,
451 validator_fcns=None,
452 error_name=None,
453 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000454 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000455 assert len(inputs) == 2
456 a, b = inputs
457 round = args_dict["round"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100458 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100459
460 # Invalidate Input/Output list for error if checks.
461 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000462 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100463 pCount, cCount = op["operands"]
464 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000465 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100466 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000467 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100468
Les Bell729b0352021-11-24 10:28:21 +0000469 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100470 self.ser,
471 validator_fcns,
472 error_name,
473 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 input1=a,
475 input2=b,
476 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000477 output_dtype=result_tensor.dtype,
478 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100479 input_list=input_list,
480 output_list=output_list,
481 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000482 ):
483 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800484
485 attr = ts.TosaSerializerAttribute()
486 attr.ArithmeticRightShiftAttribute(round)
487
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000488 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000489
490 compliance = self.tensorComplianceMetaData(
491 op, a.dtype, args_dict, result_tensor, error_name
492 )
493
494 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800495
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100496 def build_mul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100497 self,
498 rng,
499 op,
500 inputs,
501 args_dict,
502 validator_fcns=None,
503 error_name=None,
504 qinfo=None,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100505 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000506 # Note that mul is binary operator but it has a shift value tensor
507 assert len(inputs) == 3
508 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100509
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100510 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700511
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100512 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100513 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100514 result_tensor.setDtype(DType.INT32)
515
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100516 if error_name == ErrorIf.WrongOutputType:
517 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100518 outputDType = rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100519 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100520
521 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000522 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100523 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100524 pCount, cCount = op["operands"]
525 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000526 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100527 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100529
Les Bell729b0352021-11-24 10:28:21 +0000530 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100531 self.ser,
532 validator_fcns,
533 error_name,
534 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000535 input1=a,
536 input2=b,
537 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100538 output_dtype=result_tensor.dtype,
539 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100540 input_list=input_list,
541 output_list=output_list,
542 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000543 ):
544 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700545
Jeremy Johnson0a042992024-02-28 13:20:05 +0000546 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100547
548 compliance = self.tensorComplianceMetaData(
549 op, a.dtype, args_dict, result_tensor, error_name
550 )
551
552 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Jeremy Johnson587cc842024-02-08 11:45:44 +0000554 def build_table(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100555 self,
556 rng,
557 op,
558 inputs,
559 args_dict,
560 validator_fcns=None,
561 error_name=None,
562 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000563 ):
564 assert len(inputs) == 1
565 a = inputs[0]
566 table = args_dict["table"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100567 result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700568
Kevin Chengfe392ce2021-10-18 21:51:55 +0000569 attr = ts.TosaSerializerAttribute()
570 attr.TableAttribute(table)
571
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100572 # Invalidate Input/Output list for error if checks.
573 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000574 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 pCount, cCount = op["operands"]
576 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000577 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100578 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000579 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100580
Les Bell729b0352021-11-24 10:28:21 +0000581 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100582 self.ser,
583 validator_fcns,
584 error_name,
585 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000586 input_shape=a.shape,
587 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000588 output_dtype=result_tensor.dtype,
589 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100590 input_list=input_list,
591 output_list=output_list,
592 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000593 ):
594 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100595
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000596 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700597
Jeremy Johnson587cc842024-02-08 11:45:44 +0000598 compliance = self.tensorComplianceMetaData(
599 op, a.dtype, args_dict, result_tensor, error_name
600 )
601
602 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700603
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000604 def build_select(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100605 self,
606 rng,
607 op,
608 inputs,
609 args_dict,
610 validator_fcns=None,
611 error_name=None,
612 qinfo=None,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000613 ):
614 assert len(inputs) == 3
615 cond, a, b = inputs
616
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100617 result_tensor = OutputShaper.selectOp(self.ser, rng, cond, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100618
619 # Invalidate Input/Output list for error if checks.
620 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000621 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622 pCount, cCount = op["operands"]
623 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000624 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100625 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100627
Les Bell729b0352021-11-24 10:28:21 +0000628 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100629 self.ser,
630 validator_fcns,
631 error_name,
632 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000633 input1=cond,
634 input2=a,
635 input3=b,
636 input_shape=a.shape,
637 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000638 output_dtype=result_tensor.dtype,
639 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100640 input_list=input_list,
641 output_list=output_list,
642 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000643 ):
644 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100645
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000646 self.ser.addOperator(
647 op["op"],
648 input_list,
649 output_list,
650 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000651 compliance = self.tensorComplianceMetaData(
652 op, a.dtype, args_dict, result_tensor, error_name
653 )
654
655 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700656
Jeremy Johnsona0150012023-11-15 15:52:06 +0000657 def build_comparison(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100658 self,
659 rng,
660 op,
661 inputs,
662 args_dict,
663 validator_fcns=None,
664 error_name=None,
665 qinfo=None,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000666 ):
667 assert len(inputs) == 2
668 a, b = inputs
669
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100670 result_tensor = OutputShaper.binaryComparisonOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100671
672 # Invalidate Input/Output list for error if checks.
673 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000674 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100675 pCount, cCount = op["operands"]
676 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000677 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100678 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000679 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100680
Les Bell729b0352021-11-24 10:28:21 +0000681 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100682 self.ser,
683 validator_fcns,
684 error_name,
685 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000686 input1=a,
687 input2=b,
688 input_shape=a.shape,
689 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000690 output_shape=result_tensor.shape,
691 output_dtype=result_tensor.dtype,
692 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100693 input_list=input_list,
694 output_list=output_list,
695 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000696 ):
697 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100698
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000699 self.ser.addOperator(
700 op["op"],
701 input_list,
702 output_list,
703 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000704
705 compliance = self.tensorComplianceMetaData(
706 op, a.dtype, args_dict, result_tensor, error_name
707 )
708 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700709
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000710 def build_argmax(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100711 self, rng, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000712 ):
713 assert len(inputs) == 1
714 a = inputs[0]
715 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100716 result_tensor = OutputShaper.argmaxOp(self.ser, rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100717
718 # Invalidate Input/Output list for error if checks.
719 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000720 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100721 pCount, cCount = op["operands"]
722 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000723 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100724 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000725 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100726
Les Bell729b0352021-11-24 10:28:21 +0000727 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100728 self.ser,
729 validator_fcns,
730 error_name,
731 op=op,
732 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000733 input_shape=a.shape,
734 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000735 output_shape=result_tensor.shape,
736 output_dtype=result_tensor.dtype,
737 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100738 input_list=input_list,
739 output_list=output_list,
740 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000741 ):
742 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
744 attr = ts.TosaSerializerAttribute()
745 attr.AxisAttribute(axis)
746
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000747 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000748
749 compliance = self.tensorComplianceMetaData(
750 op, inputs[0].dtype, args_dict, result_tensor, error_name
751 )
752 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700753
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000754 def build_pool2d(
755 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100756 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000757 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100758 inputs,
759 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000760 validator_fcns=None,
761 error_name=None,
762 qinfo=None,
763 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100764 assert len(inputs) == 1
765 input = inputs[0]
766 # max_pool has no accum_dtype
767 accum_dtype = (
768 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
769 )
770 stride = args_dict["stride"]
771 pad = args_dict["pad"]
772 kernel = args_dict["kernel"]
773
Jeremy Johnson0601f802023-11-08 16:28:09 +0000774 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100775 self.ser, rng, input, kernel, stride, pad, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000776 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100777
778 # Ensure new output type has correct qinfo
779 if error_name == ErrorIf.WrongInputType:
780 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000781 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100782 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, input.dtype),
783 TosaQuantGen.getZeroPoint(
784 rng, self.args.zeropoint, result_tensor.dtype
785 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000786 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100787
788 # Invalidate Input/Output list for error if checks.
789 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000790 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100791 pCount, cCount = op["operands"]
792 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000793 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100794 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000795 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100796
Les Bell729b0352021-11-24 10:28:21 +0000797 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100798 self.ser,
799 validator_fcns,
800 error_name,
801 op=op,
802 input_shape=input.shape,
803 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000804 output_shape=result_tensor.shape,
805 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000806 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100807 kernel=kernel,
808 stride=stride,
809 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000810 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000811 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100812 input_list=input_list,
813 output_list=output_list,
814 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000815 ):
816 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000818 if qinfo is None:
819 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700820
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000821 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100822 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000823
824 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100826 compliance = self.tensorComplianceMetaData(
827 op, inputs[0].dtype, args_dict, result_tensor, error_name
828 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100829
830 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100831
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000832 def build_conv2d(
833 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100834 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000835 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100836 inputs,
837 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000838 validator_fcns=None,
839 error_name=None,
840 qinfo=None,
841 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100842 assert len(inputs) == 3
843 ifm, filter, bias = inputs
844 accum_dtype = args_dict["acc_type"]
845 strides = args_dict["stride"]
846 padding = args_dict["pad"]
847 dilations = args_dict["dilation"]
848
Kevin Cheng550ccc52021-03-03 11:21:43 -0800849 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100851 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100852 rng,
James Ward8b390432022-08-12 20:48:56 +0100853 ifm,
854 filter,
855 accum_dtype,
856 strides,
857 padding,
858 dilations,
859 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000860 )
861
862 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000863 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
864 DType.INT8,
865 DType.UINT8,
866 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000867 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100868 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
869 TosaQuantGen.getZeroPoint(
870 rng, self.args.zeropoint, result_tensor.dtype
871 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000872 ]
Les Bell0e027d42021-11-09 14:42:14 +0000873
874 # Invalidate Input/Output list for error_if checks.
875 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100876 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000877 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000878 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100879 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000880 )
Les Bell0e027d42021-11-09 14:42:14 +0000881
Les Bell729b0352021-11-24 10:28:21 +0000882 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000883 self.ser,
884 validator_fcns,
885 error_name,
886 op=op,
887 input_dtype=ifm.dtype,
888 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100889 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000890 qinfo=qinfo,
891 input_list=input_list,
892 num_operands=num_operands,
893 output_list=output_list,
894 pad=padding,
895 stride=strides,
896 dilation=dilations,
897 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100898 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100899 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000900 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000901 ):
902 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700903
Tai Lyd3797f02023-11-15 23:06:19 +0000904 # TODO - Test local_bound, for now set local bound attribute to False
905 local_bound = False
906
Eric Kunzee5e26762020-10-13 16:11:07 -0700907 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000908 attr.ConvAttribute(
909 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
910 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700911
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000912 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100913
914 compliance = self.tensorComplianceMetaData(
915 op, ifm.dtype, args_dict, result_tensor, error_name
916 )
917
918 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700919
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000920 def build_conv3d(
921 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100922 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000923 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100924 inputs,
925 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000926 validator_fcns=None,
927 error_name=None,
928 qinfo=None,
929 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100930 assert len(inputs) == 3
931 ifm, filter, bias = inputs
932 accum_dtype = args_dict["acc_type"]
933 strides = args_dict["stride"]
934 padding = args_dict["pad"]
935 dilations = args_dict["dilation"]
936
Kevin Cheng1533b852021-09-01 12:51:58 -0700937 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000938 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100939 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100940 rng,
James Ward8b390432022-08-12 20:48:56 +0100941 ifm,
942 filter,
943 accum_dtype,
944 strides,
945 padding,
946 dilations,
947 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000948 )
949
950 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000951 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
952 DType.INT8,
953 DType.UINT8,
954 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000955 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100956 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
957 TosaQuantGen.getZeroPoint(
958 rng, self.args.zeropoint, result_tensor.dtype
959 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000960 ]
Les Bell0e027d42021-11-09 14:42:14 +0000961
962 # Invalidate Input/Output list for error_if checks.
963 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +0000964 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000965 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000966 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100967 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000968 )
Les Bell0e027d42021-11-09 14:42:14 +0000969
Les Bell729b0352021-11-24 10:28:21 +0000970 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000971 self.ser,
972 validator_fcns,
973 error_name,
974 op=op,
975 input_dtype=ifm.dtype,
976 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +0000977 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000978 qinfo=qinfo,
979 input_list=input_list,
980 num_operands=num_operands,
981 output_list=output_list,
982 pad=padding,
983 stride=strides,
984 dilation=dilations,
985 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100986 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +0000987 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000988 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000989 ):
990 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700991
Tai Lyd3797f02023-11-15 23:06:19 +0000992 # TODO - Test local_bound, for now set local bound attribute to False
993 local_bound = False
994
Kevin Cheng1533b852021-09-01 12:51:58 -0700995 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000996 attr.ConvAttribute(
997 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
998 )
Kevin Cheng1533b852021-09-01 12:51:58 -0700999
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001000 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001001
1002 compliance = self.tensorComplianceMetaData(
1003 op, ifm.dtype, args_dict, result_tensor, error_name
1004 )
1005
1006 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001007
Kevin Cheng550ccc52021-03-03 11:21:43 -08001008 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001009 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001010 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001011 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001012 inputs,
1013 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001014 validator_fcns=None,
1015 error_name=None,
1016 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001017 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001018 assert len(inputs) == 3
1019 ifm, filter, bias = inputs
1020 accum_dtype = args_dict["acc_type"]
1021 strides = args_dict["stride"]
1022 out_pad = args_dict["pad"]
1023 output_shape = args_dict["out_shape"]
1024
TatWai Chong24594f52022-06-08 00:48:04 -07001025 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001026 result_tensor = OutputShaper.transposeConv2DOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001027 self.ser, rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 )
Les Bell0e027d42021-11-09 14:42:14 +00001029
1030 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001031 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1032 DType.INT8,
1033 DType.UINT8,
1034 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001035 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001036 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1037 TosaQuantGen.getZeroPoint(
1038 rng, self.args.zeropoint, result_tensor.dtype
1039 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001040 ]
Les Bell0e027d42021-11-09 14:42:14 +00001041
1042 # Invalidate Input/Output list for error_if checks.
1043 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001044 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001045 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001046 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001047 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001048 )
Les Bell0e027d42021-11-09 14:42:14 +00001049
Les Bell729b0352021-11-24 10:28:21 +00001050 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001051 self.ser,
1052 validator_fcns,
1053 error_name,
1054 op=op,
1055 input_dtype=ifm.dtype,
1056 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001057 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001058 qinfo=qinfo,
1059 input_list=input_list,
1060 num_operands=num_operands,
1061 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001062 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001063 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001064 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001065 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001066 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001067 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001068 ):
1069 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001070
Tai Lyd3797f02023-11-15 23:06:19 +00001071 # TODO - Test local_bound, for now set local bound attribute to False
1072 local_bound = False
1073
Eric Kunzee5e26762020-10-13 16:11:07 -07001074 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001075 attr.TransposeConvAttribute(
Tai Lyf36f2562024-03-14 16:21:29 +00001076 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
Tai Lyd3797f02023-11-15 23:06:19 +00001077 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001078
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001079 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001080
1081 compliance = self.tensorComplianceMetaData(
1082 op, ifm.dtype, args_dict, result_tensor, error_name
1083 )
1084
1085 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001086
Kevin Cheng550ccc52021-03-03 11:21:43 -08001087 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001088 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001089 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001090 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001091 inputs,
1092 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001093 validator_fcns=None,
1094 error_name=None,
1095 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001096 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001097 assert len(inputs) == 3
1098 ifm, filter, bias = inputs
1099 accum_dtype = args_dict["acc_type"]
1100 strides = args_dict["stride"]
1101 padding = args_dict["pad"]
1102 dilations = args_dict["dilation"]
1103
Jeremy Johnson4f931302024-01-04 17:05:24 +00001104 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001105 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001106 rng,
James Ward8b390432022-08-12 20:48:56 +01001107 ifm,
1108 filter,
1109 accum_dtype,
1110 strides,
1111 padding,
1112 dilations,
1113 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001114 )
1115
1116 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001117 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1118 DType.INT8,
1119 DType.UINT8,
1120 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001121 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001122 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1123 TosaQuantGen.getZeroPoint(
1124 rng, self.args.zeropoint, result_tensor.dtype
1125 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001126 ]
Les Bell0e027d42021-11-09 14:42:14 +00001127
1128 # Invalidate Input/Output list for error_if checks.
1129 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001130 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001131 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001132 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001133 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001134 )
Les Bell0e027d42021-11-09 14:42:14 +00001135
Les Bell729b0352021-11-24 10:28:21 +00001136 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001137 self.ser,
1138 validator_fcns,
1139 error_name,
1140 op=op,
1141 input_dtype=ifm.dtype,
1142 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001143 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001144 qinfo=qinfo,
1145 input_list=input_list,
1146 num_operands=num_operands,
1147 output_list=output_list,
1148 pad=padding,
1149 stride=strides,
1150 dilation=dilations,
1151 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001152 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001153 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001154 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001155 ):
1156 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001157
Tai Lyd3797f02023-11-15 23:06:19 +00001158 # TODO - Test local_bound, for now set local bound attribute to False
1159 local_bound = False
1160
Eric Kunzee5e26762020-10-13 16:11:07 -07001161 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +00001162 attr.ConvAttribute(
1163 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
1164 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001165
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001166 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001167
1168 compliance = self.tensorComplianceMetaData(
1169 op, ifm.dtype, args_dict, result_tensor, error_name
1170 )
1171
1172 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001173
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001174 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001175 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001176 rng,
James Ward8b390432022-08-12 20:48:56 +01001177 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001178 inputs,
1179 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001180 validator_fcns=None,
1181 error_name=None,
1182 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001183 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001184 assert len(inputs) == 3
1185 ifm, filter, bias = inputs
1186 accum_dtype = args_dict["acc_type"]
1187
1188 result_tensor = OutputShaper.fullyConnectedOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001189 self.ser, rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001190 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001191
1192 # Invalidate Input/Output list for error if checks.
1193 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001194 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001195 pCount, cCount = op["operands"]
1196 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001198 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001199 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001200
Les Bell729b0352021-11-24 10:28:21 +00001201 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001202 self.ser,
1203 validator_fcns,
1204 error_name,
1205 op=op,
1206 input_shape=ifm.shape,
1207 input_dtype=ifm.dtype,
1208 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001209 output_shape=result_tensor.shape,
1210 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001211 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001212 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001213 input_list=input_list,
1214 output_list=output_list,
1215 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001216 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001217 ):
1218 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001219
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001220 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001221 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001222
1223 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001224
1225 compliance = self.tensorComplianceMetaData(
1226 op, ifm.dtype, args_dict, result_tensor, error_name
1227 )
1228
1229 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001230
James Ward8b390432022-08-12 20:48:56 +01001231 def build_matmul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001232 self,
1233 rng,
1234 op,
1235 inputs,
1236 args_dict,
1237 validator_fcns=None,
1238 error_name=None,
1239 qinfo=None,
James Ward8b390432022-08-12 20:48:56 +01001240 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001241 assert len(inputs) == 2
1242 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001243 accum_dtype = args_dict["acc_type"]
1244 result_tensor = OutputShaper.matmulOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001245 self.ser, rng, a, b, accum_dtype, error_name
James Ward8b390432022-08-12 20:48:56 +01001246 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001247
1248 # Invalidate Input/Output list for error if checks.
1249 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001250 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001251 pCount, cCount = op["operands"]
1252 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001253 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001254 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001255 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001256
Les Bell729b0352021-11-24 10:28:21 +00001257 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001258 self.ser,
1259 validator_fcns,
1260 error_name,
1261 op=op,
1262 input_shape=a.shape,
1263 input_dtype=a.dtype,
1264 input2_shape=b.shape,
1265 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001266 output_shape=result_tensor.shape,
1267 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001268 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001269 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001270 input_list=input_list,
1271 output_list=output_list,
1272 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001273 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001274 ):
1275 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001276
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001277 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001278 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001279
1280 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001281
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001282 compliance = self.tensorComplianceMetaData(
1283 op, a.dtype, args_dict, result_tensor, error_name
1284 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001285
1286 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001287
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001288 def build_reduce(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001289 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001290 ):
1291 assert len(inputs) == 1
1292 a = inputs[0]
1293 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001294 result_tensor = OutputShaper.reduceOp(self.ser, rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001295
1296 # Invalidate Input/Output list for error if checks.
1297 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001298 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001299 pCount, cCount = op["operands"]
1300 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001301 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001302 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001303 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001304
Les Bell729b0352021-11-24 10:28:21 +00001305 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001306 self.ser,
1307 validator_fcns,
1308 error_name,
1309 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 axis=axis,
1311 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001312 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001313 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001314 output_dtype=result_tensor.dtype,
1315 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001316 input_list=input_list,
1317 output_list=output_list,
1318 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001319 ):
1320 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001321
1322 attr = ts.TosaSerializerAttribute()
1323 attr.AxisAttribute(axis)
1324
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001325 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001326
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001327 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1328 # Number of products - needed for compliance
1329 args_dict["n"] = a.shape[axis]
1330
1331 compliance = self.tensorComplianceMetaData(
1332 op, a.dtype, args_dict, result_tensor, error_name
1333 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001334
1335 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001336
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001337 def build_clamp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001338 self,
1339 rng,
1340 op,
1341 inputs,
1342 args_dict,
1343 validator_fcns=None,
1344 error_name=None,
1345 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001346 ):
1347 assert len(inputs) == 1
1348 a = inputs[0]
1349
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001350 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001351
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001352 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001353
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001354 if error_name == ErrorIf.MaxSmallerMin:
1355 # Make sure the numbers are different to invoke this error
1356 while v[0] == v[1]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001357 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001358 max_val = min(v)
1359 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001360 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001361 max_val = max(v)
1362 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001363
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001364 # Invalidate Input/Output list for error if checks.
1365 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001366 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001367 pCount, cCount = op["operands"]
1368 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001369 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001370 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001372
Les Bell729b0352021-11-24 10:28:21 +00001373 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374 self.ser,
1375 validator_fcns,
1376 error_name,
1377 op=op,
1378 max_val=max_val,
1379 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001380 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001381 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001382 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001383 output_dtype=result_tensor.dtype,
1384 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001385 input_list=input_list,
1386 output_list=output_list,
1387 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001388 ):
1389 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390
1391 attr = ts.TosaSerializerAttribute()
Tai Ly5d0e9c72024-04-05 01:19:31 +00001392 min_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(a.dtype, [min_val])
1393 max_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(a.dtype, [max_val])
1394
1395 # align to 8 bytes
1396 while (len(min_val_as_bytes) % 8) != 0:
1397 min_val_as_bytes.append(0)
1398 while (len(max_val_as_bytes) % 8) != 0:
1399 max_val_as_bytes.append(0)
Tai Ly60dc48c2024-03-08 22:19:41 +00001400
1401 attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001402
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001403 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001404
1405 compliance = self.tensorComplianceMetaData(
1406 op, a.dtype, args_dict, result_tensor, error_name
1407 )
1408
1409 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001410
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001411 def build_activation(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001412 self,
1413 rng,
1414 op,
1415 inputs,
1416 args_dict,
1417 validator_fcns=None,
1418 error_name=None,
1419 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001420 ):
1421 assert len(inputs) == 1
1422 a = inputs[0]
1423
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001424 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001425
1426 # Invalidate Input/Output list for error if checks.
1427 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001428 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001429 pCount, cCount = op["operands"]
1430 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001431 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001432 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001433 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001434
Les Bell729b0352021-11-24 10:28:21 +00001435 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001436 self.ser,
1437 validator_fcns,
1438 error_name,
1439 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001440 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001441 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001442 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001443 output_dtype=result_tensor.dtype,
1444 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001445 input_list=input_list,
1446 output_list=output_list,
1447 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001448 ):
1449 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001450
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001451 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001452
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001453 compliance = self.tensorComplianceMetaData(
1454 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001455 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001456
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001457 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001458
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001459 def build_concat(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001460 self,
1461 rng,
1462 op,
1463 inputs,
1464 args_dict,
1465 validator_fcns=None,
1466 error_name=None,
1467 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001468 ):
Won Jeon74342e52024-01-09 00:34:40 +00001469 if op["op"] == Op.CONCAT_SHAPE:
1470 axis = 0
1471 else:
1472 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001473 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001474 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001475
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001476 result_tensor = OutputShaper.concatOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001477 self.ser, rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001478 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001479
Matthew Haddon818ab902021-07-27 09:12:49 +01001480 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001481 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001482 input_tensor_names.append(tensor.name)
1483
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001484 # Invalidate Input/Output list for error if checks.
1485 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001486 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001487 pCount, cCount = op["operands"]
1488 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001489 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001490 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001491 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001492
Les Bell729b0352021-11-24 10:28:21 +00001493 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001494 self.ser,
1495 validator_fcns,
1496 error_name,
1497 op=op,
1498 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001499 input_shape=inputs[0].shape,
1500 output_shape=result_tensor.shape,
1501 input_dtype=inputs[0].dtype,
1502 output_dtype=result_tensor.dtype,
1503 inputs=inputs,
1504 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001505 input_list=input_list,
1506 output_list=output_list,
1507 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001508 ):
1509 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001510
Won Jeon74342e52024-01-09 00:34:40 +00001511 if op["op"] == Op.CONCAT:
1512 attr = ts.TosaSerializerAttribute()
1513 attr.AxisAttribute(axis)
1514 else:
1515 assert op["op"] == Op.CONCAT_SHAPE
1516 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001517 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001518
1519 compliance = self.tensorComplianceMetaData(
1520 op, inputs[0].dtype, args_dict, result_tensor, error_name
1521 )
1522
1523 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001524
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001525 def build_pad(
1526 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001527 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001528 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001529 inputs,
1530 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001531 validator_fcns=None,
1532 error_name=None,
1533 qinfo=None,
1534 ):
Tai Lye095da72024-01-25 22:00:18 +00001535 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001536 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001537 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001538 padding = args_dict["pad"]
1539 pad_const_int = args_dict["pad_const_int"]
1540 pad_const_float = args_dict["pad_const_fp"]
1541
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001542 result_tensor = OutputShaper.padOp(self.ser, rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001543
Tai Ly60dc48c2024-03-08 22:19:41 +00001544 # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
1545 if gtu.dtypeIsFloat(a.dtype):
Tai Ly5d0e9c72024-04-05 01:19:31 +00001546 pad_const_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(
1547 a.dtype, [pad_const_float]
1548 )
Tai Ly60dc48c2024-03-08 22:19:41 +00001549 else:
Tai Ly5d0e9c72024-04-05 01:19:31 +00001550 pad_const_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(
1551 a.dtype, [pad_const_int]
1552 )
1553
1554 # align to 8 bytes
1555 while (len(pad_const_val_as_bytes) % 8) != 0:
1556 pad_const_val_as_bytes.append(0)
Tai Ly60dc48c2024-03-08 22:19:41 +00001557
Kevin Chengfe392ce2021-10-18 21:51:55 +00001558 attr = ts.TosaSerializerAttribute()
Tai Ly60dc48c2024-03-08 22:19:41 +00001559 attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
Eric Kunzee5e26762020-10-13 16:11:07 -07001560
Matthew Haddone807aae2021-10-11 18:12:58 +01001561 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001562 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001563 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001564 pCount, cCount = op["operands"]
1565 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001566 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001567 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001568 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001569
Les Bell729b0352021-11-24 10:28:21 +00001570 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001571 self.ser,
1572 validator_fcns,
1573 error_name,
1574 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001575 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001576 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001577 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001578 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001579 pad=padding,
1580 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001581 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001582 input_list=input_list,
1583 output_list=output_list,
1584 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001585 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001586 ):
1587 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001588
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001589 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001590
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001591 compliance = self.tensorComplianceMetaData(
1592 op, a.dtype, args_dict, result_tensor, error_name
1593 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001594
1595 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001596
Won Jeona21b2e82023-08-10 10:33:01 +00001597 def build_dim(
1598 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001599 rng,
Won Jeona21b2e82023-08-10 10:33:01 +00001600 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001601 inputs,
1602 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001603 validator_fcns=None,
1604 error_name=None,
1605 qinfo=None,
1606 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001607 assert len(inputs) == 1
1608 a = inputs[0]
1609 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001610 result_tensor = OutputShaper.dimOp(self.ser, rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001611
1612 # Invalidate Input/Output list for error if checks.
1613 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001614 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001615 pCount, cCount = op["operands"]
1616 num_operands = pCount + cCount
1617 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001618 rng, error_name, input_list, output_list
Won Jeona21b2e82023-08-10 10:33:01 +00001619 )
1620
1621 if not TosaErrorValidator.evValidateErrorIfs(
1622 self.ser,
1623 validator_fcns,
1624 error_name,
1625 op=op,
1626 axis=axis,
1627 input_shape=a.shape,
1628 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001629 output_shape=result_tensor.shape,
1630 output_dtype=result_tensor.dtype,
1631 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001632 input_list=input_list,
1633 output_list=output_list,
1634 num_operands=num_operands,
1635 ):
1636 return None
1637
1638 attr = ts.TosaSerializerAttribute()
1639 attr.AxisAttribute(axis)
1640
1641 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001642 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001643
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001644 def build_reshape(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001645 self,
1646 rng,
1647 op,
1648 inputs,
1649 args_dict,
1650 validator_fcns=None,
1651 error_name=None,
1652 qinfo=None,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001653 ):
Tai Ly8690a082023-12-18 20:40:24 +00001654 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001655 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001656 shape = inputs[1]
1657 shape_attr = args_dict["new_shape"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001658 result_tensor = OutputShaper.reshapeOp(self.ser, rng, a, shape_attr, error_name)
Matthew Haddone807aae2021-10-11 18:12:58 +01001659
1660 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001661 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001662 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001663 pCount, cCount = op["operands"]
1664 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001665 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001666 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001667 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001668
Les Bell729b0352021-11-24 10:28:21 +00001669 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001670 self.ser,
1671 validator_fcns,
1672 error_name,
1673 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001674 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001675 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001676 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001677 output_dtype=result_tensor.dtype,
1678 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001679 input_list=input_list,
1680 output_list=output_list,
1681 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001682 ):
1683 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001684
Tai Ly8690a082023-12-18 20:40:24 +00001685 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001686
1687 compliance = self.tensorComplianceMetaData(
1688 op, a.dtype, args_dict, result_tensor, error_name
1689 )
1690
1691 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001692
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001693 def build_reverse(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001694 self,
1695 rng,
1696 op,
1697 inputs,
1698 args_dict,
1699 validator_fcns=None,
1700 error_name=None,
1701 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001702 ):
1703 assert len(inputs) == 1
1704 a = inputs[0]
1705 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001706 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001707
1708 # Invalidate Input/Output list for error if checks.
1709 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001710 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001711 pCount, cCount = op["operands"]
1712 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001713 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001714 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001715 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001716
Les Bell729b0352021-11-24 10:28:21 +00001717 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001718 self.ser,
1719 validator_fcns,
1720 error_name,
1721 op=op,
1722 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001723 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001724 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001725 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001726 output_dtype=result_tensor.dtype,
1727 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001728 input_list=input_list,
1729 output_list=output_list,
1730 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001731 ):
1732 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001733
1734 attr = ts.TosaSerializerAttribute()
1735 attr.AxisAttribute(axis)
1736
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001737 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001738 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001739
evacha0198477222024-01-26 12:25:32 +00001740 def build_transpose(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001741 self,
1742 rng,
1743 op,
1744 inputs,
1745 args_dict,
1746 validator_fcns=None,
1747 error_name=None,
1748 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00001749 ):
1750 assert len(inputs) == 1
1751 a = inputs[0]
1752 perms = args_dict["perms"]
1753
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001754 result_tensor = OutputShaper.transposeOp(self.ser, rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001755
Kevin Chengfe392ce2021-10-18 21:51:55 +00001756 attr = ts.TosaSerializerAttribute()
1757 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001758
Matthew Haddone807aae2021-10-11 18:12:58 +01001759 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001760 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001761 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001762 pCount, cCount = op["operands"]
1763 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001764 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001765 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001766 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001767
Les Bell729b0352021-11-24 10:28:21 +00001768 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001769 self.ser,
1770 validator_fcns,
1771 error_name,
1772 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001774 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001775 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001776 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001777 output_dtype=result_tensor.dtype,
1778 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001779 input_list=input_list,
1780 output_list=output_list,
1781 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001782 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001783 ):
1784 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001785
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001786 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001787
1788 compliance = self.tensorComplianceMetaData(
1789 op, a.dtype, args_dict, result_tensor, error_name
1790 )
1791
1792 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001793
evacha017f7d4252024-01-24 12:08:09 +00001794 def build_slice(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001795 self,
1796 rng,
1797 op,
1798 inputs,
1799 args_dict,
1800 validator_fcns=None,
1801 error_name=None,
1802 qinfo=None,
evacha017f7d4252024-01-24 12:08:09 +00001803 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001804 assert len(inputs) == 3
1805 a, start_var, size_var = inputs
1806 start_const = args_dict["start"]
1807 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001808
1809 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001810 self.ser, rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001811 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001812
1813 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001814 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001815 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001816 pCount, cCount = op["operands"]
1817 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001818 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001819 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001820 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001821
Les Bell729b0352021-11-24 10:28:21 +00001822 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001823 self.ser,
1824 validator_fcns,
1825 error_name,
1826 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001828 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001829 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001830 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001831 start=start_const,
1832 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001833 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001834 input_list=input_list,
1835 output_list=output_list,
1836 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001837 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001838 ):
1839 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001840
Tai Ly8ead6c42024-02-14 22:35:44 +00001841 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001842
1843 compliance = self.tensorComplianceMetaData(
1844 op, a.dtype, args_dict, result_tensor, error_name
1845 )
1846
1847 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001848
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001849 def build_tile(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001850 self,
1851 rng,
1852 op,
1853 inputs,
1854 args_dict,
1855 validator_fcns=None,
1856 error_name=None,
1857 qinfo=None,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001858 ):
Tai Ly8690a082023-12-18 20:40:24 +00001859 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001860 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001861 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001862 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001863 result_tensor = OutputShaper.tileOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001864 self.ser, rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001865 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001866
1867 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001868 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001869 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001870 pCount, cCount = op["operands"]
1871 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001872 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001873 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001874 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001875
Les Bell729b0352021-11-24 10:28:21 +00001876 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001877 self.ser,
1878 validator_fcns,
1879 error_name,
1880 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001881 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001882 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001883 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001884 output_dtype=result_tensor.dtype,
1885 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001886 input_list=input_list,
1887 output_list=output_list,
1888 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001889 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001890 ):
1891 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001892
Tai Ly8690a082023-12-18 20:40:24 +00001893 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001894
1895 compliance = self.tensorComplianceMetaData(
1896 op, a.dtype, args_dict, result_tensor, error_name
1897 )
1898
1899 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001900
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001901 def build_gather(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001902 self,
1903 rng,
1904 op,
1905 inputs,
1906 args_dict,
1907 validator_fcns=None,
1908 error_name=None,
1909 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001910 ):
1911 assert len(inputs) == 2
1912 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001913
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001914 result_tensor = OutputShaper.gatherOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001915 self.ser, rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001916 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001917
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001918 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001919 input_list = [values.name, indices.name]
1920 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001921 pCount, cCount = op["operands"]
1922 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001923 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001924 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001925 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001926
Les Bell729b0352021-11-24 10:28:21 +00001927 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001928 self.ser,
1929 validator_fcns,
1930 error_name,
1931 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001932 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001933 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001934 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001935 output_dtype=result_tensor.dtype,
1936 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001937 input_list=input_list,
1938 output_list=output_list,
1939 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001940 ):
1941 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001942
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001943 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001944
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001945 compliance = self.tensorComplianceMetaData(
1946 op, values.dtype, args_dict, result_tensor, error_name
1947 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001948
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001949 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001950
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001951 def build_scatter(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001952 self,
1953 rng,
1954 op,
1955 inputs,
1956 args_dict,
1957 validator_fcns=None,
1958 error_name=None,
1959 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001960 ):
1961 assert len(inputs) == 3
1962 values_in, indices, input = inputs
1963 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001964 self.ser, rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001965 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001966
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001967 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001968 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001969 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001970 pCount, cCount = op["operands"]
1971 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001972 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001973 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001974 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001975
Les Bell729b0352021-11-24 10:28:21 +00001976 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001977 self.ser,
1978 validator_fcns,
1979 error_name,
1980 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001981 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001982 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001983 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001984 output_dtype=result_tensor.dtype,
1985 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001986 input_list=input_list,
1987 output_list=output_list,
1988 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001989 ):
1990 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001991
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001992 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001993
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001994 compliance = self.tensorComplianceMetaData(
1995 op, values_in.dtype, args_dict, result_tensor, error_name
1996 )
1997
1998 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001999
Kevin Cheng550ccc52021-03-03 11:21:43 -08002000 def build_resize(
2001 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002002 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002003 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002004 inputs,
2005 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01002006 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002007 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002008 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002009 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002010 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002011 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002012 scale_input = inputs[1]
2013 offset_input = inputs[2]
2014 border_input = inputs[3]
2015
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002016 mode = args_dict["mode"]
2017 scale = args_dict["scale"]
2018 offset = args_dict["offset"]
2019 border = args_dict["border"]
2020 output_dtype = args_dict["output_dtype"]
2021
2022 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002023 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002024 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002025 input,
2026 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002027 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002028 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002029 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002030 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002031 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002032 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002033 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002034
Matthew Haddon848efb42021-09-09 12:30:53 +01002035 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002036 input_list = [
2037 input.name,
2038 scale_input.name,
2039 offset_input.name,
2040 border_input.name,
2041 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002042 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002043 pCount, cCount = op["operands"]
2044 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002045 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002046 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002047 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002048
Les Bell729b0352021-11-24 10:28:21 +00002049 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002050 self.ser,
2051 validator_fcns,
2052 error_name,
2053 op=op,
2054 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002055 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002056 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002057 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002058 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002059 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002060 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002061 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002062 input_list=input_list,
2063 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002064 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002065 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002066 ):
2067 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002068
Eric Kunzee5e26762020-10-13 16:11:07 -07002069 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002070 # write empty scale/offset/border into ResizeAttribute
2071 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002072 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002073
2074 compliance = self.tensorComplianceMetaData(
2075 op, input.dtype, args_dict, result_tensor, error_name
2076 )
2077
2078 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002079
evacha0198477222024-01-26 12:25:32 +00002080 def build_const(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002081 self,
2082 rng,
2083 op,
2084 inputs,
2085 args_dict,
2086 validator_fcns=None,
2087 error_name=None,
2088 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00002089 ):
2090 assert len(inputs) == 1
2091 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002092 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002093
2094 compliance = self.tensorComplianceMetaData(
2095 op, val.dtype, args_dict, val, error_name
2096 )
2097
2098 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002099
2100 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002101 def build_cast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002102 self,
2103 rng,
2104 op,
2105 inputs,
2106 args_dict,
2107 validator_fcns=None,
2108 error_name=None,
2109 qinfo=None,
Jeremy Johnson708da822023-11-15 16:25:45 +00002110 ):
2111 assert len(inputs) == 1
2112 val = inputs[0]
2113 out_dtype = args_dict["out_type"]
2114
2115 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002116 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002117 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002118
2119 # Invalidate Input/Output list for error if checks.
2120 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002121 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002122 pCount, cCount = op["operands"]
2123 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002124 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002125 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002126 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002127
Les Bell729b0352021-11-24 10:28:21 +00002128 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002129 self.ser,
2130 validator_fcns,
2131 error_name,
2132 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002134 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002135 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002136 output_dtype=result_tensor.dtype,
2137 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002138 input_list=input_list,
2139 output_list=output_list,
2140 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002141 ):
2142 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002143
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002144 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002145
2146 compliance = self.tensorComplianceMetaData(
2147 op, val.dtype, args_dict, result_tensor, error_name
2148 )
2149
2150 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002151
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002152 def build_rescale(
2153 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002154 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002155 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002156 inputs,
2157 args_dict,
2158 validator_fcns=None,
2159 error_name=None,
2160 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002161 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002162 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002163 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002164 multiplier_val = inputs[1]
2165 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002166 out_dtype = args_dict["output_dtype"]
2167 scale32 = args_dict["scale"]
2168 double_round = args_dict["double_round"]
2169 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002170 shift_arr = args_dict["shift"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002171 multiplier_arr = args_dict["multiplier"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002172
2173 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002174 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002175 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002176
2177 if per_channel:
2178 nc = val.shape[-1]
2179 else:
2180 nc = 1
2181
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002182 in_type_width = gtu.dtypeWidth(val.dtype)
2183 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002184
Tai Ly8690a082023-12-18 20:40:24 +00002185 input_unsigned = False
2186 output_unsigned = False
2187
Kevin Cheng3a478572021-01-22 17:21:02 -08002188 if val.dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002189 input_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002190 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002191 elif val.dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002192 input_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002193 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002194 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002195 elif error_name in [
2196 ErrorIf.InputZeroPointNotZero,
2197 ErrorIf.U16InputZeroPointNotValid,
2198 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002199 input_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002200 if input_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002201 input_zp = input_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002202 in_type_width += 1
2203 elif val.dtype == DType.UINT16:
2204 # Must come after ErrorIf.U16InputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002205 input_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002206 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002207 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002208 else:
2209 input_zp = 0
2210
Kevin Cheng3a478572021-01-22 17:21:02 -08002211 if out_dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002212 output_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002213 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002214 elif out_dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002215 output_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002216 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002217 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002218 elif error_name in [
2219 ErrorIf.OutputZeroPointNotZero,
2220 ErrorIf.U16OutputZeroPointNotValid,
2221 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002222 output_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002223 if output_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002224 output_zp = output_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002225 out_type_width += 1
2226 elif out_dtype == DType.UINT16:
2227 # Must come after ErrorIf.U16OutputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002228 output_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002229 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002230 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002231 else:
2232 output_zp = 0
2233
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002234 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2235 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002236
2237 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002238 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2239 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002240
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002241 logger.debug(
2242 f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
2243 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002244 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002245 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002246 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002247 assert val.placeholderFilename
2248 values = np.load(
2249 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2250 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002251 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2252 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2253 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002254 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2255 # Check we can safely convert to the expected dtype
2256 assert (
2257 val_adj.all() >= np.iinfo(values.dtype).min
2258 and val_adj.all() <= np.iinfo(values.dtype).max
2259 )
2260
2261 # Force casting to output datatype
2262 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2263
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002264 if not np.all(np.array_equal(values, val_adj)):
2265 # Values changed so overwrite file with new values
2266 np.save(
2267 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2268 val_adj,
2269 False,
2270 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002271
Matthew Haddonc2025212021-10-08 21:21:05 +01002272 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002273 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002274 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002275 pCount, cCount = op["operands"]
2276 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002277 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002278 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002279 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002280
2281 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002282 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002283 self.ser,
2284 validator_fcns,
2285 error_name,
2286 op=op,
2287 input_dtype=val.dtype,
2288 output_dtype=out_dtype,
2289 input_shape=val.shape,
2290 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002291 scale32=scale32,
2292 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002293 input_list=input_list,
2294 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002295 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002296 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002297 ):
2298 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002299
Eric Kunzee5e26762020-10-13 16:11:07 -07002300 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002301 attr.RescaleAttribute(
2302 input_zp,
2303 output_zp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002304 scale32,
2305 double_round,
2306 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002307 input_unsigned,
2308 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002309 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002310
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002311 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002312
2313 compliance = self.tensorComplianceMetaData(
2314 op, val.dtype, args_dict, result_tensor, error_name
2315 )
2316
2317 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002318
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002319 def _get_condition_tensor(self, rng, op, cond, error_name):
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002320 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002321 cond_type = gtu.get_wrong_output_type(op, rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002322 else:
2323 cond_type = DType.BOOL
2324 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002325 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002326 if choice == 1:
2327 cond_shape = [2]
2328 else:
2329 cond_shape = [1, 2]
2330 else:
2331 # Must be of size 1 (rank 0)
2332 cond_shape = []
2333 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2334 return cond_tens
2335
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002336 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002337 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002338 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002339 op,
2340 inputs,
2341 args_dict,
2342 validator_fcns=None,
2343 error_name=None,
2344 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002345 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002346 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002347 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002348 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002349 assert len(inputs) == 2
2350 then_tens, else_tens = inputs
2351
2352 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002353
2354 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002355 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002356
2357 # Make then/else tensors
2358 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002359
Jeremy Johnson587cc842024-02-08 11:45:44 +00002360 dtype = DType.INT32
2361
Matthew Haddon630c17c2021-10-14 15:05:41 +01002362 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002363 if error_name in [
2364 ErrorIf.CondIfOutputListThenGraphMismatch,
2365 ErrorIf.CondIfOutputListElseGraphMismatch,
2366 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002367 incorrect_shape = deepcopy(then_tens.shape)
2368 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002369 incorrect_shape[i] += (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002370 rng.choice([-3, -2, 2, 3])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002371 if incorrect_shape[i] > 3
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002372 else rng.choice([1, 2, 4])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002373 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002374 incorrect_arr = np.int32(rng.integers(0, 256, size=incorrect_shape))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002375
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002376 then_arr = np.int32(rng.integers(0, 256, size=out_shape))
2377 else_arr = np.int32(rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002378
2379 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002380 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002381
2382 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002383 then_block = "THEN_BLOCK"
2384 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002385 attr = ts.TosaSerializerAttribute()
2386 attr.CondIfAttribute(then_block, else_block)
2387
2388 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002389 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002390
Jerry Ge9e94af82022-10-27 09:57:00 -07002391 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002392 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002393 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002394 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002395 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002396 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002397 self.ser.addOutputTensor(then_tens)
2398
Jerry Ge9e94af82022-10-27 09:57:00 -07002399 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002400 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002401 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002402 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002403 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002404 self.ser.addOutputTensor(else_tens)
2405
Les Bell729b0352021-11-24 10:28:21 +00002406 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002407 self.ser,
2408 validator_fcns,
2409 error_name,
2410 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002411 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002412 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002413 ):
2414 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002415
Jeremy Johnson587cc842024-02-08 11:45:44 +00002416 compliance = self.tensorComplianceMetaData(
2417 op, dtype, args_dict, result_tensor, error_name
2418 )
2419
2420 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002421
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002422 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002423 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002424 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002425 op,
2426 inputs,
2427 args_dict,
2428 validator_fcns=None,
2429 error_name=None,
2430 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002431 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002432 # For cond_if with a binary op in the then/else blocks, take a and b and
2433 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002434 assert len(inputs) == 2
2435 a, b = inputs
2436
2437 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002438
2439 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002440 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002441
Jeremy Johnson587cc842024-02-08 11:45:44 +00002442 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002443
2444 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002445 then_block = "THEN_BLOCK"
2446 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002447 attr = ts.TosaSerializerAttribute()
2448 attr.CondIfAttribute(then_block, else_block)
2449
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002450 if error_name in [
2451 ErrorIf.CondIfInputListThenGraphMismatch,
2452 ErrorIf.CondIfInputListElseGraphMismatch,
2453 ErrorIf.CondIfOutputListElseGraphMismatch,
2454 ErrorIf.CondIfOutputListThenGraphMismatch,
2455 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002456 incorrect_shape = a.shape.copy()
2457 for i in range(len(incorrect_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002458 incorrect_shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002459 incorrect_block_input = deepcopy(a)
2460 incorrect_block_input.shape = incorrect_shape
2461
Eric Kunzee5e26762020-10-13 16:11:07 -07002462 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002463 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002464 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002465 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002466
James Ward24dbc422022-10-19 12:20:31 +01002467 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002468 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002469 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002470 then_op, else_op = (
2471 self.TOSA_OP_LIST["logical_right_shift"],
2472 self.TOSA_OP_LIST["logical_left_shift"],
2473 )
Les Bell6040b4d2021-10-11 12:50:31 +01002474 else:
2475 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002476
Jeremy Johnson587cc842024-02-08 11:45:44 +00002477 # Determine the element-wise binary operation that compliance will need to
2478 # check the results of
2479 compliance_op = then_op if cond else else_op
2480
2481 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002482 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002483 if (
2484 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2485 and block == then_block
2486 ) or (
2487 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2488 and block == else_block
2489 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002490 self.ser.addInputTensor(incorrect_block_input)
2491 self.ser.addInputTensor(b)
2492 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002493 elif (
2494 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2495 and block == then_block
2496 ) or (
2497 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2498 and block == else_block
2499 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002500 self.ser.addInputTensor(a)
2501 self.ser.addInputTensor(b)
2502 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2503 else:
2504 self.ser.addInputTensor(a)
2505 self.ser.addInputTensor(b)
2506 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002507 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002508
Les Bell729b0352021-11-24 10:28:21 +00002509 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002510 self.ser,
2511 validator_fcns,
2512 error_name,
2513 op=op,
2514 a=a,
2515 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002516 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002517 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002518 ):
2519 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002520
Jeremy Johnson587cc842024-02-08 11:45:44 +00002521 compliance = self.tensorComplianceMetaData(
2522 compliance_op, a.dtype, args_dict, result_tensor, error_name
2523 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002524
Jeremy Johnson587cc842024-02-08 11:45:44 +00002525 return TosaTestGen.BuildInfo(result_tensor, compliance)
2526
2527 def build_while_loop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002528 self,
2529 rng,
2530 op,
2531 inputs,
2532 args_dict,
2533 validator_fcns=None,
2534 error_name=None,
2535 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002536 ):
2537 assert len(inputs) == 1
2538 a = inputs[0]
2539 iter_val = args_dict["iterations"]
2540
Kevin Cheng550ccc52021-03-03 11:21:43 -08002541 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002542
Kevin Cheng550ccc52021-03-03 11:21:43 -08002543 cond_block = "COND_BLOCK"
2544 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002545
2546 attr = ts.TosaSerializerAttribute()
2547 attr.WhileLoopAttribute(cond_block, body_block)
2548
2549 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002550 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002551 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002552 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002553
2554 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002555 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2556 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002557 if error_name == ErrorIf.InputListOutputListMismatch:
2558 incorrect_acc = deepcopy(acc)
2559 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002560 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002561 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2562 else:
2563 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002564
2565 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002566 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002567 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002568 [iter.name, a.name, acc.name],
2569 [iter_out.name, a_out.name, acc_out.name],
2570 attr,
2571 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002572 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002573
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002574 if error_name in [
2575 ErrorIf.InputListCondGraphMismatch,
2576 ErrorIf.InputListBodyGraphInputMismatch,
2577 ErrorIf.InputListBodyGraphOutputMismatch,
2578 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002579 incorrect_iter = deepcopy(iter)
2580 for i in range(len(incorrect_iter.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002581 incorrect_iter.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002582 if len(incorrect_iter.shape) == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002583 incorrect_iter.shape.append(rng.choice([-3, -2, 2, 3]))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002584
2585 incorrect_acc = deepcopy(acc)
2586 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002587 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002588
Eric Kunzee5e26762020-10-13 16:11:07 -07002589 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002590 self.ser.addBasicBlock(cond_block)
2591
Matthew Haddon630c17c2021-10-14 15:05:41 +01002592 if error_name == ErrorIf.InputListCondGraphMismatch:
2593 self.ser.addInputTensor(incorrect_iter)
2594 self.ser.addInputTensor(a)
2595 self.ser.addInputTensor(incorrect_acc)
2596 else:
2597 self.ser.addInputTensor(iter)
2598 self.ser.addInputTensor(a)
2599 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002600 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002601
2602 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002603 cond_type = rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002604 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002605 cond_type = DType.BOOL
2606 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002607 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002608 if choice == 1:
2609 cond_shape = [3]
2610 else:
2611 cond_shape = [1, 2]
2612 else:
2613 cond_shape = []
2614 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002615
Kevin Cheng550ccc52021-03-03 11:21:43 -08002616 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002617
2618 # BODY block (input: a, acc, iter, output: a, acc, iter)
2619 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002620 self.ser.addBasicBlock(body_block)
2621
Matthew Haddon630c17c2021-10-14 15:05:41 +01002622 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2623 self.ser.addInputTensor(incorrect_iter)
2624 self.ser.addInputTensor(a)
2625 self.ser.addInputTensor(incorrect_acc)
2626 else:
2627 self.ser.addInputTensor(iter)
2628 self.ser.addInputTensor(a)
2629 self.ser.addInputTensor(acc)
2630
Kevin Cheng550ccc52021-03-03 11:21:43 -08002631 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002632
2633 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002634 iter_body_out = self.ser.addIntermediate(
2635 incorrect_iter.shape, incorrect_iter.dtype
2636 )
2637 acc_body_out = self.ser.addIntermediate(
2638 incorrect_acc.shape, incorrect_acc.dtype
2639 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002640 else:
2641 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2642 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2643
Eric Kunzee5e26762020-10-13 16:11:07 -07002644 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2645 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2646 self.ser.addOutputTensor(iter_body_out)
2647 self.ser.addOutputTensor(a)
2648 self.ser.addOutputTensor(acc_body_out)
2649
Les Bell729b0352021-11-24 10:28:21 +00002650 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002651 self.ser,
2652 validator_fcns,
2653 error_name,
2654 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002655 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002656 ):
2657 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002658
Jeremy Johnson587cc842024-02-08 11:45:44 +00002659 compliance = self.tensorComplianceMetaData(
2660 op, a.dtype, args_dict, acc_out, error_name
2661 )
2662
2663 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002664
Luke Hutton57287132023-02-06 14:54:18 +00002665 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002666 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002667 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002668 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002669 inputs,
2670 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002671 validator_fcns=None,
2672 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002673 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002674 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002675 assert len(inputs) == 2
2676 val1, val2 = inputs
2677 inverse = args_dict["inverse"]
2678
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002679 results = OutputShaper.fft2dOp(self.ser, rng, val1, val2, error_name)
Luke Hutton57287132023-02-06 14:54:18 +00002680
2681 input_names = [val1.name, val2.name]
2682 pCount, cCount = op["operands"]
2683 num_operands = pCount + cCount
2684
2685 output_names = [res.name for res in results]
2686 output_shapes = [res.shape for res in results]
2687 output_dtypes = [res.dtype for res in results]
2688
2689 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002690 rng, error_name, input_names, output_names
Luke Hutton57287132023-02-06 14:54:18 +00002691 )
2692
2693 if not TosaErrorValidator.evValidateErrorIfs(
2694 self.ser,
2695 validator_fcns,
2696 error_name,
2697 op=op,
2698 inverse=inverse,
2699 input1=val1,
2700 input2=val2,
2701 input_shape=val1.shape,
2702 input_dtype=val1.dtype,
2703 output_shape=output_shapes,
2704 output_dtype=output_dtypes,
2705 result_tensors=results,
2706 input_list=input_names,
2707 output_list=output_names,
2708 num_operands=num_operands,
2709 ):
2710 return None
2711
Tai Lyd3797f02023-11-15 23:06:19 +00002712 # TODO - Test local_bound, for now set local bound attribute to False
2713 local_bound = False
2714
Luke Hutton57287132023-02-06 14:54:18 +00002715 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002716 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002717
2718 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002719
2720 compliance = []
2721 for res in results:
2722 compliance.append(
2723 self.tensorComplianceMetaData(
2724 op, val1.dtype, args_dict, res, error_name
2725 )
2726 )
2727
2728 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002729
Tai Lyd3797f02023-11-15 23:06:19 +00002730 def build_rfft2d(
2731 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002732 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002733 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002734 inputs,
2735 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002736 validator_fcns=None,
2737 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002738 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002739 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002740 assert len(inputs) == 1
2741 val = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002742 results = OutputShaper.rfft2dOp(self.ser, rng, val, error_name)
Luke Hutton261b7b62023-01-10 14:50:31 +00002743
2744 input_names = [val.name]
2745 pCount, cCount = op["operands"]
2746 num_operands = pCount + cCount
2747
2748 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002749 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002750 output_dtypes = [res.dtype for res in results]
2751
2752 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002753 rng, error_name, input_names, output_names
Luke Hutton261b7b62023-01-10 14:50:31 +00002754 )
2755
2756 if not TosaErrorValidator.evValidateErrorIfs(
2757 self.ser,
2758 validator_fcns,
2759 error_name,
2760 op=op,
2761 input_shape=val.shape,
2762 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002763 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002764 output_dtype=output_dtypes,
2765 result_tensors=results,
2766 input_list=input_names,
2767 output_list=output_names,
2768 num_operands=num_operands,
2769 ):
2770 return None
2771
Tai Lyd3797f02023-11-15 23:06:19 +00002772 # TODO - Test local_bound, for now set local bound attribute to False
2773 local_bound = False
2774
2775 attr = ts.TosaSerializerAttribute()
2776 attr.RFFTAttribute(local_bound)
2777
2778 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002779
2780 compliance = []
2781 for res in results:
2782 compliance.append(
2783 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2784 )
2785
2786 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002787
Won Jeon74342e52024-01-09 00:34:40 +00002788 def build_shape_op(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002789 self,
2790 rng,
2791 op,
2792 inputs,
2793 args_dict,
2794 validator_fcns=None,
2795 error_name=None,
2796 qinfo=None,
Won Jeon74342e52024-01-09 00:34:40 +00002797 ):
2798 assert len(inputs) == 2
2799 a, b = inputs
2800
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002801 result_tensor = OutputShaper.addShapeOp(self.ser, rng, a, b, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00002802
2803 # Invalidate Input/Output list for error if checks.
2804 input_list = [a.name, b.name]
2805 output_list = [result_tensor.name]
2806 pCount, cCount = op["operands"]
2807 num_operands = pCount + cCount
2808 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2809 self, error_name, input_list, output_list
2810 )
2811
2812 if not TosaErrorValidator.evValidateErrorIfs(
2813 self.ser,
2814 validator_fcns,
2815 error_name,
2816 op=op,
2817 input1=a,
2818 input2=b,
2819 input_shape=a.shape,
2820 input_dtype=a.dtype,
2821 output_shape=result_tensor.shape,
2822 output_dtype=result_tensor.dtype,
2823 result_tensors=[result_tensor],
2824 input_list=input_list,
2825 output_list=output_list,
2826 num_operands=num_operands,
2827 ):
2828 return None
2829
2830 self.ser.addOperator(
2831 op["op"],
2832 input_list,
2833 output_list,
2834 )
2835 compliance = self.tensorComplianceMetaData(
2836 op, a.dtype, args_dict, result_tensor, error_name
2837 )
2838
2839 return TosaTestGen.BuildInfo(result_tensor, compliance)
2840
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002841 def create_filter_lists(
2842 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2843 ):
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002844 # Create a default testing rank range
2845 if testType == "positive":
2846 # 0-3 inclusive to keep test sizes reasonably small.
2847 default_test_rank_range = range(0, 4)
2848 else:
2849 # Some errors do not work with rank 0, use 1-3
2850 default_test_rank_range = range(1, 4)
Matthew Haddon1c00b712021-10-01 15:51:03 +01002851
2852 # Calculate the filters based on what is requested and what the operator allows
2853 rmin, rmax = op["rank"]
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002854
2855 if shapeFilter:
2856 # Specified shapes - ignore rank filter and default to op ranks below
2857 rankFilter = None
2858 ranksToCheck = []
2859 elif rankFilter is None:
2860 # No set rank filter so ensure default behaviour is bounded
2861 ranksToCheck = default_test_rank_range
Matthew Haddon1c00b712021-10-01 15:51:03 +01002862 else:
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002863 ranksToCheck = rankFilter
2864
2865 cleanRankFilter = []
2866 # Ensure rank values are allowed by operator
2867 for rank in ranksToCheck:
2868 if rank >= rmin and rank <= rmax:
2869 cleanRankFilter.append(rank)
2870
2871 if shapeFilter or (len(cleanRankFilter) == 0 and rankFilter is None):
2872 # Shapes specified or default test ranks didn't meet
2873 # op requirements - so just use op ranks
Matthew Haddon1c00b712021-10-01 15:51:03 +01002874 cleanRankFilter = range(rmin, rmax + 1)
2875
2876 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002877
Matthew Haddon1c00b712021-10-01 15:51:03 +01002878 if dtypeFilter is not None:
2879 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002880 # Create list of operator dtypes filtered by requested dtypes
2881 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002882 if dtype in dtypeFilter or (
2883 isinstance(dtype, list) and dtype[0] in dtypeFilter
2884 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002885 cleanDtypeFilter.append(dtype)
2886 else:
2887 cleanDtypeFilter = dtypes
2888
Jeremy Johnson18a379d2024-03-28 15:53:21 +00002889 if not shapeFilter:
2890 shapeFilter = [None]
2891
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002892 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002893 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002894 "shapeFilter": shapeFilter,
2895 "rankFilter": cleanRankFilter,
2896 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002897 }
2898 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002899 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002900 if validator is not None:
2901 validator_info = validator(check=False, op=op)
2902 else:
2903 return None
2904
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002905 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002906
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002907 # Set parameters as required
2908 if error_arguments["rank"] is not None:
2909 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002910 else:
2911 rankFilter = cleanRankFilter
2912
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002913 if error_arguments["dtype"] is not None:
2914 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002915 else:
2916 dtypeFilter = cleanDtypeFilter
2917
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002918 if error_arguments["shape"] is not None:
2919 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002920 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002921 shapeFilter = shapeFilter[
2922 :2
2923 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002924
2925 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002926 "shapeFilter": shapeFilter,
2927 "rankFilter": rankFilter,
2928 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002929 }
2930 return filterDict
2931
Kevin Cheng550ccc52021-03-03 11:21:43 -08002932 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002933 self,
2934 opName,
2935 shapeFilter=[None],
2936 rankFilter=None,
2937 dtypeFilter=None,
2938 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002939 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002940
2941 try:
2942 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002943 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002944 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002945
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002946 if not self.args.stable_rng:
2947 # Initialize a new random number generator per op
2948 self.resetGlobalRNG()
Eric Kunzee5e26762020-10-13 16:11:07 -07002949
Jeremy Johnson1271c442023-09-05 11:39:26 +01002950 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002951
Eric Kunzee5e26762020-10-13 16:11:07 -07002952 # Test list consists of a tuple of:
2953 # (opName, testNameStr, dtype, shapeList, argumentsList)
2954 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002955 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002956 error_if_validators = op["error_if_validators"]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002957 num_error_types_created = 0
Matthew Haddon1c00b712021-10-01 15:51:03 +01002958 else:
2959 error_if_validators = [None]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002960 num_error_types_created = None
Eric Kunzee5e26762020-10-13 16:11:07 -07002961
Matthew Haddon1c00b712021-10-01 15:51:03 +01002962 for validator in error_if_validators:
2963 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002964 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002965 else:
2966 error_name = None
2967
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002968 filterDict = self.create_filter_lists(
2969 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2970 )
2971 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002972 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002973 cleanRankFilter = filterDict["rankFilter"]
2974 cleanDtypeFilter = filterDict["dtypeFilter"]
2975 cleanShapeFilter = filterDict["shapeFilter"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002976 logger.debug(
2977 f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
2978 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002979
2980 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002981 for t in cleanDtypeFilter:
2982 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002983 # Filter out by rank
2984 if shape is not None and len(shape) != r:
2985 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002986 self.setTargetShape(shape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002987 typeStr = self.typeStr(t)
2988 if self.args.stable_rng:
2989 shape_rng = TosaHashRandomGenerator(
2990 self.random_seed,
2991 [opName, r, typeStr],
2992 self.random_dtype_range,
2993 )
2994 else:
2995 shape_rng = self.global_rng
2996 shapeList = tgen_fcn(self, shape_rng, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002997
Matthew Haddon74567092021-07-16 15:38:20 +01002998 shapeStr = self.shapeStr(shapeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07002999
Matthew Haddon74567092021-07-16 15:38:20 +01003000 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
3001 argList = []
3002 if agen_fcn:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003003 if self.args.stable_rng:
3004 arg_rng = TosaHashRandomGenerator(
3005 self.random_seed,
3006 [opName, shapeStr, typeStr],
3007 self.random_dtype_range,
3008 )
3009 else:
3010 arg_rng = self.global_rng
3011
3012 argList = agen_fcn(
3013 self, arg_rng, opName, shapeList, t, error_name
3014 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003015 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003016 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003017
Matthew Haddon74567092021-07-16 15:38:20 +01003018 for argStr, args in argList:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003019 # Create the test name string - for example: add_1x2x3_i32
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003020 if testType == "positive":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003021 name_parts = [opName, shapeStr, typeStr]
3022 else:
3023 assert testType == "negative"
3024 name_parts = [
3025 opName,
3026 "ERRORIF",
3027 error_name,
3028 shapeStr,
3029 typeStr,
3030 ]
3031 if argStr:
3032 name_parts.append(argStr)
3033 testStr = "_".join(name_parts)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003034
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003035 testList.append(
3036 (opName, testStr, t, error_name, shapeList, args)
3037 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003038 if error_name is not None:
3039 # Check the last test is of the error we wanted
3040 if len(testList) == 0 or testList[-1][3] != error_name:
3041 if self.args.level8k:
3042 logger.info(f"Missing {error_name} tests due to level8k mode")
3043 else:
3044 logger.error(f"ERROR: Failed to create any {error_name} tests")
3045 logger.debug(
3046 "Last test created: {}".format(
3047 testList[-1] if testList else None
3048 )
3049 )
3050 else:
3051 # Successfully created at least one ERRROR_IF test
3052 num_error_types_created += 1
Matthew Haddon1c00b712021-10-01 15:51:03 +01003053
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003054 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01003055 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3056 if "invalid_test_validators" in op:
3057 invalid_test_validators = op["invalid_test_validators"]
3058 clean_testList = []
3059 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01003060 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003061 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003062 if validator_fcn(
3063 opName=test[0],
3064 input_dtype=test[2],
3065 shapeList=test[4],
3066 args=test[5],
3067 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003068 remove_test = True
3069 if not remove_test:
3070 clean_testList.append(test)
3071 testList = clean_testList
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003072 else:
3073 if num_error_types_created is not None and not self.args.level8k:
3074 remaining_error_types = (
3075 len(error_if_validators) - num_error_types_created
3076 )
3077 if remaining_error_types:
3078 raise Exception(
3079 f"Failed to create {remaining_error_types} error types for {opName}"
3080 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003081
3082 return testList
3083
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003084 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00003085 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003086 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003087 try:
3088 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003089 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003090 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003091
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003092 logger.info(f"Creating {testStr}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01003093
Eric Kunzee5e26762020-10-13 16:11:07 -07003094 # Create a serializer
3095 self.createSerializer(opName, testStr)
3096
Jeremy Johnson1271c442023-09-05 11:39:26 +01003097 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003098 if "error_if_validators" in op:
3099 error_if_validators = op["error_if_validators"]
3100 else:
3101 error_if_validators = None
3102
Kevin Cheng550ccc52021-03-03 11:21:43 -08003103 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003104 num_operands = pCount + cCount
3105
3106 if isinstance(dtype_or_dtypeList, list):
3107 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003108 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003109 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003110 else:
3111 dtypeList = [dtype_or_dtypeList] * (num_operands)
3112
Won Jeon74342e52024-01-09 00:34:40 +00003113 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003114 assert (
3115 len(shapeList) == num_operands
3116 ), "shapeList length {} must match number of operands {}".format(
3117 len(shapeList), num_operands
3118 )
3119 assert (
3120 len(dtypeList) == num_operands
3121 ), "dtypeList length {} must match number of operands {}".format(
3122 len(dtypeList), num_operands
3123 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003124
3125 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003126 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003127 except KeyError:
3128 qgen = None
3129
3130 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003131
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003132 # Set the random number generator
3133 if self.args.stable_rng:
3134 build_rng = TosaHashRandomGenerator(
3135 self.random_seed, [testStr], self.random_dtype_range
3136 )
3137 else:
3138 build_rng = self.global_rng
3139
Matthew Haddon1c00b712021-10-01 15:51:03 +01003140 if qgen is not None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003141 qinfo = qgen(
3142 build_rng, self.args.zeropoint, op, dtype_or_dtypeList, error_name
3143 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003144 else:
3145 qinfo = None
3146
Jeremy Johnson1271c442023-09-05 11:39:26 +01003147 # Extra meta data for the desc.json
3148 tensMeta = {}
3149
Jeremy Johnson587cc842024-02-08 11:45:44 +00003150 # Check we are using the new interface with an argsDict dictionary
3151 assert isinstance(
3152 argsDict, dict
3153 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003154
Jeremy Johnson587cc842024-02-08 11:45:44 +00003155 # New interface with args info in dictionary
3156 assert "dg_type" in argsDict
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003157 tvgInfo = tvgen_fcn(
3158 self, build_rng, opName, dtypeList, shapeList, argsDict, error_name
3159 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003160 if tvgInfo.dataGenDict:
3161 tensMeta["data_gen"] = tvgInfo.dataGenDict
3162 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003163
evacha01ad8e1e22024-03-19 12:42:17 +00003164 tags = argsDict.get("tags", None)
3165
Jeremy Johnson587cc842024-02-08 11:45:44 +00003166 result = build_fcn(
3167 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003168 build_rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003169 op,
3170 tens,
3171 argsDict,
3172 validator_fcns=error_if_validators,
3173 error_name=error_name,
3174 qinfo=qinfo,
3175 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003176
Jeremy Johnson1271c442023-09-05 11:39:26 +01003177 if result:
Les Bell729b0352021-11-24 10:28:21 +00003178 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003179 if isinstance(result, TosaTestGen.BuildInfo):
3180 # Add the compliance meta data (if any)
3181 compliance = result.getComplianceInfo()
3182 if compliance:
3183 tensMeta["compliance"] = compliance
evacha01ad8e1e22024-03-19 12:42:17 +00003184 self.serialize("test", tensMeta, tags)
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003185 return True
Les Bell729b0352021-11-24 10:28:21 +00003186 else:
3187 # The test is not valid
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003188 logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003189 return False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003190
Eric Kunzee5e26762020-10-13 16:11:07 -07003191 def createDynamicOpLists(self):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003192 # Find all the ops marked as templates
3193 templateKeys = []
3194 for opName in self.TOSA_OP_LIST:
Eric Kunzee5e26762020-10-13 16:11:07 -07003195 try:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003196 if self.TOSA_OP_LIST[opName]["template"]:
3197 templateKeys.append(opName)
Eric Kunzee5e26762020-10-13 16:11:07 -07003198 except KeyError:
3199 pass
3200
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003201 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3202
3203 # Add dynamic ops based on kernel sizes
3204 for opName in templateKeys:
3205 assert opName.endswith("_TEMPLATE"), "Found incorrect template"
3206 realName = opName[: len(opName) - len("_TEMPLATE")]
3207 template = self.TOSA_OP_LIST[opName]
3208 k_rank = 3 if realName == "conv3d" else 2
3209
3210 # Choose kernels to build tests for from the template or args
3211 if self.args.level8k:
3212 if k_rank == 3:
3213 kernels = [[1, bigK, 1], [2, 2, bigK]]
3214 else:
3215 kernels = [[1, bigK], [bigK, 2]]
3216 else:
3217 kernels = []
3218 if len(self.args.conv_kernels) > 0:
3219 kernels = [k for k in self.args.conv_kernels if len(k) == k_rank]
3220 if len(kernels) == 0:
3221 logger.debug(
3222 f"{realName} op using defaults as no rank {k_rank} kernels found in {self.args.conv_kernels}"
3223 )
3224 if len(kernels) == 0:
3225 # Fallback to use the defined template kernels
3226 kernels = self.TOSA_OP_LIST[opName]["filter"]
3227
3228 # Dynamically create ops for listed kernel sizes
3229 for k in kernels:
3230 kernelStr = "x".join([str(d) for d in k])
3231 testName = f"{realName}_{kernelStr}"
3232 kernelOp = template.copy()
3233 kernelOp["filter"] = k
3234 kernelOp["template"] = False
3235 kernelOp["real_name"] = realName
3236 self.TOSA_OP_LIST[testName] = kernelOp
3237
3238 # Delete the template after having created the dynamic ops
3239 del self.TOSA_OP_LIST[opName]
Eric Kunzee5e26762020-10-13 16:11:07 -07003240
3241 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003242 """Fill in default fields for ops if they aren't already specified.
3243 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003244 for op in self.TOSA_OP_LIST:
3245
3246 # Required fields
3247 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003248 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003249 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003250 raise Exception(
3251 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3252 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003253
3254 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003255 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003256 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003257 raise Exception(
3258 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3259 op
3260 )
3261 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003262
3263 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003264 _ = self.TOSA_OP_LIST[op]["types"]
3265 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003266 raise Exception(
3267 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3268 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003269
3270 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003271 _ = self.TOSA_OP_LIST[op]["op"]
3272 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003273 raise Exception(
3274 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3275 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003276
3277 # Put in default rank range, if missing
3278 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003279 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003280 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003281 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003282
3283 # Tensor operator list
3284 # 'op': op name
3285 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003286 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3287 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003288 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3289 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003290 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003291
Kevin Cheng550ccc52021-03-03 11:21:43 -08003292 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003293 TYPE_INT_FP = [
3294 DType.INT8,
3295 DType.INT16,
3296 DType.INT32,
3297 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003298 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003299 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003300 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003301
Kevin Cheng550ccc52021-03-03 11:21:43 -08003302 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003303 TYPE_FI32 = [
3304 DType.FP32,
3305 DType.FP16,
3306 DType.BF16,
3307 DType.INT32,
3308 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003309 TYPE_FIB = [
3310 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003311 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003312 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003313 DType.INT8,
3314 DType.INT16,
3315 DType.INT32,
3316 DType.BOOL,
3317 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003318 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003319
Won Jeon2c34b462024-02-06 18:37:00 +00003320 TYPE_NARROW_INT_FP = [
3321 DType.INT8,
3322 DType.INT16,
3323 DType.FP16,
3324 DType.BF16,
3325 DType.FP32,
3326 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003327
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003328 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003329 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003330 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003331 [DType.INT8, DType.INT8, DType.INT32],
3332 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003333 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003334 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003335 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003336 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003337 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3338 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003339 ]
3340
Jeremy Johnson18a379d2024-03-28 15:53:21 +00003341 DEFAULT_RANK_RANGE = (0, gtu.MAX_TENSOR_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003342
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003343 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3344 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3345
evacha01ad8e1e22024-03-19 12:42:17 +00003346 PSEUDO_RANDOM_DATAGEN = {
3347 DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM,),
3348 DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM,),
3349 }
3350 DOT_PRODUCT_DATAGEN = {
3351 DType.FP16: (gtu.DataGenType.DOT_PRODUCT,),
3352 DType.FP32: (gtu.DataGenType.DOT_PRODUCT,),
3353 }
3354 EW_UNARY_DATAGEN = {
evacha014a205112024-03-08 16:39:24 +00003355 DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE),
3356 }
3357 PR_FS_DATAGEN = {
3358 DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL),
3359 DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL),
evacha01ad8e1e22024-03-19 12:42:17 +00003360 }
3361
Eric Kunzee5e26762020-10-13 16:11:07 -07003362 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003363 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003364 "argmax": {
3365 "op": Op.ARGMAX,
3366 "operands": (1, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00003367 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003368 "build_fcn": (
3369 build_argmax,
3370 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003371 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003372 TosaArgGen.agAxis,
3373 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003374 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003375 "error_if_validators": (
3376 TosaErrorValidator.evAxisSmallerZero,
3377 TosaErrorValidator.evAxisLargerRank,
3378 TosaErrorValidator.evArgmaxOutputRankMismatch,
3379 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3380 TosaErrorValidator.evWrongRank,
3381 TosaErrorValidator.evWrongInputType,
3382 TosaErrorValidator.evWrongOutputType,
3383 TosaErrorValidator.evWrongInputList,
3384 TosaErrorValidator.evWrongOutputList,
3385 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003386 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003387 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003388 "avg_pool2d": {
3389 "op": Op.AVG_POOL2D,
3390 "operands": (1, 0),
3391 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003392 "build_fcn": (
3393 build_pool2d,
3394 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003395 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003396 TosaArgGen.agPooling,
3397 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003398 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003399 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003400 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003401 "error_if_validators": (
3402 TosaErrorValidator.evKernelSmallerOne,
3403 TosaErrorValidator.evStrideSmallerOne,
3404 TosaErrorValidator.evPadSmallerZero,
3405 TosaErrorValidator.evWrongRank,
3406 TosaErrorValidator.evWrongInputType,
3407 TosaErrorValidator.evWrongOutputType,
3408 TosaErrorValidator.evWrongInputList,
3409 TosaErrorValidator.evWrongOutputList,
3410 TosaErrorValidator.evInputZeroPointNotZero,
3411 TosaErrorValidator.evOutputZeroPointNotZero,
3412 TosaErrorValidator.evPadLargerEqualKernel,
3413 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003414 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003415 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003416 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003417 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003418 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003419 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003420 "conv2d_TEMPLATE": {
3421 "op": Op.CONV2D,
3422 "operands": (1, 2),
3423 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003424 "build_fcn": (
3425 build_conv2d,
3426 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003427 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003428 TosaArgGen.agConv,
3429 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003430 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003431 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003432 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3433 "error_if_validators": (
3434 TosaErrorValidator.evWrongInputType,
3435 TosaErrorValidator.evWrongOutputType,
3436 TosaErrorValidator.evWrongInputList,
3437 TosaErrorValidator.evWrongOutputList,
3438 TosaErrorValidator.evInputZeroPointNotZero,
3439 TosaErrorValidator.evWeightZeroPointNotZero,
3440 TosaErrorValidator.evPadSmallerZero,
3441 TosaErrorValidator.evStrideSmallerOne,
3442 TosaErrorValidator.evDilationSmallerOne,
3443 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003444 TosaErrorValidator.evConvOutputShapeMismatch,
3445 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003446 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003447 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003448 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003449 "broadcastable_bias": True,
3450 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003451 "template": True,
3452 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003453 # Templated operator. Filled in by createDynamicOpLists
3454 "conv3d_TEMPLATE": {
3455 "op": Op.CONV3D,
3456 "operands": (1, 2),
3457 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003458 "build_fcn": (
3459 build_conv3d,
3460 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003461 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003462 TosaArgGen.agConv,
3463 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003464 "qgen": TosaQuantGen.qgConv,
3465 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003466 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3467 "error_if_validators": (
3468 TosaErrorValidator.evWrongInputType,
3469 TosaErrorValidator.evWrongOutputType,
3470 TosaErrorValidator.evWrongInputList,
3471 TosaErrorValidator.evWrongOutputList,
3472 TosaErrorValidator.evInputZeroPointNotZero,
3473 TosaErrorValidator.evWeightZeroPointNotZero,
3474 TosaErrorValidator.evPadSmallerZero,
3475 TosaErrorValidator.evStrideSmallerOne,
3476 TosaErrorValidator.evDilationSmallerOne,
3477 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003478 TosaErrorValidator.evConvOutputShapeMismatch,
3479 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003480 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003481 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003482 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003483 "filter": KERNELS_3D,
Kevin Cheng1533b852021-09-01 12:51:58 -07003484 "template": True,
3485 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003486 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003487 "depthwise_conv2d_TEMPLATE": {
3488 "op": Op.DEPTHWISE_CONV2D,
3489 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003490 "rank": (4, 4),
3491 "build_fcn": (
3492 build_depthwise_conv2d,
3493 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003494 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003495 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003496 ),
3497 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003498 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003499 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3500 "error_if_validators": (
3501 TosaErrorValidator.evWrongInputType,
3502 TosaErrorValidator.evWrongOutputType,
3503 TosaErrorValidator.evWrongInputList,
3504 TosaErrorValidator.evWrongOutputList,
3505 TosaErrorValidator.evInputZeroPointNotZero,
3506 TosaErrorValidator.evWeightZeroPointNotZero,
3507 TosaErrorValidator.evPadSmallerZero,
3508 TosaErrorValidator.evStrideSmallerOne,
3509 TosaErrorValidator.evDilationSmallerOne,
3510 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003511 TosaErrorValidator.evConvOutputShapeMismatch,
3512 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003513 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003514 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003515 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003516 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003517 "template": True,
3518 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003519 "fully_connected": {
3520 "op": Op.FULLY_CONNECTED,
3521 "operands": (1, 2),
3522 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003523 "build_fcn": (
3524 build_fully_connected,
3525 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003526 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003527 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003528 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003529 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003530 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003531 "error_if_validators": (
3532 TosaErrorValidator.evInputZeroPointNotZero,
3533 TosaErrorValidator.evWeightZeroPointNotZero,
3534 TosaErrorValidator.evWrongRank,
3535 TosaErrorValidator.evWrongInputType,
3536 TosaErrorValidator.evWrongOutputType,
3537 TosaErrorValidator.evWrongInputList,
3538 TosaErrorValidator.evWrongOutputList,
3539 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003540 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003541 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003542 "matmul": {
3543 "op": Op.MATMUL,
3544 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003545 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003546 "build_fcn": (
3547 build_matmul,
3548 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003549 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003550 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003551 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003552 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003553 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003554 "error_if_validators": (
3555 TosaErrorValidator.evInputZeroPointNotZero,
3556 TosaErrorValidator.evWrongRank,
3557 TosaErrorValidator.evWrongInputType,
3558 TosaErrorValidator.evWrongOutputType,
3559 TosaErrorValidator.evWrongInputList,
3560 TosaErrorValidator.evWrongOutputList,
3561 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003562 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003563 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003564 "max_pool2d": {
3565 "op": Op.MAX_POOL2D,
3566 "operands": (1, 0),
3567 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003568 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003569 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003570 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003571 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003572 TosaArgGen.agPooling,
3573 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003574 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003575 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003576 "error_if_validators": (
3577 TosaErrorValidator.evKernelSmallerOne,
3578 TosaErrorValidator.evStrideSmallerOne,
3579 TosaErrorValidator.evPadSmallerZero,
3580 TosaErrorValidator.evWrongRank,
3581 TosaErrorValidator.evWrongInputType,
3582 TosaErrorValidator.evWrongOutputType,
3583 TosaErrorValidator.evWrongInputList,
3584 TosaErrorValidator.evWrongOutputList,
3585 TosaErrorValidator.evPadLargerEqualKernel,
3586 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003587 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003588 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003589 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003590 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003591 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003592 "transpose_conv2d_TEMPLATE": {
3593 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003594 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003595 "rank": (4, 4),
3596 "build_fcn": (
3597 build_transpose_conv2d,
3598 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003599 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003600 TosaArgGen.agTransposeConv2D,
3601 ),
3602 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003603 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003604 "invalid_test_validators": (
3605 TosaInvalidValidator.ivHeightWidthInvalid,
3606 TosaInvalidValidator.ivNonPositiveOutputShape,
3607 ),
3608 "error_if_validators": (
3609 TosaErrorValidator.evWrongInputType,
3610 TosaErrorValidator.evWrongOutputType,
3611 TosaErrorValidator.evWrongInputList,
3612 TosaErrorValidator.evWrongOutputList,
3613 TosaErrorValidator.evInputZeroPointNotZero,
3614 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003615 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003616 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003617 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003618 TosaErrorValidator.evConvOutputShapeMismatch,
Tai Lyf36f2562024-03-14 16:21:29 +00003619 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003620 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003621 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003622 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003623 "template": True,
3624 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003625 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003626 "clamp": {
3627 "op": Op.CLAMP,
3628 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003629 "build_fcn": (
3630 build_clamp,
3631 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003632 TosaTensorValuesGen.tvgLazyGenDefault,
3633 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003634 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003635 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003636 "error_if_validators": (
3637 TosaErrorValidator.evMaxSmallerMin,
3638 TosaErrorValidator.evWrongInputType,
3639 TosaErrorValidator.evWrongOutputType,
3640 TosaErrorValidator.evWrongInputList,
3641 TosaErrorValidator.evWrongOutputList,
3642 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003643 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003644 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003645 "sigmoid": {
3646 "op": Op.SIGMOID,
3647 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003648 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003649 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003650 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003651 TosaTensorValuesGen.tvgLazyGenDefault,
3652 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003653 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003654 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003655 "error_if_validators": (
3656 TosaErrorValidator.evWrongInputType,
3657 TosaErrorValidator.evWrongOutputType,
3658 TosaErrorValidator.evWrongInputList,
3659 TosaErrorValidator.evWrongOutputList,
3660 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003661 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003662 },
3663 "tanh": {
3664 "op": Op.TANH,
3665 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003666 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003667 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003668 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003669 TosaTensorValuesGen.tvgLazyGenDefault,
3670 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003671 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003672 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003673 "error_if_validators": (
3674 TosaErrorValidator.evWrongInputType,
3675 TosaErrorValidator.evWrongOutputType,
3676 TosaErrorValidator.evWrongInputList,
3677 TosaErrorValidator.evWrongOutputList,
3678 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003679 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003680 "compliance": {
3681 "abs_error_lower_bound": 0.5,
3682 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003683 },
Won Jeon78155c62023-06-10 00:20:04 +00003684 "erf": {
3685 "op": Op.ERF,
3686 "operands": (1, 0),
3687 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003688 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003689 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003690 TosaTensorValuesGen.tvgLazyGenDefault,
3691 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003692 ),
3693 "types": TYPE_FP,
3694 "error_if_validators": (
3695 TosaErrorValidator.evWrongInputType,
3696 TosaErrorValidator.evWrongOutputType,
3697 TosaErrorValidator.evWrongInputList,
3698 TosaErrorValidator.evWrongOutputList,
3699 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003700 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003701 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003702 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003703 # Elementwise Binary Operators
3704 "add": {
3705 "op": Op.ADD,
3706 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003707 "build_fcn": (
3708 build_binary_broadcast,
3709 TosaTensorGen.tgBroadcastFuzz,
3710 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003711 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003712 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003713 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003714 "error_if_validators": (
3715 TosaErrorValidator.evRankMismatch,
3716 TosaErrorValidator.evWrongInputType,
3717 TosaErrorValidator.evWrongOutputType,
3718 TosaErrorValidator.evWrongInputList,
3719 TosaErrorValidator.evWrongOutputList,
3720 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003721 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003722 ),
evacha014a205112024-03-08 16:39:24 +00003723 "data_gen": PR_FS_DATAGEN,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003724 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003725 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003726 "arithmetic_right_shift": {
3727 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3728 "operands": (2, 0),
3729 "build_fcn": (
3730 build_arithmetic_right_shift,
3731 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 TosaArgGen.agArithmeticRightShift,
3734 ),
3735 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003736 "error_if_validators": (
3737 TosaErrorValidator.evRankMismatch,
3738 TosaErrorValidator.evWrongInputType,
3739 TosaErrorValidator.evWrongOutputType,
3740 TosaErrorValidator.evWrongInputList,
3741 TosaErrorValidator.evWrongOutputList,
3742 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003743 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003744 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003745 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003746 "bitwise_and": {
3747 "op": Op.BITWISE_AND,
3748 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003749 "build_fcn": (
3750 build_binary_broadcast,
3751 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003752 TosaTensorValuesGen.tvgLazyGenDefault,
3753 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003754 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003755 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003756 "error_if_validators": (
3757 TosaErrorValidator.evRankMismatch,
3758 TosaErrorValidator.evWrongInputType,
3759 TosaErrorValidator.evWrongOutputType,
3760 TosaErrorValidator.evWrongInputList,
3761 TosaErrorValidator.evWrongOutputList,
3762 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003763 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003764 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003765 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003766 "bitwise_or": {
3767 "op": Op.BITWISE_OR,
3768 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003769 "build_fcn": (
3770 build_binary_broadcast,
3771 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003772 TosaTensorValuesGen.tvgLazyGenDefault,
3773 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003774 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003775 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003776 "error_if_validators": (
3777 TosaErrorValidator.evRankMismatch,
3778 TosaErrorValidator.evWrongInputType,
3779 TosaErrorValidator.evWrongOutputType,
3780 TosaErrorValidator.evWrongInputList,
3781 TosaErrorValidator.evWrongOutputList,
3782 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003783 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003784 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003785 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003786 "bitwise_xor": {
3787 "op": Op.BITWISE_XOR,
3788 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 "build_fcn": (
3790 build_binary_broadcast,
3791 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003792 TosaTensorValuesGen.tvgLazyGenDefault,
3793 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003794 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003795 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003796 "error_if_validators": (
3797 TosaErrorValidator.evRankMismatch,
3798 TosaErrorValidator.evWrongInputType,
3799 TosaErrorValidator.evWrongOutputType,
3800 TosaErrorValidator.evWrongInputList,
3801 TosaErrorValidator.evWrongOutputList,
3802 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003803 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003804 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003805 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003806 "intdiv": {
3807 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003808 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003809 "build_fcn": (
3810 build_binary_broadcast,
3811 TosaTensorGen.tgBroadcastFuzz,
3812 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003813 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003814 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003815 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003816 "error_if_validators": (
3817 TosaErrorValidator.evRankMismatch,
3818 TosaErrorValidator.evWrongInputType,
3819 TosaErrorValidator.evWrongOutputType,
3820 TosaErrorValidator.evWrongInputList,
3821 TosaErrorValidator.evWrongOutputList,
3822 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003823 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003824 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003825 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003826 "logical_and": {
3827 "op": Op.LOGICAL_AND,
3828 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003829 "build_fcn": (
3830 build_binary_broadcast,
3831 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003832 TosaTensorValuesGen.tvgLazyGenDefault,
3833 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003834 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003835 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003836 "error_if_validators": (
3837 TosaErrorValidator.evRankMismatch,
3838 TosaErrorValidator.evWrongInputType,
3839 TosaErrorValidator.evWrongOutputType,
3840 TosaErrorValidator.evWrongInputList,
3841 TosaErrorValidator.evWrongOutputList,
3842 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003843 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003844 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003845 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003846 "logical_left_shift": {
3847 "op": Op.LOGICAL_LEFT_SHIFT,
3848 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003849 "build_fcn": (
3850 build_binary_broadcast,
3851 TosaTensorGen.tgBroadcastFuzz,
3852 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003853 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003854 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003855 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003856 "error_if_validators": (
3857 TosaErrorValidator.evRankMismatch,
3858 TosaErrorValidator.evWrongInputType,
3859 TosaErrorValidator.evWrongOutputType,
3860 TosaErrorValidator.evWrongInputList,
3861 TosaErrorValidator.evWrongOutputList,
3862 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003863 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003864 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003865 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003866 "logical_right_shift": {
3867 "op": Op.LOGICAL_RIGHT_SHIFT,
3868 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003869 "build_fcn": (
3870 build_binary_broadcast,
3871 TosaTensorGen.tgBroadcastFuzz,
3872 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003873 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003874 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003875 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003876 "error_if_validators": (
3877 TosaErrorValidator.evRankMismatch,
3878 TosaErrorValidator.evWrongInputType,
3879 TosaErrorValidator.evWrongOutputType,
3880 TosaErrorValidator.evWrongInputList,
3881 TosaErrorValidator.evWrongOutputList,
3882 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003883 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003884 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003885 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003886 "logical_or": {
3887 "op": Op.LOGICAL_OR,
3888 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003889 "build_fcn": (
3890 build_binary_broadcast,
3891 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003892 TosaTensorValuesGen.tvgLazyGenDefault,
3893 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003894 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003895 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003896 "error_if_validators": (
3897 TosaErrorValidator.evRankMismatch,
3898 TosaErrorValidator.evWrongInputType,
3899 TosaErrorValidator.evWrongOutputType,
3900 TosaErrorValidator.evWrongInputList,
3901 TosaErrorValidator.evWrongOutputList,
3902 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003903 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003904 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003905 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003906 "logical_xor": {
3907 "op": Op.LOGICAL_XOR,
3908 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003909 "build_fcn": (
3910 build_binary_broadcast,
3911 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003912 TosaTensorValuesGen.tvgLazyGenDefault,
3913 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003914 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003915 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003916 "error_if_validators": (
3917 TosaErrorValidator.evRankMismatch,
3918 TosaErrorValidator.evWrongInputType,
3919 TosaErrorValidator.evWrongOutputType,
3920 TosaErrorValidator.evWrongInputList,
3921 TosaErrorValidator.evWrongOutputList,
3922 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003923 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003924 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003925 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003926 "maximum": {
3927 "op": Op.MAXIMUM,
3928 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003929 "build_fcn": (
3930 build_binary_broadcast,
3931 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003932 TosaTensorValuesGen.tvgLazyGenDefault,
3933 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003934 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003935 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003936 "error_if_validators": (
3937 TosaErrorValidator.evRankMismatch,
3938 TosaErrorValidator.evWrongInputType,
3939 TosaErrorValidator.evWrongOutputType,
3940 TosaErrorValidator.evWrongInputList,
3941 TosaErrorValidator.evWrongOutputList,
3942 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003943 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003944 ),
evacha014a205112024-03-08 16:39:24 +00003945 "data_gen": PR_FS_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003947 "minimum": {
3948 "op": Op.MINIMUM,
3949 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003950 "build_fcn": (
3951 build_binary_broadcast,
3952 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003953 TosaTensorValuesGen.tvgLazyGenDefault,
3954 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003955 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003956 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003957 "error_if_validators": (
3958 TosaErrorValidator.evRankMismatch,
3959 TosaErrorValidator.evWrongInputType,
3960 TosaErrorValidator.evWrongOutputType,
3961 TosaErrorValidator.evWrongInputList,
3962 TosaErrorValidator.evWrongOutputList,
3963 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003964 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003965 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003966 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003967 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003968 "mul": {
3969 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003970 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003971 "build_fcn": (
3972 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003973 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003974 TosaTensorValuesGen.tvgMul,
3975 TosaArgGen.agMul,
3976 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003977 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003978 "error_if_validators": (
3979 TosaErrorValidator.evWrongInputType,
3980 TosaErrorValidator.evWrongOutputType,
3981 TosaErrorValidator.evWrongInputList,
3982 TosaErrorValidator.evWrongOutputList,
3983 TosaErrorValidator.evRankMismatch,
3984 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003985 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003986 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003987 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003988 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003989 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 "pow": {
3991 "op": Op.POW,
3992 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003993 "build_fcn": (
3994 build_binary_broadcast,
3995 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003996 TosaTensorValuesGen.tvgPow,
3997 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003998 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003999 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004000 "error_if_validators": (
4001 TosaErrorValidator.evRankMismatch,
4002 TosaErrorValidator.evWrongInputType,
4003 TosaErrorValidator.evWrongOutputType,
4004 TosaErrorValidator.evWrongInputList,
4005 TosaErrorValidator.evWrongOutputList,
4006 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004007 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004008 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004009 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004010 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004011 "sub": {
4012 "op": Op.SUB,
4013 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004014 "build_fcn": (
4015 build_binary_broadcast,
4016 TosaTensorGen.tgBroadcastFuzz,
4017 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004018 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004019 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004020 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004021 "error_if_validators": (
4022 TosaErrorValidator.evRankMismatch,
4023 TosaErrorValidator.evWrongInputType,
4024 TosaErrorValidator.evWrongOutputType,
4025 TosaErrorValidator.evWrongInputList,
4026 TosaErrorValidator.evWrongOutputList,
4027 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004028 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004029 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004030 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004031 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004032 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004033 "table": {
4034 "op": Op.TABLE,
4035 # Use the automatic generation functions to create the input array
4036 # but create the table tensor in the build function, as it may be
4037 # a different type from the input
4038 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004039 "build_fcn": (
4040 build_table,
4041 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004042 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004043 TosaArgGen.agTable,
4044 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004045 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004046 "error_if_validators": (
4047 TosaErrorValidator.evWrongInputType,
4048 TosaErrorValidator.evWrongOutputType,
4049 TosaErrorValidator.evWrongInputList,
4050 TosaErrorValidator.evWrongOutputList,
4051 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004052 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004053 # Elementwise Unary operators
4054 "abs": {
4055 "op": Op.ABS,
4056 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004057 "build_fcn": (
4058 build_unary,
4059 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004060 TosaTensorValuesGen.tvgLazyGenDefault,
4061 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004062 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004063 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004064 "error_if_validators": (
4065 TosaErrorValidator.evWrongInputType,
4066 TosaErrorValidator.evWrongOutputType,
4067 TosaErrorValidator.evWrongInputList,
4068 TosaErrorValidator.evWrongOutputList,
4069 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004070 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004071 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 "bitwise_not": {
4073 "op": Op.BITWISE_NOT,
4074 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004075 "build_fcn": (
4076 build_unary,
4077 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004078 TosaTensorValuesGen.tvgLazyGenDefault,
4079 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004080 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004081 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004082 "error_if_validators": (
4083 TosaErrorValidator.evWrongInputType,
4084 TosaErrorValidator.evWrongOutputType,
4085 TosaErrorValidator.evWrongInputList,
4086 TosaErrorValidator.evWrongOutputList,
4087 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004088 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004089 "ceil": {
4090 "op": Op.CEIL,
4091 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004092 "build_fcn": (
4093 build_unary,
4094 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004095 TosaTensorValuesGen.tvgLazyGenDefault,
4096 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004097 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004098 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004099 "error_if_validators": (
4100 TosaErrorValidator.evWrongInputType,
4101 TosaErrorValidator.evWrongOutputType,
4102 TosaErrorValidator.evWrongInputList,
4103 TosaErrorValidator.evWrongOutputList,
4104 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004105 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004106 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004107 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004108 "clz": {
4109 "op": Op.CLZ,
4110 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004111 "build_fcn": (
4112 build_unary,
4113 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004114 TosaTensorValuesGen.tvgLazyGenDefault,
4115 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004116 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004117 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004118 "error_if_validators": (
4119 TosaErrorValidator.evWrongInputType,
4120 TosaErrorValidator.evWrongOutputType,
4121 TosaErrorValidator.evWrongInputList,
4122 TosaErrorValidator.evWrongOutputList,
4123 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004124 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004125 "cos": {
4126 "op": Op.COS,
4127 "operands": (1, 0),
4128 "build_fcn": (
4129 build_unary,
4130 TosaTensorGen.tgBasic,
4131 TosaTensorValuesGen.tvgLazyGenDefault,
4132 TosaArgGen.agNone,
4133 ),
4134 "types": TYPE_FP,
4135 "error_if_validators": (
4136 TosaErrorValidator.evWrongInputType,
4137 TosaErrorValidator.evWrongOutputType,
4138 TosaErrorValidator.evWrongInputList,
4139 TosaErrorValidator.evWrongOutputList,
4140 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004141 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jerry Ge51bd4f52024-02-20 11:21:19 -08004142 "compliance": {"abs_error_normal_divisor": 2},
4143 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004144 "exp": {
4145 "op": Op.EXP,
4146 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004147 "build_fcn": (
4148 build_unary,
4149 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004150 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004151 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004152 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004153 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004154 "error_if_validators": (
4155 TosaErrorValidator.evWrongInputType,
4156 TosaErrorValidator.evWrongOutputType,
4157 TosaErrorValidator.evWrongInputList,
4158 TosaErrorValidator.evWrongOutputList,
4159 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004160 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004161 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004162 "floor": {
4163 "op": Op.FLOOR,
4164 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004165 "build_fcn": (
4166 build_unary,
4167 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004168 TosaTensorValuesGen.tvgLazyGenDefault,
4169 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004170 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004171 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004172 "error_if_validators": (
4173 TosaErrorValidator.evWrongInputType,
4174 TosaErrorValidator.evWrongOutputType,
4175 TosaErrorValidator.evWrongInputList,
4176 TosaErrorValidator.evWrongOutputList,
4177 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004178 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004179 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004180 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004181 "log": {
4182 "op": Op.LOG,
4183 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004184 "build_fcn": (
4185 build_unary,
4186 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004187 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004188 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004189 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004190 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004191 "error_if_validators": (
4192 TosaErrorValidator.evWrongInputType,
4193 TosaErrorValidator.evWrongOutputType,
4194 TosaErrorValidator.evWrongInputList,
4195 TosaErrorValidator.evWrongOutputList,
4196 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004197 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004198 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004199 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004200 "logical_not": {
4201 "op": Op.LOGICAL_NOT,
4202 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004203 "build_fcn": (
4204 build_unary,
4205 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004206 TosaTensorValuesGen.tvgLazyGenDefault,
4207 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004208 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004209 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004210 "error_if_validators": (
4211 TosaErrorValidator.evWrongInputType,
4212 TosaErrorValidator.evWrongOutputType,
4213 TosaErrorValidator.evWrongInputList,
4214 TosaErrorValidator.evWrongOutputList,
4215 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004216 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004217 "negate": {
4218 "op": Op.NEGATE,
4219 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004220 "build_fcn": (
4221 build_unary,
4222 TosaTensorGen.tgBasic,
4223 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004224 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004225 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004226 "qgen": TosaQuantGen.qgUnary,
4227 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004228 "error_if_validators": (
4229 TosaErrorValidator.evInputZeroPointNotZero,
4230 TosaErrorValidator.evOutputZeroPointNotZero,
4231 TosaErrorValidator.evWrongInputType,
4232 TosaErrorValidator.evWrongOutputType,
4233 TosaErrorValidator.evWrongInputList,
4234 TosaErrorValidator.evWrongOutputList,
4235 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004236 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004237 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004238 "reciprocal": {
4239 "op": Op.RECIPROCAL,
4240 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004241 "build_fcn": (
4242 build_unary,
4243 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004244 TosaTensorValuesGen.tvgLazyGenDefault,
4245 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004246 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004247 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004248 "error_if_validators": (
4249 TosaErrorValidator.evWrongInputType,
4250 TosaErrorValidator.evWrongOutputType,
4251 TosaErrorValidator.evWrongInputList,
4252 TosaErrorValidator.evWrongOutputList,
4253 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004254 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004255 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004256 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004257 "rsqrt": {
4258 "op": Op.RSQRT,
4259 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004260 "build_fcn": (
4261 build_unary,
4262 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004263 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004264 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004265 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004266 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004267 "error_if_validators": (
4268 TosaErrorValidator.evWrongInputType,
4269 TosaErrorValidator.evWrongOutputType,
4270 TosaErrorValidator.evWrongInputList,
4271 TosaErrorValidator.evWrongOutputList,
4272 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004273 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004274 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004275 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004276 "sin": {
4277 "op": Op.SIN,
4278 "operands": (1, 0),
4279 "build_fcn": (
4280 build_unary,
4281 TosaTensorGen.tgBasic,
4282 TosaTensorValuesGen.tvgLazyGenDefault,
4283 TosaArgGen.agNone,
4284 ),
4285 "types": TYPE_FP,
4286 "error_if_validators": (
4287 TosaErrorValidator.evWrongInputType,
4288 TosaErrorValidator.evWrongOutputType,
4289 TosaErrorValidator.evWrongInputList,
4290 TosaErrorValidator.evWrongOutputList,
4291 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004292 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jerry Ge51bd4f52024-02-20 11:21:19 -08004293 "compliance": {"abs_error_normal_divisor": 2},
4294 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004295 # Elementwise Ternary operators
4296 "select": {
4297 "op": Op.SELECT,
4298 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004299 "build_fcn": (
4300 build_select,
4301 TosaTensorGen.tgBroadcastFuzz,
4302 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004303 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004304 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004305 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004306 "error_if_validators": (
4307 TosaErrorValidator.evRankMismatch,
4308 TosaErrorValidator.evWrongInputType,
4309 TosaErrorValidator.evWrongOutputType,
4310 TosaErrorValidator.evWrongInputList,
4311 TosaErrorValidator.evWrongOutputList,
4312 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004313 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004314 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004315 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004316 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004317 # Comparison operators
4318 "equal": {
4319 "op": Op.EQUAL,
4320 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004321 "build_fcn": (
4322 build_comparison,
4323 TosaTensorGen.tgBroadcastFuzz,
4324 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004325 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004326 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004327 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004328 "error_if_validators": (
4329 TosaErrorValidator.evRankMismatch,
4330 TosaErrorValidator.evWrongInputType,
4331 TosaErrorValidator.evWrongOutputType,
4332 TosaErrorValidator.evWrongInputList,
4333 TosaErrorValidator.evWrongOutputList,
4334 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004335 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004336 ),
evacha014a205112024-03-08 16:39:24 +00004337 "data_gen": PR_FS_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004338 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004339 "greater_equal": {
4340 "op": Op.GREATER_EQUAL,
4341 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004342 "build_fcn": (
4343 build_comparison,
4344 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004345 TosaTensorValuesGen.tvgLazyGenDefault,
4346 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004347 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004348 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004349 "error_if_validators": (
4350 TosaErrorValidator.evRankMismatch,
4351 TosaErrorValidator.evWrongInputType,
4352 TosaErrorValidator.evWrongOutputType,
4353 TosaErrorValidator.evWrongInputList,
4354 TosaErrorValidator.evWrongOutputList,
4355 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004356 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004357 ),
evacha014a205112024-03-08 16:39:24 +00004358 "data_gen": PR_FS_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004359 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004360 "greater": {
4361 "op": Op.GREATER,
4362 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004363 "build_fcn": (
4364 build_comparison,
4365 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004366 TosaTensorValuesGen.tvgLazyGenDefault,
4367 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004368 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004369 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004370 "error_if_validators": (
4371 TosaErrorValidator.evRankMismatch,
4372 TosaErrorValidator.evWrongInputType,
4373 TosaErrorValidator.evWrongOutputType,
4374 TosaErrorValidator.evWrongInputList,
4375 TosaErrorValidator.evWrongOutputList,
4376 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004377 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004378 ),
evacha014a205112024-03-08 16:39:24 +00004379 "data_gen": PR_FS_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004380 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004381 # Reduction operators
4382 "reduce_all": {
4383 "op": Op.REDUCE_ALL,
4384 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004385 "build_fcn": (
4386 build_reduce,
4387 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004388 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004389 TosaArgGen.agAxis,
4390 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004391 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004392 "error_if_validators": (
4393 TosaErrorValidator.evAxisLargerRank,
4394 TosaErrorValidator.evAxisSmallerZero,
4395 TosaErrorValidator.evShapeOfAxisNotOne,
4396 TosaErrorValidator.evWrongInputType,
4397 TosaErrorValidator.evWrongOutputType,
4398 TosaErrorValidator.evWrongRank,
4399 TosaErrorValidator.evWrongInputList,
4400 TosaErrorValidator.evWrongOutputList,
4401 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004402 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004403 "reduce_any": {
4404 "op": Op.REDUCE_ANY,
4405 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004406 "build_fcn": (
4407 build_reduce,
4408 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004409 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004410 TosaArgGen.agAxis,
4411 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004412 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004413 "error_if_validators": (
4414 TosaErrorValidator.evAxisLargerRank,
4415 TosaErrorValidator.evAxisSmallerZero,
4416 TosaErrorValidator.evShapeOfAxisNotOne,
4417 TosaErrorValidator.evWrongInputType,
4418 TosaErrorValidator.evWrongOutputType,
4419 TosaErrorValidator.evWrongRank,
4420 TosaErrorValidator.evWrongInputList,
4421 TosaErrorValidator.evWrongOutputList,
4422 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004423 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004424 "reduce_max": {
4425 "op": Op.REDUCE_MAX,
4426 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004427 "build_fcn": (
4428 build_reduce,
4429 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004430 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004431 TosaArgGen.agAxis,
4432 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004433 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004434 "error_if_validators": (
4435 TosaErrorValidator.evAxisLargerRank,
4436 TosaErrorValidator.evAxisSmallerZero,
4437 TosaErrorValidator.evShapeOfAxisNotOne,
4438 TosaErrorValidator.evWrongInputType,
4439 TosaErrorValidator.evWrongOutputType,
4440 TosaErrorValidator.evWrongRank,
4441 TosaErrorValidator.evWrongInputList,
4442 TosaErrorValidator.evWrongOutputList,
4443 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004444 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004445 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004446 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004447 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004448 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004449 "build_fcn": (
4450 build_reduce,
4451 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004452 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004453 TosaArgGen.agAxis,
4454 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004455 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004456 "error_if_validators": (
4457 TosaErrorValidator.evAxisLargerRank,
4458 TosaErrorValidator.evAxisSmallerZero,
4459 TosaErrorValidator.evShapeOfAxisNotOne,
4460 TosaErrorValidator.evWrongInputType,
4461 TosaErrorValidator.evWrongOutputType,
4462 TosaErrorValidator.evWrongRank,
4463 TosaErrorValidator.evWrongInputList,
4464 TosaErrorValidator.evWrongOutputList,
4465 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004466 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004467 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004468 "reduce_product": {
4469 "op": Op.REDUCE_PRODUCT,
4470 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004471 "build_fcn": (
4472 build_reduce,
4473 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004474 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004475 TosaArgGen.agAxis,
4476 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004477 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004478 "error_if_validators": (
4479 TosaErrorValidator.evAxisLargerRank,
4480 TosaErrorValidator.evAxisSmallerZero,
4481 TosaErrorValidator.evShapeOfAxisNotOne,
4482 TosaErrorValidator.evWrongInputType,
4483 TosaErrorValidator.evWrongOutputType,
4484 TosaErrorValidator.evWrongRank,
4485 TosaErrorValidator.evWrongInputList,
4486 TosaErrorValidator.evWrongOutputList,
4487 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004488 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004489 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004490 "reduce_sum": {
4491 "op": Op.REDUCE_SUM,
4492 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004493 "build_fcn": (
4494 build_reduce,
4495 TosaTensorGen.tgBasic,
4496 TosaTensorValuesGen.tvgReduceSum,
4497 TosaArgGen.agAxis,
4498 ),
James Ward24dbc422022-10-19 12:20:31 +01004499 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004500 "error_if_validators": (
4501 TosaErrorValidator.evAxisLargerRank,
4502 TosaErrorValidator.evAxisSmallerZero,
4503 TosaErrorValidator.evShapeOfAxisNotOne,
4504 TosaErrorValidator.evWrongInputType,
4505 TosaErrorValidator.evWrongOutputType,
4506 TosaErrorValidator.evWrongRank,
4507 TosaErrorValidator.evWrongInputList,
4508 TosaErrorValidator.evWrongOutputList,
4509 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004510 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004511 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004512 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004513 "concat": {
4514 "op": Op.CONCAT,
4515 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004516 "build_fcn": (
4517 build_concat,
4518 TosaTensorGen.tgConcat,
4519 TosaTensorValuesGen.tvgConcat,
4520 TosaArgGen.agAxis,
4521 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004522 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004523 "error_if_validators": (
4524 TosaErrorValidator.evAxisLargerRank,
4525 TosaErrorValidator.evAxisSmallerZero,
4526 TosaErrorValidator.evConcatInputRankMismatch,
4527 TosaErrorValidator.evConcatShapeSumMismatch,
4528 TosaErrorValidator.evConcatInputDimMismatch,
4529 TosaErrorValidator.evWrongInputType,
4530 TosaErrorValidator.evWrongOutputType,
4531 TosaErrorValidator.evWrongOutputList,
4532 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004533 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004534 },
4535 "pad": {
4536 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004537 "operands": (2, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004538 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004539 "build_fcn": (
4540 build_pad,
4541 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004542 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004543 TosaArgGen.agPad,
4544 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004545 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004546 "error_if_validators": (
4547 TosaErrorValidator.evWrongInputType,
4548 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004549 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004550 TosaErrorValidator.evWrongOutputType,
4551 TosaErrorValidator.evWrongInputList,
4552 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004553 TosaErrorValidator.evRankMismatch,
4554 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004555 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004556 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004557 },
Won Jeona21b2e82023-08-10 10:33:01 +00004558 "dim": {
4559 "op": Op.DIM,
4560 "operands": (1, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004561 "rank": (1, gtu.MAX_TENSOR_RANK),
Won Jeona21b2e82023-08-10 10:33:01 +00004562 "build_fcn": (
4563 build_dim,
4564 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004565 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004566 TosaArgGen.agAxis,
4567 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004568 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004569 "error_if_validators": (
4570 TosaErrorValidator.evAxisLargerRank,
4571 TosaErrorValidator.evAxisSmallerZero,
4572 TosaErrorValidator.evWrongInputType,
4573 TosaErrorValidator.evWrongInputList,
4574 TosaErrorValidator.evWrongOutputList,
4575 TosaErrorValidator.evWrongRank,
4576 ),
4577 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004578 "reshape": {
4579 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004580 "operands": (2, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004581 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004582 "build_fcn": (
4583 build_reshape,
4584 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004585 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004586 TosaArgGen.agReshape,
4587 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004588 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004589 "error_if_validators": (
4590 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4591 TosaErrorValidator.evWrongInputType,
4592 TosaErrorValidator.evWrongOutputType,
4593 TosaErrorValidator.evWrongInputList,
4594 TosaErrorValidator.evWrongOutputList,
4595 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004596 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004597 },
4598 "reverse": {
4599 "op": Op.REVERSE,
4600 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004601 "build_fcn": (
4602 build_reverse,
4603 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004604 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004605 TosaArgGen.agAxis,
4606 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004607 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004608 "error_if_validators": (
4609 TosaErrorValidator.evAxisSmallerZero,
4610 TosaErrorValidator.evAxisLargerRank,
4611 TosaErrorValidator.evWrongInputType,
4612 TosaErrorValidator.evWrongOutputType,
4613 TosaErrorValidator.evWrongInputList,
4614 TosaErrorValidator.evWrongOutputList,
4615 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004616 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004617 },
4618 "slice": {
4619 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004620 "operands": (3, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004621 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004622 "build_fcn": (
4623 build_slice,
4624 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004625 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004626 TosaArgGen.agSlice,
4627 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004628 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004629 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004630 # TODO Turn off these error categories for now as the reference
4631 # model cannot allocate memory space for empty tensor. We probably
4632 # can report an accurate error messege at the right place during
4633 # exeuction.
4634 # TosaErrorValidator.evStartSmallerZero,
4635 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004636 TosaErrorValidator.evStartSizeOutsideBounds,
4637 TosaErrorValidator.evSizeOutputShapeMismatch,
4638 TosaErrorValidator.evInputSizeStartLengthMismatch,
4639 TosaErrorValidator.evWrongRank,
4640 TosaErrorValidator.evWrongInputType,
4641 TosaErrorValidator.evWrongOutputType,
4642 TosaErrorValidator.evWrongInputList,
4643 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004644 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004645 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004646 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004647 },
4648 "tile": {
4649 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004650 "operands": (2, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004651 "rank": (1, gtu.MAX_TENSOR_RANK),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004652 "build_fcn": (
4653 build_tile,
4654 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004655 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004656 TosaArgGen.agTile,
4657 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004658 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004659 "error_if_validators": (
4660 TosaErrorValidator.evWrongInputType,
4661 TosaErrorValidator.evWrongOutputType,
4662 TosaErrorValidator.evWrongInputList,
4663 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004664 TosaErrorValidator.evRankMismatch,
4665 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004666 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004667 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004668 },
4669 "transpose": {
4670 "op": Op.TRANSPOSE,
4671 "operands": (1, 0),
Jeremy Johnson18a379d2024-03-28 15:53:21 +00004672 "rank": (1, gtu.MAX_TENSOR_RANK),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004673 "build_fcn": (
4674 build_transpose,
4675 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004676 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004677 TosaArgGen.agTranspose,
4678 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004679 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004680 "error_if_validators": (
4681 TosaErrorValidator.evIndexOutsideBounds,
4682 TosaErrorValidator.evIndexUsedTwice,
4683 TosaErrorValidator.evWrongInputType,
4684 TosaErrorValidator.evWrongOutputType,
4685 TosaErrorValidator.evWrongInputList,
4686 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004687 TosaErrorValidator.evWrongRank,
4688 TosaErrorValidator.evRankMismatch,
4689 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004690 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004691 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004692 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004693 # Data nodes
4694 "const": {
4695 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004696 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004697 "build_fcn": (
4698 build_const,
4699 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004700 TosaTensorValuesGen.tvgLazyGenDefault,
4701 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004702 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004703 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha01ad8e1e22024-03-19 12:42:17 +00004704 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004705 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004706 "identity": {
4707 "op": Op.IDENTITY,
4708 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004709 "build_fcn": (
4710 build_unary,
4711 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004712 TosaTensorValuesGen.tvgLazyGenDefault,
4713 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004714 ),
evacha011adff832024-03-06 17:33:44 +00004715 "types": TYPE_FIB + [DType.INT4, DType.INT48],
evacha01ad8e1e22024-03-19 12:42:17 +00004716 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004717 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004718 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004719 "gather": {
4720 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004721 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004722 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004723 "build_fcn": (
4724 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004725 TosaTensorGen.tgGather,
4726 TosaTensorValuesGen.tvgGather,
4727 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004728 ),
James Ward24dbc422022-10-19 12:20:31 +01004729 "types": (
4730 DType.INT8,
4731 DType.INT16,
4732 DType.INT32,
4733 DType.FP16,
4734 DType.BF16,
4735 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004736 DType.FP8E4M3,
4737 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004738 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004739 "error_if_validators": (
4740 TosaErrorValidator.evWrongInputType,
4741 TosaErrorValidator.evWrongOutputType,
4742 TosaErrorValidator.evWrongInputList,
4743 TosaErrorValidator.evWrongOutputList,
4744 TosaErrorValidator.evWrongRank,
4745 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004746 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004747 },
4748 "scatter": {
4749 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004750 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004751 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004752 "build_fcn": (
4753 build_scatter,
4754 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004755 TosaTensorValuesGen.tvgScatter,
4756 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004757 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004758 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004759 "error_if_validators": (
4760 TosaErrorValidator.evWrongInputType,
4761 TosaErrorValidator.evWrongOutputType,
4762 TosaErrorValidator.evWrongInputList,
4763 TosaErrorValidator.evWrongOutputList,
4764 TosaErrorValidator.evWrongRank,
4765 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004766 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004767 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004768 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004769 "resize": {
4770 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004771 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004772 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004773 "build_fcn": (
4774 build_resize,
4775 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004776 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004777 TosaArgGen.agResize,
4778 ),
James Ward24dbc422022-10-19 12:20:31 +01004779 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004780 "invalid_test_validators": (
4781 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004782 ),
4783 "error_if_validators": (
4784 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004785 TosaErrorValidator.evScaleSmallerEqualZero,
4786 TosaErrorValidator.evScaleNLargerMax,
4787 TosaErrorValidator.evScaleDLargerMax,
4788 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004789 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004790 TosaErrorValidator.evBorderSmallerMin,
4791 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004792 TosaErrorValidator.evWrongInputType,
4793 TosaErrorValidator.evWrongOutputType,
4794 TosaErrorValidator.evWrongRank,
4795 TosaErrorValidator.evWrongInputList,
4796 TosaErrorValidator.evWrongOutputList,
4797 TosaErrorValidator.evBatchMismatch,
4798 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004799 TosaErrorValidator.evResizeOutputShapeMismatch,
4800 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004801 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004802 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004803 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004804 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004805 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004806 "cast": {
4807 "op": Op.CAST,
4808 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004809 "build_fcn": (
4810 build_cast,
4811 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004812 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004813 TosaArgGen.agCast,
4814 ),
James Ward8b390432022-08-12 20:48:56 +01004815 "types": (
4816 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004817 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004818 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004819 DType.INT8,
4820 DType.INT16,
4821 DType.INT32,
4822 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004823 DType.FP8E4M3,
4824 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004825 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004826 "error_if_validators": (
4827 TosaErrorValidator.evWrongInputType,
4828 TosaErrorValidator.evWrongOutputType,
4829 TosaErrorValidator.evWrongInputList,
4830 TosaErrorValidator.evWrongOutputList,
4831 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004832 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson708da822023-11-15 16:25:45 +00004833 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004834 },
4835 "rescale": {
4836 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004837 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004838 "build_fcn": (
4839 build_rescale,
4840 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004841 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004842 TosaArgGen.agRescale,
4843 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004844 "types": [
4845 DType.UINT8,
4846 DType.INT8,
4847 DType.INT16,
4848 DType.INT32,
4849 DType.INT48,
4850 DType.UINT16,
4851 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004852 "error_if_validators": (
4853 TosaErrorValidator.evInputZeroPointNotZero,
4854 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004855 TosaErrorValidator.evU16InputZeroPointNotValid,
4856 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004857 TosaErrorValidator.evScaleTrue,
4858 TosaErrorValidator.evScaleNotTrue,
4859 TosaErrorValidator.evWrongInputType,
4860 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004861 TosaErrorValidator.evWrongInputList,
4862 TosaErrorValidator.evWrongOutputList,
4863 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004864 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004865 # Custom
4866 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004867 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004868 # Two varients of cond_if, one that generates one of two constant tensors (no
4869 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4870 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004871 "cond_if_const": {
4872 "op": Op.COND_IF,
4873 "operands": (0, 2),
4874 "build_fcn": (
4875 build_cond_if_const,
4876 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004877 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004878 TosaArgGen.agCondIf,
4879 ),
4880 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004881 "error_if_validators": (
4882 TosaErrorValidator.evOutputListThenGraphMismatch,
4883 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004884 TosaErrorValidator.evCondIfCondNotMatchingBool,
4885 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004886 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004887 },
4888 "cond_if_binary": {
4889 "op": Op.COND_IF,
4890 "operands": (2, 0),
4891 "build_fcn": (
4892 build_cond_if_binary,
4893 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004894 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004895 TosaArgGen.agCondIf,
4896 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004897 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004898 "error_if_validators": (
4899 TosaErrorValidator.evInputListThenGraphMismatch,
4900 TosaErrorValidator.evInputListElseGraphMismatch,
4901 TosaErrorValidator.evOutputListThenGraphMismatch,
4902 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004903 TosaErrorValidator.evCondIfCondNotMatchingBool,
4904 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004905 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004906 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004907 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004908 "while_loop": {
4909 "op": Op.WHILE_LOOP,
4910 "operands": (0, 1),
4911 "build_fcn": (
4912 build_while_loop,
4913 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004914 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004915 TosaArgGen.agWhileLoop,
4916 ),
4917 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004918 "error_if_validators": (
4919 TosaErrorValidator.evInputListOutputListMismatch,
4920 TosaErrorValidator.evInputListCondGraphMismatch,
4921 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4922 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4923 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004924 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004925 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004926 },
Luke Hutton57287132023-02-06 14:54:18 +00004927 "fft2d": {
4928 "op": Op.FFT2D,
4929 "operands": (2, 0),
4930 "rank": (3, 3),
4931 "build_fcn": (
4932 build_fft2d,
4933 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004934 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004935 TosaArgGen.agFFT2d,
4936 ),
4937 "types": [DType.FP32],
4938 "error_if_validators": (
4939 TosaErrorValidator.evWrongInputType,
4940 TosaErrorValidator.evWrongOutputType,
4941 TosaErrorValidator.evWrongInputList,
4942 TosaErrorValidator.evWrongOutputList,
4943 TosaErrorValidator.evWrongRank,
4944 TosaErrorValidator.evBatchMismatch,
4945 TosaErrorValidator.evKernelNotPowerOfTwo,
4946 TosaErrorValidator.evFFTInputShapeMismatch,
4947 TosaErrorValidator.evFFTOutputShapeMismatch,
4948 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004949 "data_gen": DOT_PRODUCT_DATAGEN,
Luke Hutton57287132023-02-06 14:54:18 +00004950 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004951 "rfft2d": {
4952 "op": Op.RFFT2D,
4953 "operands": (1, 0),
4954 "rank": (3, 3),
4955 "build_fcn": (
4956 build_rfft2d,
4957 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004958 TosaTensorValuesGen.tvgLazyGenDefault,
4959 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004960 ),
4961 "types": [DType.FP32],
4962 "error_if_validators": (
4963 TosaErrorValidator.evWrongInputType,
4964 TosaErrorValidator.evWrongOutputType,
4965 TosaErrorValidator.evWrongInputList,
4966 TosaErrorValidator.evWrongOutputList,
4967 TosaErrorValidator.evWrongRank,
4968 TosaErrorValidator.evBatchMismatch,
4969 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004970 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004971 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004972 "data_gen": DOT_PRODUCT_DATAGEN,
Luke Hutton261b7b62023-01-10 14:50:31 +00004973 },
Won Jeon74342e52024-01-09 00:34:40 +00004974 # Shape
4975 "add_shape": {
4976 "op": Op.ADD_SHAPE,
4977 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004978 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004979 "build_fcn": (
4980 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004981 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004982 TosaTensorValuesGen.tvgAddSub,
4983 TosaArgGen.agNone,
4984 ),
4985 "types": [DType.SHAPE],
4986 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4987 },
4988 "sub_shape": {
4989 "op": Op.SUB_SHAPE,
4990 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004991 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004992 "build_fcn": (
4993 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004994 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004995 TosaTensorValuesGen.tvgAddSub,
4996 TosaArgGen.agNone,
4997 ),
4998 "types": [DType.SHAPE],
4999 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5000 },
5001 "mul_shape": {
5002 "op": Op.MUL_SHAPE,
5003 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005004 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005005 "build_fcn": (
5006 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005007 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005008 TosaTensorValuesGen.tvgMul,
5009 TosaArgGen.agNone,
5010 ),
5011 "types": [DType.SHAPE],
5012 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5013 },
5014 "div_shape": {
5015 "op": Op.DIV_SHAPE,
5016 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005017 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005018 "build_fcn": (
5019 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005020 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005021 TosaTensorValuesGen.tvgIntDiv,
5022 TosaArgGen.agNone,
5023 ),
5024 "types": [DType.SHAPE],
5025 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5026 },
5027 "concat_shape": {
5028 "op": Op.CONCAT_SHAPE,
5029 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005030 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005031 "build_fcn": (
5032 build_concat,
5033 TosaTensorGen.tgConcat,
5034 TosaTensorValuesGen.tvgConcat,
5035 TosaArgGen.agNone,
5036 ),
5037 "types": [DType.SHAPE],
5038 "error_if_validators": (),
5039 },
5040 "const_shape": {
5041 "op": Op.CONST_SHAPE,
5042 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005043 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005044 "build_fcn": (
5045 build_const,
5046 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005047 TosaTensorValuesGen.tvgLazyGenDefault,
5048 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005049 ),
5050 "types": [DType.SHAPE],
5051 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005052 }
5053
Kevin Cheng550ccc52021-03-03 11:21:43 -08005054
Eric Kunzee5e26762020-10-13 16:11:07 -07005055class OutputShaper:
5056 # Methods in this class compute the expected output shape and datatype
5057 # for common classes of operations
5058 def __init__(self):
5059 pass
5060
5061 # These methods return arguments that can be used for
5062 # creating a new output tensor
5063 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005064 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5065 if error_name != ErrorIf.RankMismatch:
5066 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005067 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005068
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005069 # Work out broadcasted output shape (when not ERRORIF test)
Eric Kunzee5e26762020-10-13 16:11:07 -07005070 shape = []
5071 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005072 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005073 shape.append(b.shape[i])
5074 else:
5075 shape.append(a.shape[i])
5076
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005077 if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
5078 # Can only create this error for rank > 0
5079 fuzz_idx = rng.integers(0, len(shape))
Jerry Ge135c9552023-05-23 20:59:32 +00005080 shape[fuzz_idx] += 1
5081
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005082 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005083 all_dtypes = [
5084 DType.INT8,
5085 DType.INT16,
5086 DType.INT32,
5087 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005088 DType.FP16,
5089 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005090 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005091 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005092 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5093 outputDType = rng.choice(wrong_dtypes)
5094 else:
5095 outputDType = a.dtype
5096
5097 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005098
5099 @staticmethod
5100 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005101 assert len(a.shape) == len(b.shape)
5102 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005103
5104 shape = []
5105 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005106 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005107 shape.append(a.shape[i])
5108
Kevin Cheng550ccc52021-03-03 11:21:43 -08005109 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005110
5111 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005112 def unaryOp(ser, rng, a, error_name=None):
5113 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005114 all_dtypes = [
5115 DType.INT8,
5116 DType.INT16,
5117 DType.INT32,
5118 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005119 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005120 DType.FP16,
5121 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005122 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005123 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5124 outputDType = rng.choice(wrong_dtypes)
5125 else:
5126 outputDType = a.dtype
5127
5128 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005129
5130 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005131 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005132 if error_name != ErrorIf.RankMismatch:
5133 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005134 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005135
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005136 # Work out broadcasted output shape (when not ERRORIF test)
Eric Kunzee5e26762020-10-13 16:11:07 -07005137 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005138 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005139 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005140 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5141 else:
5142 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005143
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005144 if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
5145 # Can only create this error for rank > 0
5146 fuzz_idx = rng.integers(0, len(shape))
Jerry Ge135c9552023-05-23 20:59:32 +00005147 shape[fuzz_idx] += 1
5148
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005149 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005150 all_dtypes = [
5151 DType.INT8,
5152 DType.INT16,
5153 DType.INT32,
5154 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005155 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005156 DType.FP16,
5157 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005158 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005159 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5160 outputDType = rng.choice(wrong_dtypes)
5161 else:
5162 outputDType = a.dtype
5163
5164 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005165
5166 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005167 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005168 if error_name != ErrorIf.RankMismatch:
5169 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005170 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005171
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005172 # Work out broadcasted output shape
Eric Kunzee5e26762020-10-13 16:11:07 -07005173 shape = []
5174 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005175 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005176 shape.append(b.shape[i])
5177 else:
5178 shape.append(a.shape[i])
5179
Jeremy Johnson18a379d2024-03-28 15:53:21 +00005180 if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
5181 # Can only create this error for rank > 0
5182 fuzz_idx = rng.integers(0, len(shape))
Jerry Ge135c9552023-05-23 20:59:32 +00005183 shape[fuzz_idx] += 1
5184
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005185 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005186 wrong_dtypes = [
5187 DType.INT8,
5188 DType.INT16,
5189 DType.INT32,
5190 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005191 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005192 DType.FP16,
5193 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005194 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005195 outputDType = rng.choice(wrong_dtypes)
5196 else:
5197 outputDType = DType.BOOL
5198
5199 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005200
5201 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005202 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005203 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005204 if error_name not in [
5205 ErrorIf.AxisSmallerZero,
5206 ErrorIf.AxisLargerRank,
5207 ErrorIf.ShapeOfAxisNotOne,
5208 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005209 shape[axis] = 1
5210 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5211 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005212
Matthew Haddond6ce7252021-09-29 15:35:44 +01005213 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005214 all_dtypes = [
5215 DType.INT8,
5216 DType.INT16,
5217 DType.INT32,
5218 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005219 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005220 DType.FP16,
5221 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005222 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005223 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5224 outputDType = rng.choice(wrong_dtypes)
5225 else:
5226 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005227
Matthew Haddond6ce7252021-09-29 15:35:44 +01005228 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005229
5230 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005231 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005232 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005233
5234 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5235 del shape[axis]
5236
5237 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5238 remove = rng.choice([True, False])
5239 if remove and len(shape) > 1:
5240 del shape[0]
5241 else:
5242 shape.append(1)
5243 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5244 for i in range(len(shape)):
5245 shape[i] = shape[i] + rng.integers(1, 10)
5246
5247 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005248 all_dtypes = [
5249 DType.INT8,
5250 DType.INT16,
5251 DType.INT32,
5252 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005253 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005254 DType.FP16,
5255 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005256 DType.FP8E4M3,
5257 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005258 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005259 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5260 outputDType = rng.choice(wrong_dtypes)
5261 else:
5262 outputDType = DType.INT32
5263
5264 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005265
5266 @staticmethod
Tai Lyf36f2562024-03-14 16:21:29 +00005267 def _get_conv_output_type(input_dtype):
5268 if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
5269 return input_dtype
5270 elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
5271 return DType.FP16
5272 elif input_dtype in (DType.INT8, DType.INT4):
5273 return DType.INT32
5274 elif input_dtype in (DType.INT16,):
5275 return DType.INT48
5276 assert True, f"Unsupported convolution data type {input_dtype}"
5277
5278 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005279 def conv2dOp(
5280 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5281 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005282
5283 # IFM: NHWC
5284 # Filter: OHWI
5285 # OFM: NHWC
5286
Kevin Cheng550ccc52021-03-03 11:21:43 -08005287 h = (
5288 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005289 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005290 + padding[0]
5291 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005292 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005293 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005294
Kevin Cheng550ccc52021-03-03 11:21:43 -08005295 w = (
5296 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005297 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005298 + padding[2]
5299 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005300 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005301 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005302
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005303 if error_name == ErrorIf.ConvOutputShapeMismatch:
5304 choices = [1, 2, 3]
5305 change = rng.choice(choices)
5306 # increment in multiples of stride to not hit non-integer error case
5307 if change in [1, 3]:
5308 h = h + (rng.choice(choices) * strides[0])
5309 if change in [2, 3]:
5310 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005311
Eric Kunzee5e26762020-10-13 16:11:07 -07005312 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5313
James Ward8b390432022-08-12 20:48:56 +01005314 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005315 # Pick some potentially correct output dtype if input type is incorrect
5316 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005317 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005318 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005319
5320 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005321 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005322 excludes = [DType.FP16, DType.FP32]
Jeremy Johnson80fd9b82024-03-12 11:46:50 +00005323 elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
Won Jeon2c34b462024-02-06 18:37:00 +00005324 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005325 else:
5326 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005327 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005328 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005329
Kevin Cheng550ccc52021-03-03 11:21:43 -08005330 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005331
5332 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005333 def conv3dOp(
5334 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5335 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005336
5337 # IFM: NDHWC
5338 # Filter: ODHWI
5339 # OFM: NDHWC
5340
5341 d = (
5342 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005343 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005344 + padding[0]
5345 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005346 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005347 ) // strides[0] + 1
5348
5349 h = (
5350 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005351 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005352 + padding[2]
5353 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005354 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005355 ) // strides[1] + 1
5356
5357 w = (
5358 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005359 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005360 + padding[4]
5361 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005362 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005363 ) // strides[2] + 1
5364
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005365 if error_name == ErrorIf.ConvOutputShapeMismatch:
5366 choices = [1, 2, 3, 4]
5367 change = rng.choice(choices)
5368 # increment in multiples of stride to not hit non-integer error case
5369 if change in [1, 4]:
5370 d = d + (rng.choice(choices) * strides[0])
5371 if change in [2, 4]:
5372 h = h + (rng.choice(choices) * strides[1])
5373 if change in [3, 4]:
5374 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005375
Kevin Cheng1533b852021-09-01 12:51:58 -07005376 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5377
James Ward8b390432022-08-12 20:48:56 +01005378 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005379 # Pick some potentially correct output dtype if input type is incorrect
5380 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005381 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005382 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005383
5384 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005385 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005386 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005387 else:
5388 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005389 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005390 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005391
5392 return ser.addOutput(ofm_shape, out_dtype)
5393
5394 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005395 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005396 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005397 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005398 # IFM: NHWC
5399 # Filter: HWCM
5400 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005401
Kevin Cheng550ccc52021-03-03 11:21:43 -08005402 h = (
5403 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005404 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005405 + padding[0]
5406 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005407 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005408 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005409
Kevin Cheng550ccc52021-03-03 11:21:43 -08005410 w = (
5411 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005412 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005413 + padding[2]
5414 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005415 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005416 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005417
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005418 if error_name == ErrorIf.ConvOutputShapeMismatch:
5419 choices = [1, 2, 3]
5420 change = rng.choice(choices)
5421 # increment in multiples of stride to not hit non-integer error case
5422 if change in [1, 3]:
5423 h = h + (rng.choice(choices) * strides[0])
5424 if change in [2, 3]:
5425 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005426
Eric Kunzee5e26762020-10-13 16:11:07 -07005427 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5428
James Ward8b390432022-08-12 20:48:56 +01005429 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005430 # Pick some potentially correct output dtype if input type is incorrect
5431 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005432 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005433 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005434
5435 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005436 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005437 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005438 else:
5439 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005440 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005441 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005442
Kevin Cheng550ccc52021-03-03 11:21:43 -08005443 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005444
5445 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005446 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005447 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005448 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005449 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005450 h = 1
5451 w = 1
5452 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005453 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5454 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005455
5456 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005457 choices = [1, 2, 3]
5458 change = rng.choice(choices)
5459 # increment in multiples of stride to not hit non-integer error case
5460 if change in [1, 3]:
5461 h = h + (rng.choice(choices) * stride[0])
5462 if change in [2, 3]:
5463 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005464 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005465
5466 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005467 all_dtypes = [
5468 DType.INT8,
5469 DType.INT16,
5470 DType.INT32,
5471 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005472 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005473 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005474 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005475 DType.FP8E4M3,
5476 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005477 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005478 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5479 outputDType = rng.choice(wrong_dtypes)
5480 else:
5481 outputDType = ifm.dtype
5482
5483 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005484
5485 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005486 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005487 # input: N, IC
5488 # filter: OC, IC
5489 # output: N, OC
5490
5491 output_shape = [input.shape[0], filter.shape[0]]
5492
James Ward8b390432022-08-12 20:48:56 +01005493 # Validated in arg_gen (also invalidated for ErrorIf)
5494 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005495
Kevin Cheng550ccc52021-03-03 11:21:43 -08005496 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005497
5498 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005499 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005500 # a: N, H, C
5501 # b: N, C, W
5502 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005503
Kevin Cheng2d60f002021-06-09 14:18:32 -07005504 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005505
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005506 if error_name == ErrorIf.WrongOutputType:
5507 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005508 incorrect_types = (
5509 DType.INT4,
5510 DType.INT8,
5511 DType.INT16,
5512 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005513 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005514 DType.FP16,
5515 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005516 DType.FP8E4M3,
5517 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005518 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005519 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005520 incorrect_types = (
5521 DType.INT4,
5522 DType.INT8,
5523 DType.INT16,
5524 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005525 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005526 DType.FP16,
5527 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005528 DType.FP8E4M3,
5529 DType.FP8E5M2,
5530 )
5531 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5532 incorrect_types = (
5533 DType.INT4,
5534 DType.INT8,
5535 DType.INT16,
5536 DType.INT32,
5537 DType.INT48,
5538 DType.FP32,
5539 DType.BF16,
5540 DType.FP8E4M3,
5541 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005542 )
James Ward24dbc422022-10-19 12:20:31 +01005543 elif (
5544 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5545 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005546 incorrect_types = (
5547 DType.INT4,
5548 DType.INT8,
5549 DType.INT16,
5550 DType.INT32,
5551 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005552 DType.FP8E4M3,
5553 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005554 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005555 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005556 elif error_name == ErrorIf.WrongInputType:
5557 # Pick some potentially correct output dtype if input type is incorrect
5558 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005559 else:
James Ward8b390432022-08-12 20:48:56 +01005560 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005561
Kevin Cheng550ccc52021-03-03 11:21:43 -08005562 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005563
5564 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005565 def concatOp(ser, rng, axis, inputs, error_name=None):
5566 input1 = inputs[0]
5567 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005568
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005569 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005570 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005571 if not (
5572 # unable to concat tensors of different ranks
5573 error_name == ErrorIf.ConcatInputRankMismatch
5574 # unable to concat tensors along an invalid axis
5575 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005576 ):
5577 for tensor in remaining_inputs:
5578 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005579
Matthew Haddon01c359d2021-10-15 16:30:48 +01005580 if error_name == ErrorIf.ConcatShapeSumMismatch:
5581 output_shape[axis] += rng.integers(5, 10)
5582
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005583 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005584 all_dtypes = {
5585 DType.INT8,
5586 DType.INT16,
5587 DType.INT32,
5588 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005589 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005590 DType.FP16,
5591 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005592 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005593 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5594 outputDType = rng.choice(wrong_dtypes)
5595 else:
5596 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005597
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005598 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005599
5600 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005601 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005602
5603 output_shape = a.shape.copy()
5604
5605 for i in range(len(output_shape)):
5606 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5607
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005608 if error_name == ErrorIf.PadOutputShapeMismatch:
5609 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005610 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005611 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005612 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005613
Matthew Haddone807aae2021-10-11 18:12:58 +01005614 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005615 all_dtypes = [
5616 DType.INT8,
5617 DType.INT16,
5618 DType.INT32,
5619 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005620 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005621 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005622 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005623 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005624 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5625 outputDType = rng.choice(wrong_dtypes)
5626 else:
5627 outputDType = a.dtype
5628
5629 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005630
5631 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005632 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005633 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005634
5635 if error_name == ErrorIf.WrongOutputType:
5636 all_dtypes = [
5637 DType.INT8,
5638 DType.INT16,
5639 DType.INT32,
5640 DType.INT48,
5641 DType.FP32,
5642 DType.FP16,
5643 DType.BF16,
5644 ]
5645 wrong_dtypes = list(set(all_dtypes))
5646 outputDType = rng.choice(wrong_dtypes)
5647 else:
5648 outputDType = DType.SHAPE
5649
5650 return ser.addOutput(output_shape, outputDType)
5651
5652 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005653 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005654 output_shape = shape.copy()
5655
Matthew Haddone807aae2021-10-11 18:12:58 +01005656 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5657 for i in range(len(output_shape)):
5658 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5659
5660 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005661 all_dtypes = [
5662 DType.INT8,
5663 DType.INT16,
5664 DType.INT32,
5665 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005666 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005667 DType.FP16,
5668 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005669 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005670 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5671 outputDType = rng.choice(wrong_dtypes)
5672 else:
5673 outputDType = a.dtype
5674
5675 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005676
5677 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005678 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005679
Matthew Haddone807aae2021-10-11 18:12:58 +01005680 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005681 all_dtypes = [
5682 DType.INT8,
5683 DType.INT16,
5684 DType.INT32,
5685 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005686 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005687 DType.FP16,
5688 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005689 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005690 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005691 outputDType = rng.choice(wrong_dtypes)
5692 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005693 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005694
Luke Huttona4e48ca2023-02-22 11:53:48 +00005695 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005696 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005697 for index in range(len(output_shape)):
5698 if output_shape[index] <= 2:
5699 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5700 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005701 output_shape[index] = output_shape[index] + rng.choice(
5702 [-2, -1, 1, 2]
5703 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005704 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5705 output_shape = input.shape.copy()
5706 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005707 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005708
5709 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005710
5711 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005712 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005713
5714 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005715 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005716
5717 for i in range(len(output_shape)):
5718 output_shape[i] = a.shape[i] * multiples[i]
5719
Luke Huttona4e48ca2023-02-22 11:53:48 +00005720 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005721 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005722
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005723 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005724 all_dtypes = [
5725 DType.INT8,
5726 DType.INT16,
5727 DType.INT32,
5728 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005729 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005730 DType.FP16,
5731 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005732 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005733 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5734 outputDType = rng.choice(wrong_dtypes)
5735 else:
5736 outputDType = a.dtype
5737
5738 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005739
5740 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005741 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005742 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005743
Kevin Cheng550ccc52021-03-03 11:21:43 -08005744 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005745
Luke Huttona4e48ca2023-02-22 11:53:48 +00005746 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005747 for i in range(len(output_shape)):
5748 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005749
Luke Huttona4e48ca2023-02-22 11:53:48 +00005750 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5751 for i in range(len(output_shape)):
5752 output_shape[i] += rng.integers(1, 10)
5753 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005754 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005755
Matthew Haddone807aae2021-10-11 18:12:58 +01005756 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005757 all_dtypes = [
5758 DType.INT8,
5759 DType.INT16,
5760 DType.INT32,
5761 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005762 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005763 DType.FP16,
5764 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005765 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005766 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5767 outputDType = rng.choice(wrong_dtypes)
5768 else:
5769 outputDType = a.dtype
5770
5771 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005772
5773 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005774 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005775 if error_name != ErrorIf.WrongRank:
5776 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005777 assert len(indices.shape) == 2
5778 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005779
Kevin Cheng77d0f762020-11-24 10:26:32 -08005780 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5781
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005782 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005783 all_dtypes = [
5784 DType.INT8,
5785 DType.INT16,
5786 DType.INT32,
5787 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005788 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005789 DType.FP16,
5790 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005791 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005792 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5793 outputDType = rng.choice(wrong_dtypes)
5794 else:
5795 outputDType = values.dtype
5796
5797 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005798
5799 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005800 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005801 if error_name != ErrorIf.WrongRank:
5802 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005803 assert len(indices.shape) == 2
5804 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005805 assert values_in.shape[0] == indices.shape[0] # N
5806 assert input.shape[1] == indices.shape[1] # W
5807 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005808
5809 output_shape = values_in.shape
5810
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005811 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005812 all_dtypes = [
5813 DType.INT8,
5814 DType.INT16,
5815 DType.INT32,
5816 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005817 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005818 DType.FP16,
5819 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005820 DType.FP8E4M3,
5821 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005822 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005823 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5824 outputDType = rng.choice(wrong_dtypes)
5825 else:
5826 outputDType = values_in.dtype
5827
5828 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005829
5830 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005831 def tableOp(ser, rng, input, error_name=None):
5832 # Same shape as the input, dtype dependent on input dtype
5833 if error_name != ErrorIf.WrongInputType:
5834 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005835 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005836 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005837 wrong_dtypes = [
5838 DType.INT8,
5839 DType.INT16,
5840 DType.INT32,
5841 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005842 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005843 DType.FP16,
5844 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005845 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005846 wrong_dtypes.remove(output_dtype)
5847 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005848 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005849
5850 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005851 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005852 serializer,
5853 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005854 input,
5855 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005856 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005857 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005858 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005859 input_dtype,
5860 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005861 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005862 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005863 # Calculate OH, OW
5864 scale_y_n = scale[0]
5865 scale_y_d = scale[1]
5866 scale_x_n = scale[2]
5867 scale_x_d = scale[3]
5868 if error_name == ErrorIf.ScaleSmallerEqualZero:
5869 scale_y_n = max(scale_y_n, 1)
5870 scale_y_d = max(scale_y_d, 1)
5871 scale_x_n = max(scale_x_n, 1)
5872 scale_x_d = max(scale_x_d, 1)
5873
5874 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5875 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5876
5877 if error_name is not None:
5878 # Make sure the output tensor is valid, which can occur when
5879 # scale, offset or border have been changed for ERROR_IFs
5880 oh = max(oh, 1)
5881 ow = max(ow, 1)
5882 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005883 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5884 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005885
5886 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5887 choices = [1, 2, 3]
5888 change = rng.choice(choices)
5889 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5890 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005891 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005892 oh -= scale_y_d
5893 assert oh > 0 # Should have been caught in agResize
5894 else:
5895 oh += scale_y_d
5896 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005897 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005898 ow -= scale_x_d
5899 assert ow > 0 # Should have been caught in agResize
5900 else:
5901 ow += scale_x_d
5902
Matthew Haddon848efb42021-09-09 12:30:53 +01005903 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005904 output_dims = [
5905 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005906 oh,
5907 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005908 input.shape[0],
5909 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005910 elif error_name == ErrorIf.BatchMismatch:
5911 output_dims = [
5912 input.shape[0] + rng.integers(1, 10),
5913 oh,
5914 ow,
5915 input.shape[3],
5916 ]
5917 elif error_name == ErrorIf.ChannelMismatch:
5918 output_dims = [
5919 input.shape[0],
5920 oh,
5921 ow,
5922 input.shape[3] + rng.integers(1, 10),
5923 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005924 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005925 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005926
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005927 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005928
5929 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005930 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005931 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005932
5933 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005934 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005935 if error_name == ErrorIf.ConvOutputShapeMismatch:
5936 choices = [1, 2, 3]
5937 change = rng.choice(choices)
5938 if change in [1, 3]:
5939 output_shape[1] = output_shape[1] + rng.choice(choices)
5940 if change in [2, 3]:
5941 output_shape[2] = output_shape[2] + rng.choice(choices)
5942
James Ward8b390432022-08-12 20:48:56 +01005943 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005944 # Pick some potentially correct output dtype if input type is incorrect
5945 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005946 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005947 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005948
5949 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005950 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005951 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005952 else:
5953 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005954 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005955 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005956
Kevin Cheng550ccc52021-03-03 11:21:43 -08005957 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005958
5959 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005960 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5961 outputs = []
5962
5963 assert ifm1.dtype == ifm2.dtype
5964 input_dtype = ifm1.dtype
5965
5966 if error_name != ErrorIf.FFTInputShapeMismatch:
5967 assert ifm1.shape == ifm2.shape
5968
5969 input_shape = ifm1.shape
5970 if error_name != ErrorIf.WrongRank:
5971 assert len(input_shape) == 3
5972
5973 output_shape = input_shape.copy()
5974 output_dtype = input_dtype
5975
5976 if error_name == ErrorIf.WrongOutputType:
5977 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005978 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005979 output_dtype = rng.choice(wrong_dtypes)
5980 elif error_name == ErrorIf.BatchMismatch:
5981 output_shape[0] += rng.integers(1, 10)
5982 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5983 modify_dim = rng.choice([1, 2])
5984 output_shape[modify_dim] += rng.integers(1, 10)
5985
5986 outputs.append(serializer.addOutput(output_shape, output_dtype))
5987 outputs.append(serializer.addOutput(output_shape, output_dtype))
5988 return outputs
5989
5990 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005991 def rfft2dOp(serializer, rng, value, error_name=None):
5992 outputs = []
5993
5994 input_shape = value.shape
5995 if error_name != ErrorIf.WrongRank:
5996 assert len(input_shape) == 3
5997
5998 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5999
6000 output_dtype = value.dtype
6001 if error_name == ErrorIf.WrongOutputType:
6002 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006003 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00006004 output_dtype = rng.choice(wrong_dtypes)
6005 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00006006 output_shape[0] += rng.integers(1, 10)
6007 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6008 modify_dim = rng.choice([1, 2])
6009 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00006010
6011 outputs.append(serializer.addOutput(output_shape, output_dtype))
6012 outputs.append(serializer.addOutput(output_shape, output_dtype))
6013 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00006014
6015 @staticmethod
6016 def addShapeOp(ser, rng, a, b, error_name=None):
6017 if error_name != ErrorIf.RankMismatch:
6018 assert len(a.shape) == len(b.shape)
6019 assert a.dtype == b.dtype
6020
Jeremy Johnson18a379d2024-03-28 15:53:21 +00006021 shape = a.shape.copy()
Won Jeon74342e52024-01-09 00:34:40 +00006022
Jeremy Johnson18a379d2024-03-28 15:53:21 +00006023 # Do not expect rank 0 tests!
6024 assert len(shape) > 0
Won Jeon74342e52024-01-09 00:34:40 +00006025 if error_name == ErrorIf.DimensionMismatch:
Jeremy Johnson18a379d2024-03-28 15:53:21 +00006026 # Can only create this error for rank > 0
6027 fuzz_idx = rng.integers(0, len(shape))
Won Jeon74342e52024-01-09 00:34:40 +00006028 shape[fuzz_idx] += 1
6029
6030 if error_name == ErrorIf.WrongOutputType:
6031 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
6032 outputDType = rng.choice(wrong_dtypes)
6033 else:
6034 outputDType = DType.SHAPE
6035 return ser.addOutput(shape, outputDType)