blob: c8670705b78f8f0c055d0618a05ec7b544ad7d7d [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005import os
Tai Ly60dc48c2024-03-08 22:19:41 +00006import struct
Matthew Haddon630c17c2021-10-14 15:05:41 +01007from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01008from datetime import datetime
9from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -070010
Jeremy Johnson1271c442023-09-05 11:39:26 +010011import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000012import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000013import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010014from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010015from generator.tosa_arg_gen import TosaArgGen
16from generator.tosa_arg_gen import TosaQuantGen
17from generator.tosa_arg_gen import TosaTensorGen
18from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000019from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_error_if import TosaErrorIfArgGen
21from generator.tosa_error_if import TosaErrorValidator
22from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010023from generator.tosa_random_gen import TosaHashRandomGenerator
24from generator.tosa_random_gen import TosaRandomGenerator
Jeremy Johnson1271c442023-09-05 11:39:26 +010025from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000026from tosa.DType import DType
27from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010028
Jeremy Johnson1271c442023-09-05 11:39:26 +010029TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
30// SPDX-License-Identifier: Apache-2.0
31// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
32"""
33
Jeremy Johnsonaf090182024-02-13 18:25:39 +000034logging.basicConfig()
35logger = logging.getLogger("tosa_verif_build_tests")
36
Matthew Haddonb724efc2021-08-25 16:40:29 +010037
Eric Kunzee5e26762020-10-13 16:11:07 -070038class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010039 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000040 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010041 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010042 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010043 TOSA_8K_LEVEL_MAX_KERNEL = 8192
44 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010045
Jeremy Johnson1271c442023-09-05 11:39:26 +010046 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000047 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010048 TOSA_MI_DOT_PRODUCT_MIN = 1000
49
Eric Kunzee5e26762020-10-13 16:11:07 -070050 def __init__(self, args):
51 self.args = args
52 self.basePath = args.output_dir
53 self.random_seed = args.random_seed
54 self.ser = None
Eric Kunzee5e26762020-10-13 16:11:07 -070055 self.createDynamicOpLists()
56 self.initOpListDefaults()
57 self.quantGen = TosaQuantGen()
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 self.global_rng = None
Eric Kunzee5e26762020-10-13 16:11:07 -070059 # Force makeShape to do a specific starting shape
60 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010061 # JSON schema validation
62 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010063 # Data generator library is sometimes needed for compliance set up
64 # even if we are generating the data later (lazy_data_generation)
65 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070066
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010067 # Work out floating point range
68 def convertFPRange(rangeFP, maxFP):
69 # Converts program arguments of max/-max to FP max
70 vals = []
71 for v in rangeFP:
72 if v == "max":
73 v = maxFP
74 elif v == "-max":
75 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000076 elif v < 0:
77 # Trim to minimum data type value
78 v = max(v, -maxFP)
79 elif v > 0:
80 # Trim to maximum data type value
81 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010082 vals.append(v)
83 return tuple(sorted(vals))
84
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010085 self.random_dtype_range = {
86 DType.SHAPE: tuple(self.args.tensor_shape_range[0:2])
87 }
Won Jeon2c34b462024-02-06 18:37:00 +000088 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 self.random_dtype_range[dtype] = convertFPRange(
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010090 args.tensor_fp_value_range,
91 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
92 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010093 self.resetGlobalRNG()
94
95 def resetGlobalRNG(self):
96 self.global_rng = TosaRandomGenerator(self.random_seed, self.random_dtype_range)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010097
Eric Kunzee5e26762020-10-13 16:11:07 -070098 def createSerializer(self, opName, testPath):
99 self.testPath = os.path.join(opName, testPath)
100
101 fullPath = os.path.join(self.basePath, self.testPath)
102 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100103 # Embed const data in the flatbuffer
104 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +0100105 if self.args.lazy_data_gen:
106 # Lazy data generation - so make constants files
107 constMode = ts.ConstMode.INPUTS
108 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100109 constMode = ts.ConstMode.EMBED_DUMP
110 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -0700111
112 def getSerializer(self):
113 return self.ser
114
Jeremy Johnson1271c442023-09-05 11:39:26 +0100115 def serialize(self, testName, metaData=None):
116 path = Path(self.basePath) / self.testPath
117
118 # Write out TOSA flatbuffer binary
119 path_fb = path / f"{testName}.tosa"
120 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700121 fd.write(self.ser.serialize())
122
Jeremy Johnson1271c442023-09-05 11:39:26 +0100123 # Get JSON descriptor from serializer
124 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
125
126 if metaData:
127 # Add extra meta data to desc.json
128 desc["meta"] = metaData
129
130 # Validate desc.json before we output it
131 self.descSchemaValidator.validate_config(desc)
132
133 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100134 if "data_gen" in metaData:
135 if self.args.lazy_data_gen:
136 # Output datagen meta data as CPP data
137 path_md = path / f"{testName}_meta_data_gen.cpp"
138 with path_md.open("w") as fd:
139 fd.write(TOSA_AUTOGENERATED_HEADER)
140 fd.write("// Test meta data for data generation setup\n\n")
141 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
142 json.dump(metaData["data_gen"], fd)
143 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100144 if "compliance" in metaData:
145 # Output datagen meta data as CPP data
146 path_md = path / f"{testName}_meta_compliance.cpp"
147 with path_md.open("w") as fd:
148 fd.write(TOSA_AUTOGENERATED_HEADER)
149 fd.write("// Test meta data for compliance validation\n\n")
150 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
151 json.dump(metaData["compliance"], fd)
152 fd.write(')";\n\n')
153
154 # Write desc.json
155 path_desc = path / "desc.json"
156 with path_desc.open("w") as fd:
157 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700158
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100159 def buildPlaceholderTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700160 placeholders = []
161
Kevin Cheng989cb052021-04-28 16:29:44 -0700162 assert len(shape_list) == len(dtype_list)
163
Jeremy Johnson1271c442023-09-05 11:39:26 +0100164 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700165 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100166 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100167 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700168 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700169
170 return placeholders
171
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100172 def buildConstTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 consts = []
174
Kevin Cheng989cb052021-04-28 16:29:44 -0700175 assert len(shape_list) == len(dtype_list)
176
Jeremy Johnson1271c442023-09-05 11:39:26 +0100177 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700178 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100179 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100180 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700181 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700182
183 return consts
184
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100185 def makeShape(self, rng, rank):
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 if self.targetted_shape:
187 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800188 return np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100189 rng.integers(
Kevin Cheng550ccc52021-03-03 11:21:43 -0800190 low=self.args.tensor_shape_range[0],
191 high=self.args.tensor_shape_range[1],
192 size=rank,
193 )
194 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700195
196 def setTargetShape(self, shape):
197 self.targetted_shape = shape
198
Eric Kunzee5e26762020-10-13 16:11:07 -0700199 def shapeStr(self, shape):
200
201 sStr = []
202 # Convert to strings
203 for i in shape:
204 sStr.append(str(i))
205
Kevin Cheng550ccc52021-03-03 11:21:43 -0800206 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100208 def typeStr(self, dtype):
209 if isinstance(dtype, list) or isinstance(dtype, tuple):
210 assert len(dtype) >= 2
211 strs = [self.typeStr(t) for t in dtype]
212 # Limit types to the first 2 as the 3rd is the accumulator
213 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100215 if dtype in gtu.DTYPE_ATTRIBUTES:
216 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700217 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100218 raise Exception(
219 "Unknown dtype, cannot convert to string: {}".format(dtype)
220 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700221
Luke Hutton57287132023-02-06 14:54:18 +0000222 def constrictBatchSize(self, shape):
223 # Limit the batch size unless an explicit target shape set
224 if self.args.max_batch_size and not self.args.target_shapes:
225 shape[0] = min(shape[0], self.args.max_batch_size)
226 return shape
227
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100228 def makeDimension(self, rng):
229 return rng.randInt(
James Ward30124a82023-02-02 14:56:33 +0000230 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
231 )
232
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100233 def tensorComplianceMetaData(
234 self, op, inputType, argsDict, outputTensor, errorName
235 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000236 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
237 UNSUPPORTED_NON_FP32_INPUT_OPS = (
238 Op.MATMUL,
239 Op.CONV2D,
240 Op.FULLY_CONNECTED,
241 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000242 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000243 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000244 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100245 if (
246 errorName
247 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000248 or (
249 not gtu.dtypeIsSupportedByCompliance(inputType)
250 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
251 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100252 ):
253 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100255
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100257 compliance_tens = {
258 "mode": None,
259 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
260 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
261 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100262 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
263 mode = gtu.ComplianceMode.DOT_PRODUCT
264 compliance_tens["dot_product_info"] = {
265 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100266 "ks": int(argsDict["ksb"])
267 if "ksb" in argsDict
268 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100269 }
evacha019c96eef2024-02-07 11:21:55 +0000270 elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100271 mode = gtu.ComplianceMode.FP_SPECIAL
272 elif "compliance" in op and "ulp" in op["compliance"]:
273 mode = gtu.ComplianceMode.ULP
274 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000275 elif "compliance" in op and "relative" in op["compliance"]:
276 mode = gtu.ComplianceMode.RELATIVE
277 compliance_tens["relative_info"] = {
278 "max": argsDict["max_abs_value"],
279 "scale": op["compliance"]["relative"],
280 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100281 elif op["op"] == Op.REDUCE_PRODUCT:
282 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000283 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000284 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000285 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000286 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
287 compliance_tens["abs_error_info"] = {
288 "lower_bound": op["compliance"]["abs_error_lower_bound"]
289 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800290 elif op["op"] in (Op.SIN, Op.COS):
291 mode = gtu.ComplianceMode.ABS_ERROR
292 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
293 compliance_tens["abs_error_info"] = {
294 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
295 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100296 else:
297 mode = gtu.ComplianceMode.EXACT
298 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
299
300 return compliance_tens
301
302 # Build Op functions
303 # Create the output tensor (calling OutputShaper as needed)
304 # Do final tweaks to attributes (if necessary for errorIf)
305 # Add Op into graph
306 # Return resulting tensor information or BuildInfo
307
308 class BuildInfo:
309 """Enhanced build information containing result tensor and associated compliance dict."""
310
311 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000312 if isinstance(resultTensor, list):
313 assert complianceDict is None or isinstance(complianceDict, list)
314 self.resultTensorList = resultTensor
315 self.complianceDictList = complianceDict
316 else:
317 self.resultTensorList = [resultTensor]
318 if complianceDict is None:
319 self.complianceDictList = None
320 else:
321 self.complianceDictList = [complianceDict]
322
323 def getComplianceInfo(self):
324 if self.complianceDictList is None:
325 return None
326 else:
327 tens_dict = {}
328 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
329 if comp is not None:
330 tens_dict[tens.name] = comp
331
332 if tens_dict:
333 # Have some compliance data, so return the info
334 compliance = {
335 "version": "0.1",
336 "tensors": tens_dict,
337 }
338 else:
339 compliance = None
340 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700341
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000342 def build_unary(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100343 self,
344 rng,
345 op,
346 inputs,
347 args_dict,
348 validator_fcns=None,
349 error_name=None,
350 qinfo=None,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000351 ):
352 assert len(inputs) == 1
353 a = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100354 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100355
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000356 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100357
358 # Ensure new output type has correct qinfo
359 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000360 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000361 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100362 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, a.dtype),
363 TosaQuantGen.getZeroPoint(
364 rng, self.args.zeropoint, result_tensor.dtype
365 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000366 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100367
368 # Invalidate Input/Output list for error if checks.
369 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000370 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100371 pCount, cCount = op["operands"]
372 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000373 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100374 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000375 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100376
Les Bell729b0352021-11-24 10:28:21 +0000377 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100378 self.ser,
379 validator_fcns,
380 error_name,
381 op=op,
382 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000383 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000384 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000385 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100386 input_list=input_list,
387 output_list=output_list,
388 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000389 ):
390 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100391
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000392 attr = None
393 if op["op"] == Op.NEGATE:
394 attr = ts.TosaSerializerAttribute()
395 attr.NegateAttribute(qinfo[0], qinfo[1])
396
397 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000398
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000399 compliance = self.tensorComplianceMetaData(
400 op, a.dtype, args_dict, result_tensor, error_name
401 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000402 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700403
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000404 def build_binary_broadcast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100405 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000406 ):
407 assert len(inputs) == 2
408 a, b = inputs
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100409 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100410
411 # Invalidate Input/Output list for error if checks.
412 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000413 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100414 pCount, cCount = op["operands"]
415 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000416 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100417 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000418 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100419
Les Bell729b0352021-11-24 10:28:21 +0000420 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100421 self.ser,
422 validator_fcns,
423 error_name,
424 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000425 input1=a,
426 input2=b,
427 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000428 output_dtype=result_tensor.dtype,
429 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100430 input_list=input_list,
431 output_list=output_list,
432 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000433 ):
434 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100435
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000437
Jeremy Johnson9a758382023-11-07 16:27:35 +0000438 compliance = self.tensorComplianceMetaData(
439 op, a.dtype, args_dict, result_tensor, error_name
440 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000441
442 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700443
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000444 def build_arithmetic_right_shift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100445 self,
446 rng,
447 op,
448 inputs,
449 args_dict,
450 validator_fcns=None,
451 error_name=None,
452 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000453 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000454 assert len(inputs) == 2
455 a, b = inputs
456 round = args_dict["round"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100457 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100458
459 # Invalidate Input/Output list for error if checks.
460 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000461 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100462 pCount, cCount = op["operands"]
463 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000464 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100465 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000466 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100467
Les Bell729b0352021-11-24 10:28:21 +0000468 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100469 self.ser,
470 validator_fcns,
471 error_name,
472 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000473 input1=a,
474 input2=b,
475 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000476 output_dtype=result_tensor.dtype,
477 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478 input_list=input_list,
479 output_list=output_list,
480 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000481 ):
482 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800483
484 attr = ts.TosaSerializerAttribute()
485 attr.ArithmeticRightShiftAttribute(round)
486
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000487 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000488
489 compliance = self.tensorComplianceMetaData(
490 op, a.dtype, args_dict, result_tensor, error_name
491 )
492
493 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800494
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100495 def build_mul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100496 self,
497 rng,
498 op,
499 inputs,
500 args_dict,
501 validator_fcns=None,
502 error_name=None,
503 qinfo=None,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100504 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000505 # Note that mul is binary operator but it has a shift value tensor
506 assert len(inputs) == 3
507 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100508
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100509 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700510
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100511 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100512 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100513 result_tensor.setDtype(DType.INT32)
514
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100515 if error_name == ErrorIf.WrongOutputType:
516 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100517 outputDType = rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100518 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100519
520 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000521 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100522 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100523 pCount, cCount = op["operands"]
524 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000525 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100528
Les Bell729b0352021-11-24 10:28:21 +0000529 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100530 self.ser,
531 validator_fcns,
532 error_name,
533 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000534 input1=a,
535 input2=b,
536 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100537 output_dtype=result_tensor.dtype,
538 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100539 input_list=input_list,
540 output_list=output_list,
541 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000542 ):
543 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700544
Jeremy Johnson0a042992024-02-28 13:20:05 +0000545 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100546
547 compliance = self.tensorComplianceMetaData(
548 op, a.dtype, args_dict, result_tensor, error_name
549 )
550
551 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700552
Jeremy Johnson587cc842024-02-08 11:45:44 +0000553 def build_table(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100554 self,
555 rng,
556 op,
557 inputs,
558 args_dict,
559 validator_fcns=None,
560 error_name=None,
561 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000562 ):
563 assert len(inputs) == 1
564 a = inputs[0]
565 table = args_dict["table"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100566 result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700567
Kevin Chengfe392ce2021-10-18 21:51:55 +0000568 attr = ts.TosaSerializerAttribute()
569 attr.TableAttribute(table)
570
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100571 # Invalidate Input/Output list for error if checks.
572 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000573 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100574 pCount, cCount = op["operands"]
575 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000576 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100577 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000578 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100579
Les Bell729b0352021-11-24 10:28:21 +0000580 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100581 self.ser,
582 validator_fcns,
583 error_name,
584 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000585 input_shape=a.shape,
586 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000587 output_dtype=result_tensor.dtype,
588 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100589 input_list=input_list,
590 output_list=output_list,
591 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000592 ):
593 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100594
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000595 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700596
Jeremy Johnson587cc842024-02-08 11:45:44 +0000597 compliance = self.tensorComplianceMetaData(
598 op, a.dtype, args_dict, result_tensor, error_name
599 )
600
601 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700602
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000603 def build_select(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100604 self,
605 rng,
606 op,
607 inputs,
608 args_dict,
609 validator_fcns=None,
610 error_name=None,
611 qinfo=None,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000612 ):
613 assert len(inputs) == 3
614 cond, a, b = inputs
615
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100616 result_tensor = OutputShaper.selectOp(self.ser, rng, cond, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617
618 # Invalidate Input/Output list for error if checks.
619 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000620 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100621 pCount, cCount = op["operands"]
622 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100624 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000625 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626
Les Bell729b0352021-11-24 10:28:21 +0000627 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100628 self.ser,
629 validator_fcns,
630 error_name,
631 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000632 input1=cond,
633 input2=a,
634 input3=b,
635 input_shape=a.shape,
636 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000637 output_dtype=result_tensor.dtype,
638 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100639 input_list=input_list,
640 output_list=output_list,
641 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000642 ):
643 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100644
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000645 self.ser.addOperator(
646 op["op"],
647 input_list,
648 output_list,
649 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000650 compliance = self.tensorComplianceMetaData(
651 op, a.dtype, args_dict, result_tensor, error_name
652 )
653
654 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700655
Jeremy Johnsona0150012023-11-15 15:52:06 +0000656 def build_comparison(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100657 self,
658 rng,
659 op,
660 inputs,
661 args_dict,
662 validator_fcns=None,
663 error_name=None,
664 qinfo=None,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000665 ):
666 assert len(inputs) == 2
667 a, b = inputs
668
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100669 result_tensor = OutputShaper.binaryComparisonOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100670
671 # Invalidate Input/Output list for error if checks.
672 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000673 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100674 pCount, cCount = op["operands"]
675 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000676 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100677 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000678 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100679
Les Bell729b0352021-11-24 10:28:21 +0000680 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100681 self.ser,
682 validator_fcns,
683 error_name,
684 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000685 input1=a,
686 input2=b,
687 input_shape=a.shape,
688 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000689 output_shape=result_tensor.shape,
690 output_dtype=result_tensor.dtype,
691 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100692 input_list=input_list,
693 output_list=output_list,
694 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000695 ):
696 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100697
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 self.ser.addOperator(
699 op["op"],
700 input_list,
701 output_list,
702 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000703
704 compliance = self.tensorComplianceMetaData(
705 op, a.dtype, args_dict, result_tensor, error_name
706 )
707 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700708
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000709 def build_argmax(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100710 self, rng, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000711 ):
712 assert len(inputs) == 1
713 a = inputs[0]
714 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100715 result_tensor = OutputShaper.argmaxOp(self.ser, rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100716
717 # Invalidate Input/Output list for error if checks.
718 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000719 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100720 pCount, cCount = op["operands"]
721 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000722 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100723 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000724 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100725
Les Bell729b0352021-11-24 10:28:21 +0000726 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100727 self.ser,
728 validator_fcns,
729 error_name,
730 op=op,
731 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000732 input_shape=a.shape,
733 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000734 output_shape=result_tensor.shape,
735 output_dtype=result_tensor.dtype,
736 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100737 input_list=input_list,
738 output_list=output_list,
739 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000740 ):
741 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700742
743 attr = ts.TosaSerializerAttribute()
744 attr.AxisAttribute(axis)
745
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000746 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000747
748 compliance = self.tensorComplianceMetaData(
749 op, inputs[0].dtype, args_dict, result_tensor, error_name
750 )
751 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700752
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000753 def build_pool2d(
754 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100755 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000756 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100757 inputs,
758 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000759 validator_fcns=None,
760 error_name=None,
761 qinfo=None,
762 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100763 assert len(inputs) == 1
764 input = inputs[0]
765 # max_pool has no accum_dtype
766 accum_dtype = (
767 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
768 )
769 stride = args_dict["stride"]
770 pad = args_dict["pad"]
771 kernel = args_dict["kernel"]
772
Jeremy Johnson0601f802023-11-08 16:28:09 +0000773 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100774 self.ser, rng, input, kernel, stride, pad, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000775 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100776
777 # Ensure new output type has correct qinfo
778 if error_name == ErrorIf.WrongInputType:
779 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000780 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100781 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, input.dtype),
782 TosaQuantGen.getZeroPoint(
783 rng, self.args.zeropoint, result_tensor.dtype
784 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000785 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100786
787 # Invalidate Input/Output list for error if checks.
788 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000789 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100790 pCount, cCount = op["operands"]
791 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000792 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100793 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000794 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100795
Les Bell729b0352021-11-24 10:28:21 +0000796 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100797 self.ser,
798 validator_fcns,
799 error_name,
800 op=op,
801 input_shape=input.shape,
802 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000803 output_shape=result_tensor.shape,
804 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000805 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100806 kernel=kernel,
807 stride=stride,
808 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000809 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000810 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100811 input_list=input_list,
812 output_list=output_list,
813 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000814 ):
815 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700816
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000817 if qinfo is None:
818 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700819
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000820 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100821 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000822
823 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700824
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100825 compliance = self.tensorComplianceMetaData(
826 op, inputs[0].dtype, args_dict, result_tensor, error_name
827 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100828
829 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100830
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000831 def build_conv2d(
832 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100833 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000834 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100835 inputs,
836 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000837 validator_fcns=None,
838 error_name=None,
839 qinfo=None,
840 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100841 assert len(inputs) == 3
842 ifm, filter, bias = inputs
843 accum_dtype = args_dict["acc_type"]
844 strides = args_dict["stride"]
845 padding = args_dict["pad"]
846 dilations = args_dict["dilation"]
847
Kevin Cheng550ccc52021-03-03 11:21:43 -0800848 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100849 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100850 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100851 rng,
James Ward8b390432022-08-12 20:48:56 +0100852 ifm,
853 filter,
854 accum_dtype,
855 strides,
856 padding,
857 dilations,
858 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000859 )
860
861 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000862 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
863 DType.INT8,
864 DType.UINT8,
865 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000866 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100867 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
868 TosaQuantGen.getZeroPoint(
869 rng, self.args.zeropoint, result_tensor.dtype
870 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000871 ]
Les Bell0e027d42021-11-09 14:42:14 +0000872
873 # Invalidate Input/Output list for error_if checks.
874 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100875 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000876 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000877 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100878 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000879 )
Les Bell0e027d42021-11-09 14:42:14 +0000880
Les Bell729b0352021-11-24 10:28:21 +0000881 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000882 self.ser,
883 validator_fcns,
884 error_name,
885 op=op,
886 input_dtype=ifm.dtype,
887 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100888 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000889 qinfo=qinfo,
890 input_list=input_list,
891 num_operands=num_operands,
892 output_list=output_list,
893 pad=padding,
894 stride=strides,
895 dilation=dilations,
896 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100897 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100898 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000899 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000900 ):
901 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700902
Tai Lyd3797f02023-11-15 23:06:19 +0000903 # TODO - Test local_bound, for now set local bound attribute to False
904 local_bound = False
905
Eric Kunzee5e26762020-10-13 16:11:07 -0700906 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000907 attr.ConvAttribute(
908 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
909 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700910
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000911 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100912
913 compliance = self.tensorComplianceMetaData(
914 op, ifm.dtype, args_dict, result_tensor, error_name
915 )
916
917 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700918
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000919 def build_conv3d(
920 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100921 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000922 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100923 inputs,
924 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000925 validator_fcns=None,
926 error_name=None,
927 qinfo=None,
928 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100929 assert len(inputs) == 3
930 ifm, filter, bias = inputs
931 accum_dtype = args_dict["acc_type"]
932 strides = args_dict["stride"]
933 padding = args_dict["pad"]
934 dilations = args_dict["dilation"]
935
Kevin Cheng1533b852021-09-01 12:51:58 -0700936 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000937 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100938 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100939 rng,
James Ward8b390432022-08-12 20:48:56 +0100940 ifm,
941 filter,
942 accum_dtype,
943 strides,
944 padding,
945 dilations,
946 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000947 )
948
949 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000950 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
951 DType.INT8,
952 DType.UINT8,
953 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000954 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100955 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
956 TosaQuantGen.getZeroPoint(
957 rng, self.args.zeropoint, result_tensor.dtype
958 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000959 ]
Les Bell0e027d42021-11-09 14:42:14 +0000960
961 # Invalidate Input/Output list for error_if checks.
962 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +0000963 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000964 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000965 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100966 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000967 )
Les Bell0e027d42021-11-09 14:42:14 +0000968
Les Bell729b0352021-11-24 10:28:21 +0000969 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000970 self.ser,
971 validator_fcns,
972 error_name,
973 op=op,
974 input_dtype=ifm.dtype,
975 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +0000976 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000977 qinfo=qinfo,
978 input_list=input_list,
979 num_operands=num_operands,
980 output_list=output_list,
981 pad=padding,
982 stride=strides,
983 dilation=dilations,
984 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100985 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +0000986 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000987 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000988 ):
989 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700990
Tai Lyd3797f02023-11-15 23:06:19 +0000991 # TODO - Test local_bound, for now set local bound attribute to False
992 local_bound = False
993
Kevin Cheng1533b852021-09-01 12:51:58 -0700994 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000995 attr.ConvAttribute(
996 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
997 )
Kevin Cheng1533b852021-09-01 12:51:58 -0700998
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000999 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001000
1001 compliance = self.tensorComplianceMetaData(
1002 op, ifm.dtype, args_dict, result_tensor, error_name
1003 )
1004
1005 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001006
Kevin Cheng550ccc52021-03-03 11:21:43 -08001007 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001008 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001009 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001010 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001011 inputs,
1012 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001013 validator_fcns=None,
1014 error_name=None,
1015 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001016 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001017 assert len(inputs) == 3
1018 ifm, filter, bias = inputs
1019 accum_dtype = args_dict["acc_type"]
1020 strides = args_dict["stride"]
1021 out_pad = args_dict["pad"]
1022 output_shape = args_dict["out_shape"]
1023
TatWai Chong24594f52022-06-08 00:48:04 -07001024 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001025 result_tensor = OutputShaper.transposeConv2DOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001026 self.ser, rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001027 )
Les Bell0e027d42021-11-09 14:42:14 +00001028
1029 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001030 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1031 DType.INT8,
1032 DType.UINT8,
1033 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001034 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001035 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1036 TosaQuantGen.getZeroPoint(
1037 rng, self.args.zeropoint, result_tensor.dtype
1038 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001039 ]
Les Bell0e027d42021-11-09 14:42:14 +00001040
1041 # Invalidate Input/Output list for error_if checks.
1042 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001043 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001044 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001045 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001046 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001047 )
Les Bell0e027d42021-11-09 14:42:14 +00001048
Les Bell729b0352021-11-24 10:28:21 +00001049 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001050 self.ser,
1051 validator_fcns,
1052 error_name,
1053 op=op,
1054 input_dtype=ifm.dtype,
1055 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001056 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001057 qinfo=qinfo,
1058 input_list=input_list,
1059 num_operands=num_operands,
1060 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001061 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001062 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001063 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001064 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001065 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001066 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001067 ):
1068 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001069
Tai Lyd3797f02023-11-15 23:06:19 +00001070 # TODO - Test local_bound, for now set local bound attribute to False
1071 local_bound = False
1072
Eric Kunzee5e26762020-10-13 16:11:07 -07001073 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001074 attr.TransposeConvAttribute(
Tai Lyf36f2562024-03-14 16:21:29 +00001075 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
Tai Lyd3797f02023-11-15 23:06:19 +00001076 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001077
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001078 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001079
1080 compliance = self.tensorComplianceMetaData(
1081 op, ifm.dtype, args_dict, result_tensor, error_name
1082 )
1083
1084 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001085
Kevin Cheng550ccc52021-03-03 11:21:43 -08001086 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001088 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001089 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001090 inputs,
1091 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001092 validator_fcns=None,
1093 error_name=None,
1094 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001095 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001096 assert len(inputs) == 3
1097 ifm, filter, bias = inputs
1098 accum_dtype = args_dict["acc_type"]
1099 strides = args_dict["stride"]
1100 padding = args_dict["pad"]
1101 dilations = args_dict["dilation"]
1102
Jeremy Johnson4f931302024-01-04 17:05:24 +00001103 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001104 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001105 rng,
James Ward8b390432022-08-12 20:48:56 +01001106 ifm,
1107 filter,
1108 accum_dtype,
1109 strides,
1110 padding,
1111 dilations,
1112 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001113 )
1114
1115 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001116 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1117 DType.INT8,
1118 DType.UINT8,
1119 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001120 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001121 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1122 TosaQuantGen.getZeroPoint(
1123 rng, self.args.zeropoint, result_tensor.dtype
1124 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001125 ]
Les Bell0e027d42021-11-09 14:42:14 +00001126
1127 # Invalidate Input/Output list for error_if checks.
1128 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001129 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001130 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001131 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001132 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001133 )
Les Bell0e027d42021-11-09 14:42:14 +00001134
Les Bell729b0352021-11-24 10:28:21 +00001135 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001136 self.ser,
1137 validator_fcns,
1138 error_name,
1139 op=op,
1140 input_dtype=ifm.dtype,
1141 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001142 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001143 qinfo=qinfo,
1144 input_list=input_list,
1145 num_operands=num_operands,
1146 output_list=output_list,
1147 pad=padding,
1148 stride=strides,
1149 dilation=dilations,
1150 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001151 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001152 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001153 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001154 ):
1155 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001156
Tai Lyd3797f02023-11-15 23:06:19 +00001157 # TODO - Test local_bound, for now set local bound attribute to False
1158 local_bound = False
1159
Eric Kunzee5e26762020-10-13 16:11:07 -07001160 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +00001161 attr.ConvAttribute(
1162 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
1163 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001164
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001165 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001166
1167 compliance = self.tensorComplianceMetaData(
1168 op, ifm.dtype, args_dict, result_tensor, error_name
1169 )
1170
1171 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001172
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001173 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001174 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001175 rng,
James Ward8b390432022-08-12 20:48:56 +01001176 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001177 inputs,
1178 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001179 validator_fcns=None,
1180 error_name=None,
1181 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001182 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001183 assert len(inputs) == 3
1184 ifm, filter, bias = inputs
1185 accum_dtype = args_dict["acc_type"]
1186
1187 result_tensor = OutputShaper.fullyConnectedOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001188 self.ser, rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001189 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001190
1191 # Invalidate Input/Output list for error if checks.
1192 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001193 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001194 pCount, cCount = op["operands"]
1195 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001196 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001197 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001198 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001199
Les Bell729b0352021-11-24 10:28:21 +00001200 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001201 self.ser,
1202 validator_fcns,
1203 error_name,
1204 op=op,
1205 input_shape=ifm.shape,
1206 input_dtype=ifm.dtype,
1207 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001208 output_shape=result_tensor.shape,
1209 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001210 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001211 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001212 input_list=input_list,
1213 output_list=output_list,
1214 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001215 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001216 ):
1217 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001218
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001219 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001220 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001221
1222 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001223
1224 compliance = self.tensorComplianceMetaData(
1225 op, ifm.dtype, args_dict, result_tensor, error_name
1226 )
1227
1228 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001229
James Ward8b390432022-08-12 20:48:56 +01001230 def build_matmul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001231 self,
1232 rng,
1233 op,
1234 inputs,
1235 args_dict,
1236 validator_fcns=None,
1237 error_name=None,
1238 qinfo=None,
James Ward8b390432022-08-12 20:48:56 +01001239 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001240 assert len(inputs) == 2
1241 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001242 accum_dtype = args_dict["acc_type"]
1243 result_tensor = OutputShaper.matmulOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001244 self.ser, rng, a, b, accum_dtype, error_name
James Ward8b390432022-08-12 20:48:56 +01001245 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001246
1247 # Invalidate Input/Output list for error if checks.
1248 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001249 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001250 pCount, cCount = op["operands"]
1251 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001252 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001253 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001254 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001255
Les Bell729b0352021-11-24 10:28:21 +00001256 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001257 self.ser,
1258 validator_fcns,
1259 error_name,
1260 op=op,
1261 input_shape=a.shape,
1262 input_dtype=a.dtype,
1263 input2_shape=b.shape,
1264 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001265 output_shape=result_tensor.shape,
1266 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001267 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001268 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001269 input_list=input_list,
1270 output_list=output_list,
1271 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001272 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001273 ):
1274 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001275
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001276 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001277 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001278
1279 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001280
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001281 compliance = self.tensorComplianceMetaData(
1282 op, a.dtype, args_dict, result_tensor, error_name
1283 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001284
1285 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001286
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001287 def build_reduce(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001288 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001289 ):
1290 assert len(inputs) == 1
1291 a = inputs[0]
1292 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001293 result_tensor = OutputShaper.reduceOp(self.ser, rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001294
1295 # Invalidate Input/Output list for error if checks.
1296 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001297 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001298 pCount, cCount = op["operands"]
1299 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001300 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001301 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001302 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001303
Les Bell729b0352021-11-24 10:28:21 +00001304 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001305 self.ser,
1306 validator_fcns,
1307 error_name,
1308 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001309 axis=axis,
1310 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001311 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001312 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001313 output_dtype=result_tensor.dtype,
1314 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001315 input_list=input_list,
1316 output_list=output_list,
1317 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001318 ):
1319 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001320
1321 attr = ts.TosaSerializerAttribute()
1322 attr.AxisAttribute(axis)
1323
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001324 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001325
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001326 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1327 # Number of products - needed for compliance
1328 args_dict["n"] = a.shape[axis]
1329
1330 compliance = self.tensorComplianceMetaData(
1331 op, a.dtype, args_dict, result_tensor, error_name
1332 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001333
1334 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001335
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001336 def build_clamp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001337 self,
1338 rng,
1339 op,
1340 inputs,
1341 args_dict,
1342 validator_fcns=None,
1343 error_name=None,
1344 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001345 ):
1346 assert len(inputs) == 1
1347 a = inputs[0]
1348
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001349 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001350
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001351 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001352
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001353 if error_name == ErrorIf.MaxSmallerMin:
1354 # Make sure the numbers are different to invoke this error
1355 while v[0] == v[1]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001356 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001357 max_val = min(v)
1358 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001359 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001360 max_val = max(v)
1361 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001362
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001363 # Invalidate Input/Output list for error if checks.
1364 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001365 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001366 pCount, cCount = op["operands"]
1367 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001368 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001369 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001370 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001371
Les Bell729b0352021-11-24 10:28:21 +00001372 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001373 self.ser,
1374 validator_fcns,
1375 error_name,
1376 op=op,
1377 max_val=max_val,
1378 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001379 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001380 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001381 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001382 output_dtype=result_tensor.dtype,
1383 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001384 input_list=input_list,
1385 output_list=output_list,
1386 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001387 ):
1388 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001389
1390 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001391 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1392 if a.dtype == DType.FP16:
1393 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1394 min_val = min_val.astype(np.float32)
1395 max_val = max_val.astype(np.float32)
Tai Ly60dc48c2024-03-08 22:19:41 +00001396 min_val_as_bytes = struct.pack("<f", min_val)
1397 max_val_as_bytes = struct.pack("<f", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001398 elif a.dtype in (DType.INT8, DType.INT16):
Tai Ly60dc48c2024-03-08 22:19:41 +00001399 min_val_as_bytes = struct.pack("<i", min_val)
1400 max_val_as_bytes = struct.pack("<i", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001401 else:
1402 # to avoid internal error for incorrect input types
Tai Ly60dc48c2024-03-08 22:19:41 +00001403 min_val_as_bytes = struct.pack("<i", 0)
1404 max_val_as_bytes = struct.pack("<i", 0)
1405
1406 attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001407
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001408 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001409
1410 compliance = self.tensorComplianceMetaData(
1411 op, a.dtype, args_dict, result_tensor, error_name
1412 )
1413
1414 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001415
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001416 def build_activation(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001417 self,
1418 rng,
1419 op,
1420 inputs,
1421 args_dict,
1422 validator_fcns=None,
1423 error_name=None,
1424 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001425 ):
1426 assert len(inputs) == 1
1427 a = inputs[0]
1428
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001429 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001430
1431 # Invalidate Input/Output list for error if checks.
1432 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001433 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001434 pCount, cCount = op["operands"]
1435 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001436 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001437 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001438 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001439
Les Bell729b0352021-11-24 10:28:21 +00001440 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001441 self.ser,
1442 validator_fcns,
1443 error_name,
1444 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001445 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001446 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001447 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001448 output_dtype=result_tensor.dtype,
1449 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001450 input_list=input_list,
1451 output_list=output_list,
1452 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001453 ):
1454 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001455
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001456 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001457
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001458 compliance = self.tensorComplianceMetaData(
1459 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001460 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001461
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001462 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001463
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001464 def build_concat(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001465 self,
1466 rng,
1467 op,
1468 inputs,
1469 args_dict,
1470 validator_fcns=None,
1471 error_name=None,
1472 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001473 ):
Won Jeon74342e52024-01-09 00:34:40 +00001474 if op["op"] == Op.CONCAT_SHAPE:
1475 axis = 0
1476 else:
1477 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001478 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001479 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001480
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001481 result_tensor = OutputShaper.concatOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001482 self.ser, rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001483 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001484
Matthew Haddon818ab902021-07-27 09:12:49 +01001485 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001486 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001487 input_tensor_names.append(tensor.name)
1488
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001489 # Invalidate Input/Output list for error if checks.
1490 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001491 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001492 pCount, cCount = op["operands"]
1493 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001494 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001495 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001496 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001497
Les Bell729b0352021-11-24 10:28:21 +00001498 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001499 self.ser,
1500 validator_fcns,
1501 error_name,
1502 op=op,
1503 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001504 input_shape=inputs[0].shape,
1505 output_shape=result_tensor.shape,
1506 input_dtype=inputs[0].dtype,
1507 output_dtype=result_tensor.dtype,
1508 inputs=inputs,
1509 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001510 input_list=input_list,
1511 output_list=output_list,
1512 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001513 ):
1514 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001515
Won Jeon74342e52024-01-09 00:34:40 +00001516 if op["op"] == Op.CONCAT:
1517 attr = ts.TosaSerializerAttribute()
1518 attr.AxisAttribute(axis)
1519 else:
1520 assert op["op"] == Op.CONCAT_SHAPE
1521 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001522 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001523
1524 compliance = self.tensorComplianceMetaData(
1525 op, inputs[0].dtype, args_dict, result_tensor, error_name
1526 )
1527
1528 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001529
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001530 def build_pad(
1531 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001532 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001533 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001534 inputs,
1535 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001536 validator_fcns=None,
1537 error_name=None,
1538 qinfo=None,
1539 ):
Tai Lye095da72024-01-25 22:00:18 +00001540 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001541 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001542 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001543 padding = args_dict["pad"]
1544 pad_const_int = args_dict["pad_const_int"]
1545 pad_const_float = args_dict["pad_const_fp"]
1546
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001547 result_tensor = OutputShaper.padOp(self.ser, rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001548
Tai Ly60dc48c2024-03-08 22:19:41 +00001549 # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
1550 if gtu.dtypeIsFloat(a.dtype):
1551 pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
1552 else:
1553 pad_const_val_as_bytes = struct.pack("<i", pad_const_int)
1554
Kevin Chengfe392ce2021-10-18 21:51:55 +00001555 attr = ts.TosaSerializerAttribute()
Tai Ly60dc48c2024-03-08 22:19:41 +00001556 attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
Eric Kunzee5e26762020-10-13 16:11:07 -07001557
Matthew Haddone807aae2021-10-11 18:12:58 +01001558 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001559 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001560 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001561 pCount, cCount = op["operands"]
1562 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001563 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001564 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001565 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001566
Les Bell729b0352021-11-24 10:28:21 +00001567 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001568 self.ser,
1569 validator_fcns,
1570 error_name,
1571 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001572 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001573 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001574 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001575 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001576 pad=padding,
1577 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001578 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001579 input_list=input_list,
1580 output_list=output_list,
1581 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001582 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001583 ):
1584 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001585
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001586 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001587
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001588 compliance = self.tensorComplianceMetaData(
1589 op, a.dtype, args_dict, result_tensor, error_name
1590 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001591
1592 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001593
Won Jeona21b2e82023-08-10 10:33:01 +00001594 def build_dim(
1595 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001596 rng,
Won Jeona21b2e82023-08-10 10:33:01 +00001597 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001598 inputs,
1599 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001600 validator_fcns=None,
1601 error_name=None,
1602 qinfo=None,
1603 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001604 assert len(inputs) == 1
1605 a = inputs[0]
1606 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001607 result_tensor = OutputShaper.dimOp(self.ser, rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001608
1609 # Invalidate Input/Output list for error if checks.
1610 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001611 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001612 pCount, cCount = op["operands"]
1613 num_operands = pCount + cCount
1614 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001615 rng, error_name, input_list, output_list
Won Jeona21b2e82023-08-10 10:33:01 +00001616 )
1617
1618 if not TosaErrorValidator.evValidateErrorIfs(
1619 self.ser,
1620 validator_fcns,
1621 error_name,
1622 op=op,
1623 axis=axis,
1624 input_shape=a.shape,
1625 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001626 output_shape=result_tensor.shape,
1627 output_dtype=result_tensor.dtype,
1628 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001629 input_list=input_list,
1630 output_list=output_list,
1631 num_operands=num_operands,
1632 ):
1633 return None
1634
1635 attr = ts.TosaSerializerAttribute()
1636 attr.AxisAttribute(axis)
1637
1638 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001639 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001640
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001641 def build_reshape(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001642 self,
1643 rng,
1644 op,
1645 inputs,
1646 args_dict,
1647 validator_fcns=None,
1648 error_name=None,
1649 qinfo=None,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001650 ):
Tai Ly8690a082023-12-18 20:40:24 +00001651 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001652 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001653 shape = inputs[1]
1654 shape_attr = args_dict["new_shape"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001655 result_tensor = OutputShaper.reshapeOp(self.ser, rng, a, shape_attr, error_name)
Matthew Haddone807aae2021-10-11 18:12:58 +01001656
1657 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001658 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001659 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001660 pCount, cCount = op["operands"]
1661 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001662 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001663 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001664 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001665
Les Bell729b0352021-11-24 10:28:21 +00001666 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001667 self.ser,
1668 validator_fcns,
1669 error_name,
1670 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001671 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001672 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001673 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001674 output_dtype=result_tensor.dtype,
1675 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001676 input_list=input_list,
1677 output_list=output_list,
1678 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001679 ):
1680 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001681
Tai Ly8690a082023-12-18 20:40:24 +00001682 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001683
1684 compliance = self.tensorComplianceMetaData(
1685 op, a.dtype, args_dict, result_tensor, error_name
1686 )
1687
1688 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001689
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001690 def build_reverse(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001691 self,
1692 rng,
1693 op,
1694 inputs,
1695 args_dict,
1696 validator_fcns=None,
1697 error_name=None,
1698 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001699 ):
1700 assert len(inputs) == 1
1701 a = inputs[0]
1702 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001703 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001704
1705 # Invalidate Input/Output list for error if checks.
1706 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001707 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001708 pCount, cCount = op["operands"]
1709 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001710 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001711 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001712 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001713
Les Bell729b0352021-11-24 10:28:21 +00001714 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001715 self.ser,
1716 validator_fcns,
1717 error_name,
1718 op=op,
1719 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001720 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001721 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001722 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001723 output_dtype=result_tensor.dtype,
1724 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001725 input_list=input_list,
1726 output_list=output_list,
1727 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001728 ):
1729 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001730
1731 attr = ts.TosaSerializerAttribute()
1732 attr.AxisAttribute(axis)
1733
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001734 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001735 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001736
evacha0198477222024-01-26 12:25:32 +00001737 def build_transpose(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001738 self,
1739 rng,
1740 op,
1741 inputs,
1742 args_dict,
1743 validator_fcns=None,
1744 error_name=None,
1745 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00001746 ):
1747 assert len(inputs) == 1
1748 a = inputs[0]
1749 perms = args_dict["perms"]
1750
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001751 result_tensor = OutputShaper.transposeOp(self.ser, rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001752
Kevin Chengfe392ce2021-10-18 21:51:55 +00001753 attr = ts.TosaSerializerAttribute()
1754 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001755
Matthew Haddone807aae2021-10-11 18:12:58 +01001756 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001757 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001758 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001759 pCount, cCount = op["operands"]
1760 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001761 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001762 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001763 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001764
Les Bell729b0352021-11-24 10:28:21 +00001765 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001766 self.ser,
1767 validator_fcns,
1768 error_name,
1769 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001770 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001771 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001772 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001774 output_dtype=result_tensor.dtype,
1775 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001776 input_list=input_list,
1777 output_list=output_list,
1778 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001779 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001780 ):
1781 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001782
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001783 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001784
1785 compliance = self.tensorComplianceMetaData(
1786 op, a.dtype, args_dict, result_tensor, error_name
1787 )
1788
1789 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001790
evacha017f7d4252024-01-24 12:08:09 +00001791 def build_slice(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001792 self,
1793 rng,
1794 op,
1795 inputs,
1796 args_dict,
1797 validator_fcns=None,
1798 error_name=None,
1799 qinfo=None,
evacha017f7d4252024-01-24 12:08:09 +00001800 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001801 assert len(inputs) == 3
1802 a, start_var, size_var = inputs
1803 start_const = args_dict["start"]
1804 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001805
1806 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001807 self.ser, rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001808 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001809
1810 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001811 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001812 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001813 pCount, cCount = op["operands"]
1814 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001815 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001816 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001817 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001818
Les Bell729b0352021-11-24 10:28:21 +00001819 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001820 self.ser,
1821 validator_fcns,
1822 error_name,
1823 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001825 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001826 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001827 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001828 start=start_const,
1829 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001830 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001831 input_list=input_list,
1832 output_list=output_list,
1833 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001834 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001835 ):
1836 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001837
Tai Ly8ead6c42024-02-14 22:35:44 +00001838 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001839
1840 compliance = self.tensorComplianceMetaData(
1841 op, a.dtype, args_dict, result_tensor, error_name
1842 )
1843
1844 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001845
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001846 def build_tile(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001847 self,
1848 rng,
1849 op,
1850 inputs,
1851 args_dict,
1852 validator_fcns=None,
1853 error_name=None,
1854 qinfo=None,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001855 ):
Tai Ly8690a082023-12-18 20:40:24 +00001856 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001857 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001858 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001859 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001860 result_tensor = OutputShaper.tileOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001861 self.ser, rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001862 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001863
1864 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001865 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001866 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001867 pCount, cCount = op["operands"]
1868 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001869 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001870 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001871 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001872
Les Bell729b0352021-11-24 10:28:21 +00001873 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001874 self.ser,
1875 validator_fcns,
1876 error_name,
1877 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001878 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001879 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001880 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001881 output_dtype=result_tensor.dtype,
1882 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001883 input_list=input_list,
1884 output_list=output_list,
1885 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001886 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001887 ):
1888 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001889
Tai Ly8690a082023-12-18 20:40:24 +00001890 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001891
1892 compliance = self.tensorComplianceMetaData(
1893 op, a.dtype, args_dict, result_tensor, error_name
1894 )
1895
1896 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001897
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001898 def build_gather(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001899 self,
1900 rng,
1901 op,
1902 inputs,
1903 args_dict,
1904 validator_fcns=None,
1905 error_name=None,
1906 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001907 ):
1908 assert len(inputs) == 2
1909 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001910
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001911 result_tensor = OutputShaper.gatherOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001912 self.ser, rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001913 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001914
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001915 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001916 input_list = [values.name, indices.name]
1917 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001918 pCount, cCount = op["operands"]
1919 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001920 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001921 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001922 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001923
Les Bell729b0352021-11-24 10:28:21 +00001924 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001925 self.ser,
1926 validator_fcns,
1927 error_name,
1928 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001929 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001930 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001931 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001932 output_dtype=result_tensor.dtype,
1933 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001934 input_list=input_list,
1935 output_list=output_list,
1936 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001937 ):
1938 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001939
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001940 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001941
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001942 compliance = self.tensorComplianceMetaData(
1943 op, values.dtype, args_dict, result_tensor, error_name
1944 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001945
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001946 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001947
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001948 def build_scatter(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001949 self,
1950 rng,
1951 op,
1952 inputs,
1953 args_dict,
1954 validator_fcns=None,
1955 error_name=None,
1956 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001957 ):
1958 assert len(inputs) == 3
1959 values_in, indices, input = inputs
1960 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001961 self.ser, rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001962 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001963
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001964 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001965 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001966 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001967 pCount, cCount = op["operands"]
1968 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001969 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001970 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001971 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001972
Les Bell729b0352021-11-24 10:28:21 +00001973 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001974 self.ser,
1975 validator_fcns,
1976 error_name,
1977 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001978 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001979 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001980 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001981 output_dtype=result_tensor.dtype,
1982 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001983 input_list=input_list,
1984 output_list=output_list,
1985 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001986 ):
1987 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001988
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001989 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001990
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001991 compliance = self.tensorComplianceMetaData(
1992 op, values_in.dtype, args_dict, result_tensor, error_name
1993 )
1994
1995 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001996
Kevin Cheng550ccc52021-03-03 11:21:43 -08001997 def build_resize(
1998 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001999 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002000 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002001 inputs,
2002 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01002003 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002004 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002005 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002006 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002007 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002008 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002009 scale_input = inputs[1]
2010 offset_input = inputs[2]
2011 border_input = inputs[3]
2012
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002013 mode = args_dict["mode"]
2014 scale = args_dict["scale"]
2015 offset = args_dict["offset"]
2016 border = args_dict["border"]
2017 output_dtype = args_dict["output_dtype"]
2018
2019 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002020 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002021 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002022 input,
2023 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002024 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002025 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002026 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002027 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002028 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002029 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002030 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002031
Matthew Haddon848efb42021-09-09 12:30:53 +01002032 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002033 input_list = [
2034 input.name,
2035 scale_input.name,
2036 offset_input.name,
2037 border_input.name,
2038 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002039 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002040 pCount, cCount = op["operands"]
2041 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002042 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002043 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002044 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002045
Les Bell729b0352021-11-24 10:28:21 +00002046 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002047 self.ser,
2048 validator_fcns,
2049 error_name,
2050 op=op,
2051 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002052 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002053 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002054 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002055 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002056 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002057 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002058 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002059 input_list=input_list,
2060 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002061 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002062 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002063 ):
2064 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002065
Eric Kunzee5e26762020-10-13 16:11:07 -07002066 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002067 # write empty scale/offset/border into ResizeAttribute
2068 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002069 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002070
2071 compliance = self.tensorComplianceMetaData(
2072 op, input.dtype, args_dict, result_tensor, error_name
2073 )
2074
2075 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002076
evacha0198477222024-01-26 12:25:32 +00002077 def build_const(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002078 self,
2079 rng,
2080 op,
2081 inputs,
2082 args_dict,
2083 validator_fcns=None,
2084 error_name=None,
2085 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00002086 ):
2087 assert len(inputs) == 1
2088 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002089 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002090
2091 compliance = self.tensorComplianceMetaData(
2092 op, val.dtype, args_dict, val, error_name
2093 )
2094
2095 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002096
2097 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002098 def build_cast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002099 self,
2100 rng,
2101 op,
2102 inputs,
2103 args_dict,
2104 validator_fcns=None,
2105 error_name=None,
2106 qinfo=None,
Jeremy Johnson708da822023-11-15 16:25:45 +00002107 ):
2108 assert len(inputs) == 1
2109 val = inputs[0]
2110 out_dtype = args_dict["out_type"]
2111
2112 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002113 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002114 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002115
2116 # Invalidate Input/Output list for error if checks.
2117 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002118 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002119 pCount, cCount = op["operands"]
2120 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002121 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002122 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002123 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002124
Les Bell729b0352021-11-24 10:28:21 +00002125 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002126 self.ser,
2127 validator_fcns,
2128 error_name,
2129 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002130 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002131 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002132 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002133 output_dtype=result_tensor.dtype,
2134 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002135 input_list=input_list,
2136 output_list=output_list,
2137 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002138 ):
2139 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002140
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002141 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002142
2143 compliance = self.tensorComplianceMetaData(
2144 op, val.dtype, args_dict, result_tensor, error_name
2145 )
2146
2147 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002148
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002149 def build_rescale(
2150 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002151 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002152 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002153 inputs,
2154 args_dict,
2155 validator_fcns=None,
2156 error_name=None,
2157 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002158 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002159 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002160 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002161 multiplier_val = inputs[1]
2162 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002163 out_dtype = args_dict["output_dtype"]
2164 scale32 = args_dict["scale"]
2165 double_round = args_dict["double_round"]
2166 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002167 shift_arr = args_dict["shift"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002168 multiplier_arr = args_dict["multiplier"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002169
2170 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002171 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002172 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002173
2174 if per_channel:
2175 nc = val.shape[-1]
2176 else:
2177 nc = 1
2178
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002179 in_type_width = gtu.dtypeWidth(val.dtype)
2180 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002181
Tai Ly8690a082023-12-18 20:40:24 +00002182 input_unsigned = False
2183 output_unsigned = False
2184
Kevin Cheng3a478572021-01-22 17:21:02 -08002185 if val.dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002186 input_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002187 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002188 elif val.dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002189 input_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002190 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002191 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002192 elif error_name in [
2193 ErrorIf.InputZeroPointNotZero,
2194 ErrorIf.U16InputZeroPointNotValid,
2195 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002196 input_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002197 if input_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002198 input_zp = input_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002199 in_type_width += 1
2200 elif val.dtype == DType.UINT16:
2201 # Must come after ErrorIf.U16InputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002202 input_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002203 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002204 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002205 else:
2206 input_zp = 0
2207
Kevin Cheng3a478572021-01-22 17:21:02 -08002208 if out_dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002209 output_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002210 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002211 elif out_dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002212 output_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002213 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002214 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002215 elif error_name in [
2216 ErrorIf.OutputZeroPointNotZero,
2217 ErrorIf.U16OutputZeroPointNotValid,
2218 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002219 output_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002220 if output_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002221 output_zp = output_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002222 out_type_width += 1
2223 elif out_dtype == DType.UINT16:
2224 # Must come after ErrorIf.U16OutputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002225 output_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002226 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002227 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002228 else:
2229 output_zp = 0
2230
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002231 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2232 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002233
2234 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002235 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2236 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002237
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002238 logger.debug(
2239 f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
2240 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002241 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002242 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002243 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002244 assert val.placeholderFilename
2245 values = np.load(
2246 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2247 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002248 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2249 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2250 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002251 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2252 # Check we can safely convert to the expected dtype
2253 assert (
2254 val_adj.all() >= np.iinfo(values.dtype).min
2255 and val_adj.all() <= np.iinfo(values.dtype).max
2256 )
2257
2258 # Force casting to output datatype
2259 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2260
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002261 if not np.all(np.array_equal(values, val_adj)):
2262 # Values changed so overwrite file with new values
2263 np.save(
2264 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2265 val_adj,
2266 False,
2267 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002268
Matthew Haddonc2025212021-10-08 21:21:05 +01002269 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002270 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002271 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002272 pCount, cCount = op["operands"]
2273 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002274 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002275 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002276 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002277
2278 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002279 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002280 self.ser,
2281 validator_fcns,
2282 error_name,
2283 op=op,
2284 input_dtype=val.dtype,
2285 output_dtype=out_dtype,
2286 input_shape=val.shape,
2287 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002288 scale32=scale32,
2289 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002290 input_list=input_list,
2291 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002292 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002293 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002294 ):
2295 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002296
Eric Kunzee5e26762020-10-13 16:11:07 -07002297 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002298 attr.RescaleAttribute(
2299 input_zp,
2300 output_zp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002301 scale32,
2302 double_round,
2303 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002304 input_unsigned,
2305 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002306 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002307
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002308 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002309
2310 compliance = self.tensorComplianceMetaData(
2311 op, val.dtype, args_dict, result_tensor, error_name
2312 )
2313
2314 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002315
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002316 def _get_condition_tensor(self, rng, op, cond, error_name):
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002317 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002318 cond_type = gtu.get_wrong_output_type(op, rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002319 else:
2320 cond_type = DType.BOOL
2321 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002322 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002323 if choice == 1:
2324 cond_shape = [2]
2325 else:
2326 cond_shape = [1, 2]
2327 else:
2328 # Must be of size 1 (rank 0)
2329 cond_shape = []
2330 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2331 return cond_tens
2332
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002333 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002334 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002335 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002336 op,
2337 inputs,
2338 args_dict,
2339 validator_fcns=None,
2340 error_name=None,
2341 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002342 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002343 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002344 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002345 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002346 assert len(inputs) == 2
2347 then_tens, else_tens = inputs
2348
2349 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002350
2351 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002352 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002353
2354 # Make then/else tensors
2355 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002356
Jeremy Johnson587cc842024-02-08 11:45:44 +00002357 dtype = DType.INT32
2358
Matthew Haddon630c17c2021-10-14 15:05:41 +01002359 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002360 if error_name in [
2361 ErrorIf.CondIfOutputListThenGraphMismatch,
2362 ErrorIf.CondIfOutputListElseGraphMismatch,
2363 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002364 incorrect_shape = deepcopy(then_tens.shape)
2365 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002366 incorrect_shape[i] += (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002367 rng.choice([-3, -2, 2, 3])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002368 if incorrect_shape[i] > 3
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002369 else rng.choice([1, 2, 4])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002370 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002371 incorrect_arr = np.int32(rng.integers(0, 256, size=incorrect_shape))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002372
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002373 then_arr = np.int32(rng.integers(0, 256, size=out_shape))
2374 else_arr = np.int32(rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002375
2376 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002377 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002378
2379 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002380 then_block = "THEN_BLOCK"
2381 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002382 attr = ts.TosaSerializerAttribute()
2383 attr.CondIfAttribute(then_block, else_block)
2384
2385 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002386 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002387
Jerry Ge9e94af82022-10-27 09:57:00 -07002388 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002389 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002390 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002391 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002392 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002393 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002394 self.ser.addOutputTensor(then_tens)
2395
Jerry Ge9e94af82022-10-27 09:57:00 -07002396 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002397 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002398 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002399 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002400 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002401 self.ser.addOutputTensor(else_tens)
2402
Les Bell729b0352021-11-24 10:28:21 +00002403 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002404 self.ser,
2405 validator_fcns,
2406 error_name,
2407 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002408 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002409 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002410 ):
2411 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002412
Jeremy Johnson587cc842024-02-08 11:45:44 +00002413 compliance = self.tensorComplianceMetaData(
2414 op, dtype, args_dict, result_tensor, error_name
2415 )
2416
2417 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002418
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002419 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002420 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002421 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002422 op,
2423 inputs,
2424 args_dict,
2425 validator_fcns=None,
2426 error_name=None,
2427 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002428 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002429 # For cond_if with a binary op in the then/else blocks, take a and b and
2430 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002431 assert len(inputs) == 2
2432 a, b = inputs
2433
2434 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002435
2436 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002437 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002438
Jeremy Johnson587cc842024-02-08 11:45:44 +00002439 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002440
2441 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002442 then_block = "THEN_BLOCK"
2443 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002444 attr = ts.TosaSerializerAttribute()
2445 attr.CondIfAttribute(then_block, else_block)
2446
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002447 if error_name in [
2448 ErrorIf.CondIfInputListThenGraphMismatch,
2449 ErrorIf.CondIfInputListElseGraphMismatch,
2450 ErrorIf.CondIfOutputListElseGraphMismatch,
2451 ErrorIf.CondIfOutputListThenGraphMismatch,
2452 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002453 incorrect_shape = a.shape.copy()
2454 for i in range(len(incorrect_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002455 incorrect_shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002456 incorrect_block_input = deepcopy(a)
2457 incorrect_block_input.shape = incorrect_shape
2458
Eric Kunzee5e26762020-10-13 16:11:07 -07002459 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002460 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002461 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002462 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002463
James Ward24dbc422022-10-19 12:20:31 +01002464 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002465 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002466 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002467 then_op, else_op = (
2468 self.TOSA_OP_LIST["logical_right_shift"],
2469 self.TOSA_OP_LIST["logical_left_shift"],
2470 )
Les Bell6040b4d2021-10-11 12:50:31 +01002471 else:
2472 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002473
Jeremy Johnson587cc842024-02-08 11:45:44 +00002474 # Determine the element-wise binary operation that compliance will need to
2475 # check the results of
2476 compliance_op = then_op if cond else else_op
2477
2478 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002479 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002480 if (
2481 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2482 and block == then_block
2483 ) or (
2484 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2485 and block == else_block
2486 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002487 self.ser.addInputTensor(incorrect_block_input)
2488 self.ser.addInputTensor(b)
2489 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002490 elif (
2491 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2492 and block == then_block
2493 ) or (
2494 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2495 and block == else_block
2496 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002497 self.ser.addInputTensor(a)
2498 self.ser.addInputTensor(b)
2499 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2500 else:
2501 self.ser.addInputTensor(a)
2502 self.ser.addInputTensor(b)
2503 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002504 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002505
Les Bell729b0352021-11-24 10:28:21 +00002506 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002507 self.ser,
2508 validator_fcns,
2509 error_name,
2510 op=op,
2511 a=a,
2512 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002513 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002514 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002515 ):
2516 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002517
Jeremy Johnson587cc842024-02-08 11:45:44 +00002518 compliance = self.tensorComplianceMetaData(
2519 compliance_op, a.dtype, args_dict, result_tensor, error_name
2520 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002521
Jeremy Johnson587cc842024-02-08 11:45:44 +00002522 return TosaTestGen.BuildInfo(result_tensor, compliance)
2523
2524 def build_while_loop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002525 self,
2526 rng,
2527 op,
2528 inputs,
2529 args_dict,
2530 validator_fcns=None,
2531 error_name=None,
2532 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002533 ):
2534 assert len(inputs) == 1
2535 a = inputs[0]
2536 iter_val = args_dict["iterations"]
2537
Kevin Cheng550ccc52021-03-03 11:21:43 -08002538 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002539
Kevin Cheng550ccc52021-03-03 11:21:43 -08002540 cond_block = "COND_BLOCK"
2541 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002542
2543 attr = ts.TosaSerializerAttribute()
2544 attr.WhileLoopAttribute(cond_block, body_block)
2545
2546 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002547 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002548 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002549 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002550
2551 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002552 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2553 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002554 if error_name == ErrorIf.InputListOutputListMismatch:
2555 incorrect_acc = deepcopy(acc)
2556 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002557 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002558 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2559 else:
2560 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002561
2562 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002563 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002564 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002565 [iter.name, a.name, acc.name],
2566 [iter_out.name, a_out.name, acc_out.name],
2567 attr,
2568 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002569 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002570
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002571 if error_name in [
2572 ErrorIf.InputListCondGraphMismatch,
2573 ErrorIf.InputListBodyGraphInputMismatch,
2574 ErrorIf.InputListBodyGraphOutputMismatch,
2575 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002576 incorrect_iter = deepcopy(iter)
2577 for i in range(len(incorrect_iter.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002578 incorrect_iter.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002579 if len(incorrect_iter.shape) == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002580 incorrect_iter.shape.append(rng.choice([-3, -2, 2, 3]))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002581
2582 incorrect_acc = deepcopy(acc)
2583 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002584 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002585
Eric Kunzee5e26762020-10-13 16:11:07 -07002586 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002587 self.ser.addBasicBlock(cond_block)
2588
Matthew Haddon630c17c2021-10-14 15:05:41 +01002589 if error_name == ErrorIf.InputListCondGraphMismatch:
2590 self.ser.addInputTensor(incorrect_iter)
2591 self.ser.addInputTensor(a)
2592 self.ser.addInputTensor(incorrect_acc)
2593 else:
2594 self.ser.addInputTensor(iter)
2595 self.ser.addInputTensor(a)
2596 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002597 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002598
2599 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002600 cond_type = rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002601 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002602 cond_type = DType.BOOL
2603 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002604 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002605 if choice == 1:
2606 cond_shape = [3]
2607 else:
2608 cond_shape = [1, 2]
2609 else:
2610 cond_shape = []
2611 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002612
Kevin Cheng550ccc52021-03-03 11:21:43 -08002613 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002614
2615 # BODY block (input: a, acc, iter, output: a, acc, iter)
2616 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002617 self.ser.addBasicBlock(body_block)
2618
Matthew Haddon630c17c2021-10-14 15:05:41 +01002619 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2620 self.ser.addInputTensor(incorrect_iter)
2621 self.ser.addInputTensor(a)
2622 self.ser.addInputTensor(incorrect_acc)
2623 else:
2624 self.ser.addInputTensor(iter)
2625 self.ser.addInputTensor(a)
2626 self.ser.addInputTensor(acc)
2627
Kevin Cheng550ccc52021-03-03 11:21:43 -08002628 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002629
2630 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002631 iter_body_out = self.ser.addIntermediate(
2632 incorrect_iter.shape, incorrect_iter.dtype
2633 )
2634 acc_body_out = self.ser.addIntermediate(
2635 incorrect_acc.shape, incorrect_acc.dtype
2636 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002637 else:
2638 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2639 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2640
Eric Kunzee5e26762020-10-13 16:11:07 -07002641 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2642 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2643 self.ser.addOutputTensor(iter_body_out)
2644 self.ser.addOutputTensor(a)
2645 self.ser.addOutputTensor(acc_body_out)
2646
Les Bell729b0352021-11-24 10:28:21 +00002647 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002648 self.ser,
2649 validator_fcns,
2650 error_name,
2651 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002652 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002653 ):
2654 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002655
Jeremy Johnson587cc842024-02-08 11:45:44 +00002656 compliance = self.tensorComplianceMetaData(
2657 op, a.dtype, args_dict, acc_out, error_name
2658 )
2659
2660 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002661
Luke Hutton57287132023-02-06 14:54:18 +00002662 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002663 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002664 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002665 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002666 inputs,
2667 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002668 validator_fcns=None,
2669 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002670 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002671 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002672 assert len(inputs) == 2
2673 val1, val2 = inputs
2674 inverse = args_dict["inverse"]
2675
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002676 results = OutputShaper.fft2dOp(self.ser, rng, val1, val2, error_name)
Luke Hutton57287132023-02-06 14:54:18 +00002677
2678 input_names = [val1.name, val2.name]
2679 pCount, cCount = op["operands"]
2680 num_operands = pCount + cCount
2681
2682 output_names = [res.name for res in results]
2683 output_shapes = [res.shape for res in results]
2684 output_dtypes = [res.dtype for res in results]
2685
2686 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002687 rng, error_name, input_names, output_names
Luke Hutton57287132023-02-06 14:54:18 +00002688 )
2689
2690 if not TosaErrorValidator.evValidateErrorIfs(
2691 self.ser,
2692 validator_fcns,
2693 error_name,
2694 op=op,
2695 inverse=inverse,
2696 input1=val1,
2697 input2=val2,
2698 input_shape=val1.shape,
2699 input_dtype=val1.dtype,
2700 output_shape=output_shapes,
2701 output_dtype=output_dtypes,
2702 result_tensors=results,
2703 input_list=input_names,
2704 output_list=output_names,
2705 num_operands=num_operands,
2706 ):
2707 return None
2708
Tai Lyd3797f02023-11-15 23:06:19 +00002709 # TODO - Test local_bound, for now set local bound attribute to False
2710 local_bound = False
2711
Luke Hutton57287132023-02-06 14:54:18 +00002712 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002713 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002714
2715 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002716
2717 compliance = []
2718 for res in results:
2719 compliance.append(
2720 self.tensorComplianceMetaData(
2721 op, val1.dtype, args_dict, res, error_name
2722 )
2723 )
2724
2725 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002726
Tai Lyd3797f02023-11-15 23:06:19 +00002727 def build_rfft2d(
2728 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002729 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002730 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002731 inputs,
2732 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002733 validator_fcns=None,
2734 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002735 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002736 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002737 assert len(inputs) == 1
2738 val = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002739 results = OutputShaper.rfft2dOp(self.ser, rng, val, error_name)
Luke Hutton261b7b62023-01-10 14:50:31 +00002740
2741 input_names = [val.name]
2742 pCount, cCount = op["operands"]
2743 num_operands = pCount + cCount
2744
2745 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002746 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002747 output_dtypes = [res.dtype for res in results]
2748
2749 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002750 rng, error_name, input_names, output_names
Luke Hutton261b7b62023-01-10 14:50:31 +00002751 )
2752
2753 if not TosaErrorValidator.evValidateErrorIfs(
2754 self.ser,
2755 validator_fcns,
2756 error_name,
2757 op=op,
2758 input_shape=val.shape,
2759 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002760 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002761 output_dtype=output_dtypes,
2762 result_tensors=results,
2763 input_list=input_names,
2764 output_list=output_names,
2765 num_operands=num_operands,
2766 ):
2767 return None
2768
Tai Lyd3797f02023-11-15 23:06:19 +00002769 # TODO - Test local_bound, for now set local bound attribute to False
2770 local_bound = False
2771
2772 attr = ts.TosaSerializerAttribute()
2773 attr.RFFTAttribute(local_bound)
2774
2775 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002776
2777 compliance = []
2778 for res in results:
2779 compliance.append(
2780 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2781 )
2782
2783 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002784
Won Jeon74342e52024-01-09 00:34:40 +00002785 def build_shape_op(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002786 self,
2787 rng,
2788 op,
2789 inputs,
2790 args_dict,
2791 validator_fcns=None,
2792 error_name=None,
2793 qinfo=None,
Won Jeon74342e52024-01-09 00:34:40 +00002794 ):
2795 assert len(inputs) == 2
2796 a, b = inputs
2797
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002798 result_tensor = OutputShaper.addShapeOp(self.ser, rng, a, b, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00002799
2800 # Invalidate Input/Output list for error if checks.
2801 input_list = [a.name, b.name]
2802 output_list = [result_tensor.name]
2803 pCount, cCount = op["operands"]
2804 num_operands = pCount + cCount
2805 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2806 self, error_name, input_list, output_list
2807 )
2808
2809 if not TosaErrorValidator.evValidateErrorIfs(
2810 self.ser,
2811 validator_fcns,
2812 error_name,
2813 op=op,
2814 input1=a,
2815 input2=b,
2816 input_shape=a.shape,
2817 input_dtype=a.dtype,
2818 output_shape=result_tensor.shape,
2819 output_dtype=result_tensor.dtype,
2820 result_tensors=[result_tensor],
2821 input_list=input_list,
2822 output_list=output_list,
2823 num_operands=num_operands,
2824 ):
2825 return None
2826
2827 self.ser.addOperator(
2828 op["op"],
2829 input_list,
2830 output_list,
2831 )
2832 compliance = self.tensorComplianceMetaData(
2833 op, a.dtype, args_dict, result_tensor, error_name
2834 )
2835
2836 return TosaTestGen.BuildInfo(result_tensor, compliance)
2837
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002838 def create_filter_lists(
2839 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2840 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002841 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2842 default_test_rank_range = range(1, 5)
2843 if not shapeFilter:
2844 shapeFilter = [None]
2845
2846 # Calculate the filters based on what is requested and what the operator allows
2847 rmin, rmax = op["rank"]
2848 if rankFilter is not None:
2849 cleanRankFilter = []
2850 # Ensure rankFilter values are allowed by operator
2851 for rank in rankFilter:
2852 if rank >= rmin and rank <= rmax:
2853 cleanRankFilter.append(rank)
2854 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002855 # Ensure default behaviour is bounded by default range or by operator,
2856 # whichever is the smaller range of ranks.
2857 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002858 cleanRankFilter = (
2859 opRankRange
2860 if len(opRankRange) <= len(default_test_rank_range)
2861 else default_test_rank_range
2862 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002863 else:
2864 cleanRankFilter = range(rmin, rmax + 1)
2865
2866 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002867
Matthew Haddon1c00b712021-10-01 15:51:03 +01002868 if dtypeFilter is not None:
2869 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002870 # Create list of operator dtypes filtered by requested dtypes
2871 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002872 if dtype in dtypeFilter or (
2873 isinstance(dtype, list) and dtype[0] in dtypeFilter
2874 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002875 cleanDtypeFilter.append(dtype)
2876 else:
2877 cleanDtypeFilter = dtypes
2878
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002879 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002880 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002881 "shapeFilter": shapeFilter,
2882 "rankFilter": cleanRankFilter,
2883 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002884 }
2885 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002886 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002887 if validator is not None:
2888 validator_info = validator(check=False, op=op)
2889 else:
2890 return None
2891
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002892 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002893
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002894 # Set parameters as required
2895 if error_arguments["rank"] is not None:
2896 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002897 else:
2898 rankFilter = cleanRankFilter
2899
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002900 if error_arguments["dtype"] is not None:
2901 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002902 else:
2903 dtypeFilter = cleanDtypeFilter
2904
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002905 if error_arguments["shape"] is not None:
2906 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002907 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002908 shapeFilter = shapeFilter[
2909 :2
2910 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002911
2912 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002913 "shapeFilter": shapeFilter,
2914 "rankFilter": rankFilter,
2915 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002916 }
2917 return filterDict
2918
Kevin Cheng550ccc52021-03-03 11:21:43 -08002919 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002920 self,
2921 opName,
2922 shapeFilter=[None],
2923 rankFilter=None,
2924 dtypeFilter=None,
2925 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002926 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002927
2928 try:
2929 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002930 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002931 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002932
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002933 if not self.args.stable_rng:
2934 # Initialize a new random number generator per op
2935 self.resetGlobalRNG()
Eric Kunzee5e26762020-10-13 16:11:07 -07002936
Jeremy Johnson1271c442023-09-05 11:39:26 +01002937 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002938
Eric Kunzee5e26762020-10-13 16:11:07 -07002939 # Test list consists of a tuple of:
2940 # (opName, testNameStr, dtype, shapeList, argumentsList)
2941 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002942 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002943 error_if_validators = op["error_if_validators"]
2944 else:
2945 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002946
Matthew Haddon1c00b712021-10-01 15:51:03 +01002947 for validator in error_if_validators:
2948 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002949 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002950 else:
2951 error_name = None
2952
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002953 filterDict = self.create_filter_lists(
2954 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2955 )
2956 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002957 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002958 cleanRankFilter = filterDict["rankFilter"]
2959 cleanDtypeFilter = filterDict["dtypeFilter"]
2960 cleanShapeFilter = filterDict["shapeFilter"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002961 logger.debug(
2962 f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
2963 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002964
2965 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002966 for t in cleanDtypeFilter:
2967 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002968 # Filter out by rank
2969 if shape is not None and len(shape) != r:
2970 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002971 self.setTargetShape(shape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002972 typeStr = self.typeStr(t)
2973 if self.args.stable_rng:
2974 shape_rng = TosaHashRandomGenerator(
2975 self.random_seed,
2976 [opName, r, typeStr],
2977 self.random_dtype_range,
2978 )
2979 else:
2980 shape_rng = self.global_rng
2981 shapeList = tgen_fcn(self, shape_rng, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002982
Matthew Haddon74567092021-07-16 15:38:20 +01002983 shapeStr = self.shapeStr(shapeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07002984
Matthew Haddon74567092021-07-16 15:38:20 +01002985 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2986 argList = []
2987 if agen_fcn:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002988 if self.args.stable_rng:
2989 arg_rng = TosaHashRandomGenerator(
2990 self.random_seed,
2991 [opName, shapeStr, typeStr],
2992 self.random_dtype_range,
2993 )
2994 else:
2995 arg_rng = self.global_rng
2996
2997 argList = agen_fcn(
2998 self, arg_rng, opName, shapeList, t, error_name
2999 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003000 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003001 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003002
Matthew Haddon74567092021-07-16 15:38:20 +01003003 for argStr, args in argList:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003004 # Create the test name string - for example: add_1x2x3_i32
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003005 if testType == "positive":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003006 name_parts = [opName, shapeStr, typeStr]
3007 else:
3008 assert testType == "negative"
3009 name_parts = [
3010 opName,
3011 "ERRORIF",
3012 error_name,
3013 shapeStr,
3014 typeStr,
3015 ]
3016 if argStr:
3017 name_parts.append(argStr)
3018 testStr = "_".join(name_parts)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003019
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003020 testList.append(
3021 (opName, testStr, t, error_name, shapeList, args)
3022 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003023
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003024 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01003025 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3026 if "invalid_test_validators" in op:
3027 invalid_test_validators = op["invalid_test_validators"]
3028 clean_testList = []
3029 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01003030 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003031 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003032 if validator_fcn(
3033 opName=test[0],
3034 input_dtype=test[2],
3035 shapeList=test[4],
3036 args=test[5],
3037 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003038 remove_test = True
3039 if not remove_test:
3040 clean_testList.append(test)
3041 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07003042
3043 return testList
3044
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003045 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00003046 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003047 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003048 try:
3049 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003050 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003051 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003052
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003053 logger.info(f"Creating {testStr}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01003054
Eric Kunzee5e26762020-10-13 16:11:07 -07003055 # Create a serializer
3056 self.createSerializer(opName, testStr)
3057
Jeremy Johnson1271c442023-09-05 11:39:26 +01003058 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003059 if "error_if_validators" in op:
3060 error_if_validators = op["error_if_validators"]
3061 else:
3062 error_if_validators = None
3063
Kevin Cheng550ccc52021-03-03 11:21:43 -08003064 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003065 num_operands = pCount + cCount
3066
3067 if isinstance(dtype_or_dtypeList, list):
3068 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003069 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003070 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003071 else:
3072 dtypeList = [dtype_or_dtypeList] * (num_operands)
3073
Won Jeon74342e52024-01-09 00:34:40 +00003074 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003075 assert (
3076 len(shapeList) == num_operands
3077 ), "shapeList length {} must match number of operands {}".format(
3078 len(shapeList), num_operands
3079 )
3080 assert (
3081 len(dtypeList) == num_operands
3082 ), "dtypeList length {} must match number of operands {}".format(
3083 len(dtypeList), num_operands
3084 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003085
3086 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003087 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003088 except KeyError:
3089 qgen = None
3090
3091 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003092
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003093 # Set the random number generator
3094 if self.args.stable_rng:
3095 build_rng = TosaHashRandomGenerator(
3096 self.random_seed, [testStr], self.random_dtype_range
3097 )
3098 else:
3099 build_rng = self.global_rng
3100
Matthew Haddon1c00b712021-10-01 15:51:03 +01003101 if qgen is not None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003102 qinfo = qgen(
3103 build_rng, self.args.zeropoint, op, dtype_or_dtypeList, error_name
3104 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003105 else:
3106 qinfo = None
3107
Jeremy Johnson1271c442023-09-05 11:39:26 +01003108 # Extra meta data for the desc.json
3109 tensMeta = {}
3110
Jeremy Johnson587cc842024-02-08 11:45:44 +00003111 # Check we are using the new interface with an argsDict dictionary
3112 assert isinstance(
3113 argsDict, dict
3114 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003115
Jeremy Johnson587cc842024-02-08 11:45:44 +00003116 # New interface with args info in dictionary
3117 assert "dg_type" in argsDict
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003118 tvgInfo = tvgen_fcn(
3119 self, build_rng, opName, dtypeList, shapeList, argsDict, error_name
3120 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003121 if tvgInfo.dataGenDict:
3122 tensMeta["data_gen"] = tvgInfo.dataGenDict
3123 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003124
Jeremy Johnson587cc842024-02-08 11:45:44 +00003125 result = build_fcn(
3126 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003127 build_rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003128 op,
3129 tens,
3130 argsDict,
3131 validator_fcns=error_if_validators,
3132 error_name=error_name,
3133 qinfo=qinfo,
3134 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003135
Jeremy Johnson1271c442023-09-05 11:39:26 +01003136 if result:
Les Bell729b0352021-11-24 10:28:21 +00003137 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003138 if isinstance(result, TosaTestGen.BuildInfo):
3139 # Add the compliance meta data (if any)
3140 compliance = result.getComplianceInfo()
3141 if compliance:
3142 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003143 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003144 else:
3145 # The test is not valid
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003146 logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003147
Eric Kunzee5e26762020-10-13 16:11:07 -07003148 def createDynamicOpLists(self):
3149
Jeremy Johnson00423432022-09-12 17:27:37 +01003150 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3151 # Already created these lists (can occur when class is initialized more than once)
3152 return
3153
Eric Kunzee5e26762020-10-13 16:11:07 -07003154 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003155 if not self.args.level8k:
3156 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3157 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3158 else:
3159 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3160 KERNELS_2D = [[1, bigK], [bigK, 2]]
3161 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003162
Kevin Cheng1533b852021-09-01 12:51:58 -07003163 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003164 testName = "conv2d_{}x{}".format(k[0], k[1])
3165 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3166 self.TOSA_OP_LIST[testName]["filter"] = k
3167 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003168 self.TOSA_OP_LIST[testName]["real_name"] = "conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003169
Kevin Cheng550ccc52021-03-03 11:21:43 -08003170 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3171 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3172 "depthwise_conv2d_TEMPLATE"
3173 ].copy()
3174 self.TOSA_OP_LIST[testName]["filter"] = k
3175 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003176 self.TOSA_OP_LIST[testName]["real_name"] = "depthwise_conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003177
Kevin Cheng550ccc52021-03-03 11:21:43 -08003178 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3179 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3180 "transpose_conv2d_TEMPLATE"
3181 ].copy()
3182 self.TOSA_OP_LIST[testName]["filter"] = k
3183 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003184 self.TOSA_OP_LIST[testName]["real_name"] = "transpose_conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003185
Kevin Cheng1533b852021-09-01 12:51:58 -07003186 for k in KERNELS_3D:
3187 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3188 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3189 self.TOSA_OP_LIST[testName]["filter"] = k
3190 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003191 self.TOSA_OP_LIST[testName]["real_name"] = "conv3d"
Kevin Cheng1533b852021-09-01 12:51:58 -07003192
Eric Kunzee5e26762020-10-13 16:11:07 -07003193 # Delete any templates after having created any dynamic ops
3194 # This is a two-pass operation because it's bad practice to delete
3195 # keys from dictionaries while iterating
3196 keyList = []
3197 for k in self.TOSA_OP_LIST:
3198 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003199 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003200 keyList.append(k)
3201 continue
3202 except KeyError:
3203 pass
3204
3205 for k in keyList:
3206 del self.TOSA_OP_LIST[k]
3207
3208 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003209 """Fill in default fields for ops if they aren't already specified.
3210 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003211 for op in self.TOSA_OP_LIST:
3212
3213 # Required fields
3214 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003215 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003216 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003217 raise Exception(
3218 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3219 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003220
3221 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003222 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003223 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003224 raise Exception(
3225 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3226 op
3227 )
3228 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003229
3230 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003231 _ = self.TOSA_OP_LIST[op]["types"]
3232 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003233 raise Exception(
3234 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3235 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003236
3237 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003238 _ = self.TOSA_OP_LIST[op]["op"]
3239 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003240 raise Exception(
3241 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3242 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003243
3244 # Put in default rank range, if missing
3245 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003246 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003247 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003248 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003249
3250 # Tensor operator list
3251 # 'op': op name
3252 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003253 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3254 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003255 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3256 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003257 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003258
Kevin Cheng550ccc52021-03-03 11:21:43 -08003259 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003260 TYPE_INT_FP = [
3261 DType.INT8,
3262 DType.INT16,
3263 DType.INT32,
3264 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003265 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003266 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003267 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003268
Kevin Cheng550ccc52021-03-03 11:21:43 -08003269 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003270 TYPE_FI32 = [
3271 DType.FP32,
3272 DType.FP16,
3273 DType.BF16,
3274 DType.INT32,
3275 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003276 TYPE_FIB = [
3277 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003278 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003279 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003280 DType.INT8,
3281 DType.INT16,
3282 DType.INT32,
3283 DType.BOOL,
3284 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003285 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003286
Won Jeon2c34b462024-02-06 18:37:00 +00003287 TYPE_NARROW_INT_FP = [
3288 DType.INT8,
3289 DType.INT16,
3290 DType.FP16,
3291 DType.BF16,
3292 DType.FP32,
3293 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003294
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003295 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003296 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003297 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003298 [DType.INT8, DType.INT8, DType.INT32],
3299 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003300 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003301 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003302 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003303 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003304 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3305 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003306 ]
3307
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003308 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003309
3310 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003311 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003312 "argmax": {
3313 "op": Op.ARGMAX,
3314 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003315 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003316 "build_fcn": (
3317 build_argmax,
3318 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003319 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003320 TosaArgGen.agAxis,
3321 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003322 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003323 "error_if_validators": (
3324 TosaErrorValidator.evAxisSmallerZero,
3325 TosaErrorValidator.evAxisLargerRank,
3326 TosaErrorValidator.evArgmaxOutputRankMismatch,
3327 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3328 TosaErrorValidator.evWrongRank,
3329 TosaErrorValidator.evWrongInputType,
3330 TosaErrorValidator.evWrongOutputType,
3331 TosaErrorValidator.evWrongInputList,
3332 TosaErrorValidator.evWrongOutputList,
3333 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003334 "data_gen": {
3335 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3336 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003337 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "avg_pool2d": {
3339 "op": Op.AVG_POOL2D,
3340 "operands": (1, 0),
3341 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 "build_fcn": (
3343 build_pool2d,
3344 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003345 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003346 TosaArgGen.agPooling,
3347 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003348 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003349 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003350 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003351 "error_if_validators": (
3352 TosaErrorValidator.evKernelSmallerOne,
3353 TosaErrorValidator.evStrideSmallerOne,
3354 TosaErrorValidator.evPadSmallerZero,
3355 TosaErrorValidator.evWrongRank,
3356 TosaErrorValidator.evWrongInputType,
3357 TosaErrorValidator.evWrongOutputType,
3358 TosaErrorValidator.evWrongInputList,
3359 TosaErrorValidator.evWrongOutputList,
3360 TosaErrorValidator.evInputZeroPointNotZero,
3361 TosaErrorValidator.evOutputZeroPointNotZero,
3362 TosaErrorValidator.evPadLargerEqualKernel,
3363 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003364 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003365 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003366 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003367 "data_gen": {
3368 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3369 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003370 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003371 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003372 "conv2d_TEMPLATE": {
3373 "op": Op.CONV2D,
3374 "operands": (1, 2),
3375 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003376 "build_fcn": (
3377 build_conv2d,
3378 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003379 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003380 TosaArgGen.agConv,
3381 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003382 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003383 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003384 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3385 "error_if_validators": (
3386 TosaErrorValidator.evWrongInputType,
3387 TosaErrorValidator.evWrongOutputType,
3388 TosaErrorValidator.evWrongInputList,
3389 TosaErrorValidator.evWrongOutputList,
3390 TosaErrorValidator.evInputZeroPointNotZero,
3391 TosaErrorValidator.evWeightZeroPointNotZero,
3392 TosaErrorValidator.evPadSmallerZero,
3393 TosaErrorValidator.evStrideSmallerOne,
3394 TosaErrorValidator.evDilationSmallerOne,
3395 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003396 TosaErrorValidator.evConvOutputShapeMismatch,
3397 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003398 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003399 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003400 "data_gen": {
3401 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3402 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003403 "template": True,
3404 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003405 # Templated operator. Filled in by createDynamicOpLists
3406 "conv3d_TEMPLATE": {
3407 "op": Op.CONV3D,
3408 "operands": (1, 2),
3409 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003410 "build_fcn": (
3411 build_conv3d,
3412 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003413 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003414 TosaArgGen.agConv,
3415 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003416 "qgen": TosaQuantGen.qgConv,
3417 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003418 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3419 "error_if_validators": (
3420 TosaErrorValidator.evWrongInputType,
3421 TosaErrorValidator.evWrongOutputType,
3422 TosaErrorValidator.evWrongInputList,
3423 TosaErrorValidator.evWrongOutputList,
3424 TosaErrorValidator.evInputZeroPointNotZero,
3425 TosaErrorValidator.evWeightZeroPointNotZero,
3426 TosaErrorValidator.evPadSmallerZero,
3427 TosaErrorValidator.evStrideSmallerOne,
3428 TosaErrorValidator.evDilationSmallerOne,
3429 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003430 TosaErrorValidator.evConvOutputShapeMismatch,
3431 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003432 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003433 ),
evacha0147ab1762024-01-29 13:23:23 +00003434 "data_gen": {
3435 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3436 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003437 "template": True,
3438 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003439 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003440 "depthwise_conv2d_TEMPLATE": {
3441 "op": Op.DEPTHWISE_CONV2D,
3442 "operands": (1, 2),
3443 "filter": [1, 1],
3444 "rank": (4, 4),
3445 "build_fcn": (
3446 build_depthwise_conv2d,
3447 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003448 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003449 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003450 ),
3451 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003452 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003453 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3454 "error_if_validators": (
3455 TosaErrorValidator.evWrongInputType,
3456 TosaErrorValidator.evWrongOutputType,
3457 TosaErrorValidator.evWrongInputList,
3458 TosaErrorValidator.evWrongOutputList,
3459 TosaErrorValidator.evInputZeroPointNotZero,
3460 TosaErrorValidator.evWeightZeroPointNotZero,
3461 TosaErrorValidator.evPadSmallerZero,
3462 TosaErrorValidator.evStrideSmallerOne,
3463 TosaErrorValidator.evDilationSmallerOne,
3464 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003465 TosaErrorValidator.evConvOutputShapeMismatch,
3466 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003467 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003468 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003469 "data_gen": {
3470 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3471 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003472 "template": True,
3473 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003474 "fully_connected": {
3475 "op": Op.FULLY_CONNECTED,
3476 "operands": (1, 2),
3477 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003478 "build_fcn": (
3479 build_fully_connected,
3480 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003481 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003482 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003483 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003484 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003485 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003486 "error_if_validators": (
3487 TosaErrorValidator.evInputZeroPointNotZero,
3488 TosaErrorValidator.evWeightZeroPointNotZero,
3489 TosaErrorValidator.evWrongRank,
3490 TosaErrorValidator.evWrongInputType,
3491 TosaErrorValidator.evWrongOutputType,
3492 TosaErrorValidator.evWrongInputList,
3493 TosaErrorValidator.evWrongOutputList,
3494 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003495 "data_gen": {
3496 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3497 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003498 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003499 "matmul": {
3500 "op": Op.MATMUL,
3501 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003502 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003503 "build_fcn": (
3504 build_matmul,
3505 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003506 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003507 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003508 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003509 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003510 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003511 "error_if_validators": (
3512 TosaErrorValidator.evInputZeroPointNotZero,
3513 TosaErrorValidator.evWrongRank,
3514 TosaErrorValidator.evWrongInputType,
3515 TosaErrorValidator.evWrongOutputType,
3516 TosaErrorValidator.evWrongInputList,
3517 TosaErrorValidator.evWrongOutputList,
3518 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003519 "data_gen": {
3520 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003521 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003522 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003523 "max_pool2d": {
3524 "op": Op.MAX_POOL2D,
3525 "operands": (1, 0),
3526 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003527 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003528 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003529 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003530 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003531 TosaArgGen.agPooling,
3532 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003533 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003534 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003535 "error_if_validators": (
3536 TosaErrorValidator.evKernelSmallerOne,
3537 TosaErrorValidator.evStrideSmallerOne,
3538 TosaErrorValidator.evPadSmallerZero,
3539 TosaErrorValidator.evWrongRank,
3540 TosaErrorValidator.evWrongInputType,
3541 TosaErrorValidator.evWrongOutputType,
3542 TosaErrorValidator.evWrongInputList,
3543 TosaErrorValidator.evWrongOutputList,
3544 TosaErrorValidator.evPadLargerEqualKernel,
3545 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003546 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003547 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003548 "data_gen": {
3549 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3550 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003551 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003552 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003553 "transpose_conv2d_TEMPLATE": {
3554 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003555 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003556 "rank": (4, 4),
3557 "build_fcn": (
3558 build_transpose_conv2d,
3559 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003560 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003561 TosaArgGen.agTransposeConv2D,
3562 ),
3563 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003564 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003565 "invalid_test_validators": (
3566 TosaInvalidValidator.ivHeightWidthInvalid,
3567 TosaInvalidValidator.ivNonPositiveOutputShape,
3568 ),
3569 "error_if_validators": (
3570 TosaErrorValidator.evWrongInputType,
3571 TosaErrorValidator.evWrongOutputType,
3572 TosaErrorValidator.evWrongInputList,
3573 TosaErrorValidator.evWrongOutputList,
3574 TosaErrorValidator.evInputZeroPointNotZero,
3575 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003576 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003577 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003578 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003579 TosaErrorValidator.evConvOutputShapeMismatch,
Tai Lyf36f2562024-03-14 16:21:29 +00003580 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003581 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003582 "data_gen": {
3583 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3584 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003585 "template": True,
3586 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003587 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003588 "clamp": {
3589 "op": Op.CLAMP,
3590 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003591 "build_fcn": (
3592 build_clamp,
3593 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003594 TosaTensorValuesGen.tvgLazyGenDefault,
3595 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003596 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003597 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003598 "error_if_validators": (
3599 TosaErrorValidator.evMaxSmallerMin,
3600 TosaErrorValidator.evWrongInputType,
3601 TosaErrorValidator.evWrongOutputType,
3602 TosaErrorValidator.evWrongInputList,
3603 TosaErrorValidator.evWrongOutputList,
3604 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003605 "data_gen": {
3606 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3607 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003608 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003609 "sigmoid": {
3610 "op": Op.SIGMOID,
3611 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003612 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003613 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003614 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003615 TosaTensorValuesGen.tvgLazyGenDefault,
3616 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003617 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003618 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003619 "error_if_validators": (
3620 TosaErrorValidator.evWrongInputType,
3621 TosaErrorValidator.evWrongOutputType,
3622 TosaErrorValidator.evWrongInputList,
3623 TosaErrorValidator.evWrongOutputList,
3624 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003625 "data_gen": {
3626 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3627 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003628 },
3629 "tanh": {
3630 "op": Op.TANH,
3631 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003632 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003633 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003634 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003635 TosaTensorValuesGen.tvgLazyGenDefault,
3636 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003637 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003638 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003639 "error_if_validators": (
3640 TosaErrorValidator.evWrongInputType,
3641 TosaErrorValidator.evWrongOutputType,
3642 TosaErrorValidator.evWrongInputList,
3643 TosaErrorValidator.evWrongOutputList,
3644 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003645 "data_gen": {
3646 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3647 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003648 "compliance": {
3649 "abs_error_lower_bound": 0.5,
3650 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003651 },
Won Jeon78155c62023-06-10 00:20:04 +00003652 "erf": {
3653 "op": Op.ERF,
3654 "operands": (1, 0),
3655 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003656 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003657 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003658 TosaTensorValuesGen.tvgLazyGenDefault,
3659 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003660 ),
3661 "types": TYPE_FP,
3662 "error_if_validators": (
3663 TosaErrorValidator.evWrongInputType,
3664 TosaErrorValidator.evWrongOutputType,
3665 TosaErrorValidator.evWrongInputList,
3666 TosaErrorValidator.evWrongOutputList,
3667 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003668 "data_gen": {
3669 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3670 },
3671 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003672 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003673 # Elementwise Binary Operators
3674 "add": {
3675 "op": Op.ADD,
3676 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003677 "build_fcn": (
3678 build_binary_broadcast,
3679 TosaTensorGen.tgBroadcastFuzz,
3680 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003681 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003682 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003683 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003684 "error_if_validators": (
3685 TosaErrorValidator.evRankMismatch,
3686 TosaErrorValidator.evWrongInputType,
3687 TosaErrorValidator.evWrongOutputType,
3688 TosaErrorValidator.evWrongInputList,
3689 TosaErrorValidator.evWrongOutputList,
3690 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003691 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003692 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003693 "data_gen": {
3694 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3695 },
3696 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003697 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003698 "arithmetic_right_shift": {
3699 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3700 "operands": (2, 0),
3701 "build_fcn": (
3702 build_arithmetic_right_shift,
3703 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003704 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003705 TosaArgGen.agArithmeticRightShift,
3706 ),
3707 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003708 "error_if_validators": (
3709 TosaErrorValidator.evRankMismatch,
3710 TosaErrorValidator.evWrongInputType,
3711 TosaErrorValidator.evWrongOutputType,
3712 TosaErrorValidator.evWrongInputList,
3713 TosaErrorValidator.evWrongOutputList,
3714 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003715 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003716 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003717 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003718 "bitwise_and": {
3719 "op": Op.BITWISE_AND,
3720 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003721 "build_fcn": (
3722 build_binary_broadcast,
3723 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003724 TosaTensorValuesGen.tvgLazyGenDefault,
3725 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003726 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003727 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003728 "error_if_validators": (
3729 TosaErrorValidator.evRankMismatch,
3730 TosaErrorValidator.evWrongInputType,
3731 TosaErrorValidator.evWrongOutputType,
3732 TosaErrorValidator.evWrongInputList,
3733 TosaErrorValidator.evWrongOutputList,
3734 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003735 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003736 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003737 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003738 "bitwise_or": {
3739 "op": Op.BITWISE_OR,
3740 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003741 "build_fcn": (
3742 build_binary_broadcast,
3743 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003744 TosaTensorValuesGen.tvgLazyGenDefault,
3745 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003746 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003747 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003748 "error_if_validators": (
3749 TosaErrorValidator.evRankMismatch,
3750 TosaErrorValidator.evWrongInputType,
3751 TosaErrorValidator.evWrongOutputType,
3752 TosaErrorValidator.evWrongInputList,
3753 TosaErrorValidator.evWrongOutputList,
3754 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003755 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003756 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003757 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003758 "bitwise_xor": {
3759 "op": Op.BITWISE_XOR,
3760 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003761 "build_fcn": (
3762 build_binary_broadcast,
3763 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003764 TosaTensorValuesGen.tvgLazyGenDefault,
3765 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003766 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003767 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003768 "error_if_validators": (
3769 TosaErrorValidator.evRankMismatch,
3770 TosaErrorValidator.evWrongInputType,
3771 TosaErrorValidator.evWrongOutputType,
3772 TosaErrorValidator.evWrongInputList,
3773 TosaErrorValidator.evWrongOutputList,
3774 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003775 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003776 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003777 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003778 "intdiv": {
3779 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003780 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003781 "build_fcn": (
3782 build_binary_broadcast,
3783 TosaTensorGen.tgBroadcastFuzz,
3784 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003785 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003786 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003787 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003788 "error_if_validators": (
3789 TosaErrorValidator.evRankMismatch,
3790 TosaErrorValidator.evWrongInputType,
3791 TosaErrorValidator.evWrongOutputType,
3792 TosaErrorValidator.evWrongInputList,
3793 TosaErrorValidator.evWrongOutputList,
3794 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003795 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003796 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003797 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003798 "logical_and": {
3799 "op": Op.LOGICAL_AND,
3800 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003801 "build_fcn": (
3802 build_binary_broadcast,
3803 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003804 TosaTensorValuesGen.tvgLazyGenDefault,
3805 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003806 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003807 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003808 "error_if_validators": (
3809 TosaErrorValidator.evRankMismatch,
3810 TosaErrorValidator.evWrongInputType,
3811 TosaErrorValidator.evWrongOutputType,
3812 TosaErrorValidator.evWrongInputList,
3813 TosaErrorValidator.evWrongOutputList,
3814 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003815 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003816 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003817 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003818 "logical_left_shift": {
3819 "op": Op.LOGICAL_LEFT_SHIFT,
3820 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003821 "build_fcn": (
3822 build_binary_broadcast,
3823 TosaTensorGen.tgBroadcastFuzz,
3824 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003825 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003826 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003827 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003828 "error_if_validators": (
3829 TosaErrorValidator.evRankMismatch,
3830 TosaErrorValidator.evWrongInputType,
3831 TosaErrorValidator.evWrongOutputType,
3832 TosaErrorValidator.evWrongInputList,
3833 TosaErrorValidator.evWrongOutputList,
3834 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003835 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003836 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003837 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003838 "logical_right_shift": {
3839 "op": Op.LOGICAL_RIGHT_SHIFT,
3840 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003841 "build_fcn": (
3842 build_binary_broadcast,
3843 TosaTensorGen.tgBroadcastFuzz,
3844 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003845 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003846 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003847 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003848 "error_if_validators": (
3849 TosaErrorValidator.evRankMismatch,
3850 TosaErrorValidator.evWrongInputType,
3851 TosaErrorValidator.evWrongOutputType,
3852 TosaErrorValidator.evWrongInputList,
3853 TosaErrorValidator.evWrongOutputList,
3854 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003855 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003856 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003857 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003858 "logical_or": {
3859 "op": Op.LOGICAL_OR,
3860 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003861 "build_fcn": (
3862 build_binary_broadcast,
3863 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003864 TosaTensorValuesGen.tvgLazyGenDefault,
3865 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003866 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003867 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003868 "error_if_validators": (
3869 TosaErrorValidator.evRankMismatch,
3870 TosaErrorValidator.evWrongInputType,
3871 TosaErrorValidator.evWrongOutputType,
3872 TosaErrorValidator.evWrongInputList,
3873 TosaErrorValidator.evWrongOutputList,
3874 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003875 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003876 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003877 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003878 "logical_xor": {
3879 "op": Op.LOGICAL_XOR,
3880 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003881 "build_fcn": (
3882 build_binary_broadcast,
3883 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003884 TosaTensorValuesGen.tvgLazyGenDefault,
3885 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003886 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003887 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003888 "error_if_validators": (
3889 TosaErrorValidator.evRankMismatch,
3890 TosaErrorValidator.evWrongInputType,
3891 TosaErrorValidator.evWrongOutputType,
3892 TosaErrorValidator.evWrongInputList,
3893 TosaErrorValidator.evWrongOutputList,
3894 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003895 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003896 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003897 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003898 "maximum": {
3899 "op": Op.MAXIMUM,
3900 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003901 "build_fcn": (
3902 build_binary_broadcast,
3903 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003904 TosaTensorValuesGen.tvgLazyGenDefault,
3905 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003906 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003907 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003908 "error_if_validators": (
3909 TosaErrorValidator.evRankMismatch,
3910 TosaErrorValidator.evWrongInputType,
3911 TosaErrorValidator.evWrongOutputType,
3912 TosaErrorValidator.evWrongInputList,
3913 TosaErrorValidator.evWrongOutputList,
3914 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003915 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003916 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003917 "data_gen": {
3918 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3919 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 "minimum": {
3922 "op": Op.MINIMUM,
3923 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003924 "build_fcn": (
3925 build_binary_broadcast,
3926 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003927 TosaTensorValuesGen.tvgLazyGenDefault,
3928 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003929 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003931 "error_if_validators": (
3932 TosaErrorValidator.evRankMismatch,
3933 TosaErrorValidator.evWrongInputType,
3934 TosaErrorValidator.evWrongOutputType,
3935 TosaErrorValidator.evWrongInputList,
3936 TosaErrorValidator.evWrongOutputList,
3937 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003938 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003940 "data_gen": {
3941 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3942 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003943 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003944 "mul": {
3945 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003946 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003947 "build_fcn": (
3948 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003949 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003950 TosaTensorValuesGen.tvgMul,
3951 TosaArgGen.agMul,
3952 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003954 "error_if_validators": (
3955 TosaErrorValidator.evWrongInputType,
3956 TosaErrorValidator.evWrongOutputType,
3957 TosaErrorValidator.evWrongInputList,
3958 TosaErrorValidator.evWrongOutputList,
3959 TosaErrorValidator.evRankMismatch,
3960 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003961 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003962 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003963 "data_gen": {
3964 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3965 },
3966 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003967 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003968 "pow": {
3969 "op": Op.POW,
3970 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003971 "build_fcn": (
3972 build_binary_broadcast,
3973 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003974 TosaTensorValuesGen.tvgPow,
3975 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003976 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003977 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003978 "error_if_validators": (
3979 TosaErrorValidator.evRankMismatch,
3980 TosaErrorValidator.evWrongInputType,
3981 TosaErrorValidator.evWrongOutputType,
3982 TosaErrorValidator.evWrongInputList,
3983 TosaErrorValidator.evWrongOutputList,
3984 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003985 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003986 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003987 "data_gen": {
3988 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3989 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003991 "sub": {
3992 "op": Op.SUB,
3993 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003994 "build_fcn": (
3995 build_binary_broadcast,
3996 TosaTensorGen.tgBroadcastFuzz,
3997 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003998 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004001 "error_if_validators": (
4002 TosaErrorValidator.evRankMismatch,
4003 TosaErrorValidator.evWrongInputType,
4004 TosaErrorValidator.evWrongOutputType,
4005 TosaErrorValidator.evWrongInputList,
4006 TosaErrorValidator.evWrongOutputList,
4007 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004008 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004009 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004010 "data_gen": {
4011 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4012 },
4013 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004014 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004015 "table": {
4016 "op": Op.TABLE,
4017 # Use the automatic generation functions to create the input array
4018 # but create the table tensor in the build function, as it may be
4019 # a different type from the input
4020 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004021 "build_fcn": (
4022 build_table,
4023 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004024 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004025 TosaArgGen.agTable,
4026 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004027 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004028 "error_if_validators": (
4029 TosaErrorValidator.evWrongInputType,
4030 TosaErrorValidator.evWrongOutputType,
4031 TosaErrorValidator.evWrongInputList,
4032 TosaErrorValidator.evWrongOutputList,
4033 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004034 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004035 # Elementwise Unary operators
4036 "abs": {
4037 "op": Op.ABS,
4038 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004039 "build_fcn": (
4040 build_unary,
4041 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004042 TosaTensorValuesGen.tvgLazyGenDefault,
4043 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004044 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004045 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004046 "error_if_validators": (
4047 TosaErrorValidator.evWrongInputType,
4048 TosaErrorValidator.evWrongOutputType,
4049 TosaErrorValidator.evWrongInputList,
4050 TosaErrorValidator.evWrongOutputList,
4051 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004052 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004053 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004054 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004055 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004056 "bitwise_not": {
4057 "op": Op.BITWISE_NOT,
4058 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004059 "build_fcn": (
4060 build_unary,
4061 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004062 TosaTensorValuesGen.tvgLazyGenDefault,
4063 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004064 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004065 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004066 "error_if_validators": (
4067 TosaErrorValidator.evWrongInputType,
4068 TosaErrorValidator.evWrongOutputType,
4069 TosaErrorValidator.evWrongInputList,
4070 TosaErrorValidator.evWrongOutputList,
4071 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004073 "ceil": {
4074 "op": Op.CEIL,
4075 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004076 "build_fcn": (
4077 build_unary,
4078 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004079 TosaTensorValuesGen.tvgLazyGenDefault,
4080 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004081 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004083 "error_if_validators": (
4084 TosaErrorValidator.evWrongInputType,
4085 TosaErrorValidator.evWrongOutputType,
4086 TosaErrorValidator.evWrongInputList,
4087 TosaErrorValidator.evWrongOutputList,
4088 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004089 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004090 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004091 },
4092 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004094 "clz": {
4095 "op": Op.CLZ,
4096 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004097 "build_fcn": (
4098 build_unary,
4099 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004100 TosaTensorValuesGen.tvgLazyGenDefault,
4101 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004102 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004103 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004104 "error_if_validators": (
4105 TosaErrorValidator.evWrongInputType,
4106 TosaErrorValidator.evWrongOutputType,
4107 TosaErrorValidator.evWrongInputList,
4108 TosaErrorValidator.evWrongOutputList,
4109 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004110 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004111 "cos": {
4112 "op": Op.COS,
4113 "operands": (1, 0),
4114 "build_fcn": (
4115 build_unary,
4116 TosaTensorGen.tgBasic,
4117 TosaTensorValuesGen.tvgLazyGenDefault,
4118 TosaArgGen.agNone,
4119 ),
4120 "types": TYPE_FP,
4121 "error_if_validators": (
4122 TosaErrorValidator.evWrongInputType,
4123 TosaErrorValidator.evWrongOutputType,
4124 TosaErrorValidator.evWrongInputList,
4125 TosaErrorValidator.evWrongOutputList,
4126 ),
4127 "data_gen": {
4128 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4129 },
4130 "compliance": {"abs_error_normal_divisor": 2},
4131 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004132 "exp": {
4133 "op": Op.EXP,
4134 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004135 "build_fcn": (
4136 build_unary,
4137 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004138 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004139 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004140 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004141 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004142 "error_if_validators": (
4143 TosaErrorValidator.evWrongInputType,
4144 TosaErrorValidator.evWrongOutputType,
4145 TosaErrorValidator.evWrongInputList,
4146 TosaErrorValidator.evWrongOutputList,
4147 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004148 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004149 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004150 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004151 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004152 "floor": {
4153 "op": Op.FLOOR,
4154 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004155 "build_fcn": (
4156 build_unary,
4157 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004158 TosaTensorValuesGen.tvgLazyGenDefault,
4159 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004160 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004161 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004162 "error_if_validators": (
4163 TosaErrorValidator.evWrongInputType,
4164 TosaErrorValidator.evWrongOutputType,
4165 TosaErrorValidator.evWrongInputList,
4166 TosaErrorValidator.evWrongOutputList,
4167 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004168 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004169 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004170 },
4171 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004172 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004173 "log": {
4174 "op": Op.LOG,
4175 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004176 "build_fcn": (
4177 build_unary,
4178 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004179 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004180 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004181 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004182 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004183 "error_if_validators": (
4184 TosaErrorValidator.evWrongInputType,
4185 TosaErrorValidator.evWrongOutputType,
4186 TosaErrorValidator.evWrongInputList,
4187 TosaErrorValidator.evWrongOutputList,
4188 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004189 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004190 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004191 },
4192 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004193 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004194 "logical_not": {
4195 "op": Op.LOGICAL_NOT,
4196 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004197 "build_fcn": (
4198 build_unary,
4199 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004200 TosaTensorValuesGen.tvgLazyGenDefault,
4201 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004202 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004203 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004204 "error_if_validators": (
4205 TosaErrorValidator.evWrongInputType,
4206 TosaErrorValidator.evWrongOutputType,
4207 TosaErrorValidator.evWrongInputList,
4208 TosaErrorValidator.evWrongOutputList,
4209 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004210 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004211 "negate": {
4212 "op": Op.NEGATE,
4213 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004214 "build_fcn": (
4215 build_unary,
4216 TosaTensorGen.tgBasic,
4217 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004218 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004219 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004220 "qgen": TosaQuantGen.qgUnary,
4221 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004222 "error_if_validators": (
4223 TosaErrorValidator.evInputZeroPointNotZero,
4224 TosaErrorValidator.evOutputZeroPointNotZero,
4225 TosaErrorValidator.evWrongInputType,
4226 TosaErrorValidator.evWrongOutputType,
4227 TosaErrorValidator.evWrongInputList,
4228 TosaErrorValidator.evWrongOutputList,
4229 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004230 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004231 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004232 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004233 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004234 "reciprocal": {
4235 "op": Op.RECIPROCAL,
4236 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004237 "build_fcn": (
4238 build_unary,
4239 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004240 TosaTensorValuesGen.tvgLazyGenDefault,
4241 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004242 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004243 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004244 "error_if_validators": (
4245 TosaErrorValidator.evWrongInputType,
4246 TosaErrorValidator.evWrongOutputType,
4247 TosaErrorValidator.evWrongInputList,
4248 TosaErrorValidator.evWrongOutputList,
4249 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004250 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004251 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004252 },
4253 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004254 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004255 "rsqrt": {
4256 "op": Op.RSQRT,
4257 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004258 "build_fcn": (
4259 build_unary,
4260 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004261 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004262 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004263 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004264 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004265 "error_if_validators": (
4266 TosaErrorValidator.evWrongInputType,
4267 TosaErrorValidator.evWrongOutputType,
4268 TosaErrorValidator.evWrongInputList,
4269 TosaErrorValidator.evWrongOutputList,
4270 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004271 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004272 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004273 },
4274 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004275 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004276 "sin": {
4277 "op": Op.SIN,
4278 "operands": (1, 0),
4279 "build_fcn": (
4280 build_unary,
4281 TosaTensorGen.tgBasic,
4282 TosaTensorValuesGen.tvgLazyGenDefault,
4283 TosaArgGen.agNone,
4284 ),
4285 "types": TYPE_FP,
4286 "error_if_validators": (
4287 TosaErrorValidator.evWrongInputType,
4288 TosaErrorValidator.evWrongOutputType,
4289 TosaErrorValidator.evWrongInputList,
4290 TosaErrorValidator.evWrongOutputList,
4291 ),
4292 "data_gen": {
4293 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4294 },
4295 "compliance": {"abs_error_normal_divisor": 2},
4296 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004297 # Elementwise Ternary operators
4298 "select": {
4299 "op": Op.SELECT,
4300 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004301 "build_fcn": (
4302 build_select,
4303 TosaTensorGen.tgBroadcastFuzz,
4304 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004305 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004306 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004307 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004308 "error_if_validators": (
4309 TosaErrorValidator.evRankMismatch,
4310 TosaErrorValidator.evWrongInputType,
4311 TosaErrorValidator.evWrongOutputType,
4312 TosaErrorValidator.evWrongInputList,
4313 TosaErrorValidator.evWrongOutputList,
4314 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004315 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004316 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004317 "data_gen": {
4318 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4319 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004320 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004321 # Comparison operators
4322 "equal": {
4323 "op": Op.EQUAL,
4324 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004325 "build_fcn": (
4326 build_comparison,
4327 TosaTensorGen.tgBroadcastFuzz,
4328 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004329 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004330 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004331 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004332 "error_if_validators": (
4333 TosaErrorValidator.evRankMismatch,
4334 TosaErrorValidator.evWrongInputType,
4335 TosaErrorValidator.evWrongOutputType,
4336 TosaErrorValidator.evWrongInputList,
4337 TosaErrorValidator.evWrongOutputList,
4338 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004339 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004340 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004341 "data_gen": {
4342 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4343 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004344 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004345 "greater_equal": {
4346 "op": Op.GREATER_EQUAL,
4347 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004348 "build_fcn": (
4349 build_comparison,
4350 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004351 TosaTensorValuesGen.tvgLazyGenDefault,
4352 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004353 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004354 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004355 "error_if_validators": (
4356 TosaErrorValidator.evRankMismatch,
4357 TosaErrorValidator.evWrongInputType,
4358 TosaErrorValidator.evWrongOutputType,
4359 TosaErrorValidator.evWrongInputList,
4360 TosaErrorValidator.evWrongOutputList,
4361 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004362 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004363 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004364 "data_gen": {
4365 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4366 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004367 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004368 "greater": {
4369 "op": Op.GREATER,
4370 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004371 "build_fcn": (
4372 build_comparison,
4373 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004374 TosaTensorValuesGen.tvgLazyGenDefault,
4375 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004376 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004377 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004378 "error_if_validators": (
4379 TosaErrorValidator.evRankMismatch,
4380 TosaErrorValidator.evWrongInputType,
4381 TosaErrorValidator.evWrongOutputType,
4382 TosaErrorValidator.evWrongInputList,
4383 TosaErrorValidator.evWrongOutputList,
4384 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004385 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004386 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004387 "data_gen": {
4388 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4389 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004390 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004391 # Reduction operators
4392 "reduce_all": {
4393 "op": Op.REDUCE_ALL,
4394 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004395 "build_fcn": (
4396 build_reduce,
4397 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004398 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004399 TosaArgGen.agAxis,
4400 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004401 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004402 "error_if_validators": (
4403 TosaErrorValidator.evAxisLargerRank,
4404 TosaErrorValidator.evAxisSmallerZero,
4405 TosaErrorValidator.evShapeOfAxisNotOne,
4406 TosaErrorValidator.evWrongInputType,
4407 TosaErrorValidator.evWrongOutputType,
4408 TosaErrorValidator.evWrongRank,
4409 TosaErrorValidator.evWrongInputList,
4410 TosaErrorValidator.evWrongOutputList,
4411 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004412 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004413 "reduce_any": {
4414 "op": Op.REDUCE_ANY,
4415 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004416 "build_fcn": (
4417 build_reduce,
4418 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004419 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004420 TosaArgGen.agAxis,
4421 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004422 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004423 "error_if_validators": (
4424 TosaErrorValidator.evAxisLargerRank,
4425 TosaErrorValidator.evAxisSmallerZero,
4426 TosaErrorValidator.evShapeOfAxisNotOne,
4427 TosaErrorValidator.evWrongInputType,
4428 TosaErrorValidator.evWrongOutputType,
4429 TosaErrorValidator.evWrongRank,
4430 TosaErrorValidator.evWrongInputList,
4431 TosaErrorValidator.evWrongOutputList,
4432 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004433 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004434 "reduce_max": {
4435 "op": Op.REDUCE_MAX,
4436 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004437 "build_fcn": (
4438 build_reduce,
4439 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004440 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004441 TosaArgGen.agAxis,
4442 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004443 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004444 "error_if_validators": (
4445 TosaErrorValidator.evAxisLargerRank,
4446 TosaErrorValidator.evAxisSmallerZero,
4447 TosaErrorValidator.evShapeOfAxisNotOne,
4448 TosaErrorValidator.evWrongInputType,
4449 TosaErrorValidator.evWrongOutputType,
4450 TosaErrorValidator.evWrongRank,
4451 TosaErrorValidator.evWrongInputList,
4452 TosaErrorValidator.evWrongOutputList,
4453 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004454 "data_gen": {
4455 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4456 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004457 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004458 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004459 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004460 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004461 "build_fcn": (
4462 build_reduce,
4463 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004464 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004465 TosaArgGen.agAxis,
4466 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004467 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004468 "error_if_validators": (
4469 TosaErrorValidator.evAxisLargerRank,
4470 TosaErrorValidator.evAxisSmallerZero,
4471 TosaErrorValidator.evShapeOfAxisNotOne,
4472 TosaErrorValidator.evWrongInputType,
4473 TosaErrorValidator.evWrongOutputType,
4474 TosaErrorValidator.evWrongRank,
4475 TosaErrorValidator.evWrongInputList,
4476 TosaErrorValidator.evWrongOutputList,
4477 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004478 "data_gen": {
4479 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4480 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004481 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004482 "reduce_product": {
4483 "op": Op.REDUCE_PRODUCT,
4484 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004485 "build_fcn": (
4486 build_reduce,
4487 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004488 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004489 TosaArgGen.agAxis,
4490 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004491 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004492 "error_if_validators": (
4493 TosaErrorValidator.evAxisLargerRank,
4494 TosaErrorValidator.evAxisSmallerZero,
4495 TosaErrorValidator.evShapeOfAxisNotOne,
4496 TosaErrorValidator.evWrongInputType,
4497 TosaErrorValidator.evWrongOutputType,
4498 TosaErrorValidator.evWrongRank,
4499 TosaErrorValidator.evWrongInputList,
4500 TosaErrorValidator.evWrongOutputList,
4501 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004502 "data_gen": {
4503 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4504 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004505 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004506 "reduce_sum": {
4507 "op": Op.REDUCE_SUM,
4508 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004509 "build_fcn": (
4510 build_reduce,
4511 TosaTensorGen.tgBasic,
4512 TosaTensorValuesGen.tvgReduceSum,
4513 TosaArgGen.agAxis,
4514 ),
James Ward24dbc422022-10-19 12:20:31 +01004515 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004516 "error_if_validators": (
4517 TosaErrorValidator.evAxisLargerRank,
4518 TosaErrorValidator.evAxisSmallerZero,
4519 TosaErrorValidator.evShapeOfAxisNotOne,
4520 TosaErrorValidator.evWrongInputType,
4521 TosaErrorValidator.evWrongOutputType,
4522 TosaErrorValidator.evWrongRank,
4523 TosaErrorValidator.evWrongInputList,
4524 TosaErrorValidator.evWrongOutputList,
4525 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004526 "data_gen": {
4527 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4528 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004529 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004530 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004531 "concat": {
4532 "op": Op.CONCAT,
4533 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004534 "build_fcn": (
4535 build_concat,
4536 TosaTensorGen.tgConcat,
4537 TosaTensorValuesGen.tvgConcat,
4538 TosaArgGen.agAxis,
4539 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004540 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004541 "error_if_validators": (
4542 TosaErrorValidator.evAxisLargerRank,
4543 TosaErrorValidator.evAxisSmallerZero,
4544 TosaErrorValidator.evConcatInputRankMismatch,
4545 TosaErrorValidator.evConcatShapeSumMismatch,
4546 TosaErrorValidator.evConcatInputDimMismatch,
4547 TosaErrorValidator.evWrongInputType,
4548 TosaErrorValidator.evWrongOutputType,
4549 TosaErrorValidator.evWrongOutputList,
4550 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004551 "data_gen": {
4552 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4553 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004554 },
4555 "pad": {
4556 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004557 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004558 "build_fcn": (
4559 build_pad,
4560 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004561 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004562 TosaArgGen.agPad,
4563 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004564 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004565 "error_if_validators": (
4566 TosaErrorValidator.evWrongInputType,
4567 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004568 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004569 TosaErrorValidator.evWrongOutputType,
4570 TosaErrorValidator.evWrongInputList,
4571 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004572 TosaErrorValidator.evRankMismatch,
4573 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004574 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004575 "data_gen": {
4576 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4577 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004578 },
Won Jeona21b2e82023-08-10 10:33:01 +00004579 "dim": {
4580 "op": Op.DIM,
4581 "operands": (1, 0),
4582 "build_fcn": (
4583 build_dim,
4584 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004585 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004586 TosaArgGen.agAxis,
4587 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004588 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004589 "error_if_validators": (
4590 TosaErrorValidator.evAxisLargerRank,
4591 TosaErrorValidator.evAxisSmallerZero,
4592 TosaErrorValidator.evWrongInputType,
4593 TosaErrorValidator.evWrongInputList,
4594 TosaErrorValidator.evWrongOutputList,
4595 TosaErrorValidator.evWrongRank,
4596 ),
4597 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004598 "reshape": {
4599 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004600 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004601 "build_fcn": (
4602 build_reshape,
4603 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004604 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004605 TosaArgGen.agReshape,
4606 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004607 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004608 "error_if_validators": (
4609 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4610 TosaErrorValidator.evWrongInputType,
4611 TosaErrorValidator.evWrongOutputType,
4612 TosaErrorValidator.evWrongInputList,
4613 TosaErrorValidator.evWrongOutputList,
4614 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004615 "data_gen": {
4616 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4617 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004618 },
4619 "reverse": {
4620 "op": Op.REVERSE,
4621 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004622 "build_fcn": (
4623 build_reverse,
4624 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004625 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004626 TosaArgGen.agAxis,
4627 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004628 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004629 "error_if_validators": (
4630 TosaErrorValidator.evAxisSmallerZero,
4631 TosaErrorValidator.evAxisLargerRank,
4632 TosaErrorValidator.evWrongInputType,
4633 TosaErrorValidator.evWrongOutputType,
4634 TosaErrorValidator.evWrongInputList,
4635 TosaErrorValidator.evWrongOutputList,
4636 ),
evacha0198477222024-01-26 12:25:32 +00004637 "data_gen": {
4638 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4639 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004640 },
4641 "slice": {
4642 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004643 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004644 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004645 "build_fcn": (
4646 build_slice,
4647 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004648 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004649 TosaArgGen.agSlice,
4650 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004651 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004652 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004653 # TODO Turn off these error categories for now as the reference
4654 # model cannot allocate memory space for empty tensor. We probably
4655 # can report an accurate error messege at the right place during
4656 # exeuction.
4657 # TosaErrorValidator.evStartSmallerZero,
4658 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004659 TosaErrorValidator.evStartSizeOutsideBounds,
4660 TosaErrorValidator.evSizeOutputShapeMismatch,
4661 TosaErrorValidator.evInputSizeStartLengthMismatch,
4662 TosaErrorValidator.evWrongRank,
4663 TosaErrorValidator.evWrongInputType,
4664 TosaErrorValidator.evWrongOutputType,
4665 TosaErrorValidator.evWrongInputList,
4666 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004667 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004668 ),
evacha017f7d4252024-01-24 12:08:09 +00004669 "data_gen": {
4670 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4671 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004672 },
4673 "tile": {
4674 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004675 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004676 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004677 "build_fcn": (
4678 build_tile,
4679 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004680 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004681 TosaArgGen.agTile,
4682 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004683 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004684 "error_if_validators": (
4685 TosaErrorValidator.evWrongInputType,
4686 TosaErrorValidator.evWrongOutputType,
4687 TosaErrorValidator.evWrongInputList,
4688 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004689 TosaErrorValidator.evRankMismatch,
4690 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004691 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004692 "data_gen": {
4693 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4694 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004695 },
4696 "transpose": {
4697 "op": Op.TRANSPOSE,
4698 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004699 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004700 "build_fcn": (
4701 build_transpose,
4702 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004703 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004704 TosaArgGen.agTranspose,
4705 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004706 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004707 "error_if_validators": (
4708 TosaErrorValidator.evIndexOutsideBounds,
4709 TosaErrorValidator.evIndexUsedTwice,
4710 TosaErrorValidator.evWrongInputType,
4711 TosaErrorValidator.evWrongOutputType,
4712 TosaErrorValidator.evWrongInputList,
4713 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004714 TosaErrorValidator.evWrongRank,
4715 TosaErrorValidator.evRankMismatch,
4716 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004717 ),
evacha0198477222024-01-26 12:25:32 +00004718 "data_gen": {
4719 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4720 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004721 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004722 # Data nodes
4723 "const": {
4724 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004725 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004726 "build_fcn": (
4727 build_const,
4728 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004729 TosaTensorValuesGen.tvgLazyGenDefault,
4730 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004731 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004732 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004733 "data_gen": {
4734 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4735 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004736 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004737 "identity": {
4738 "op": Op.IDENTITY,
4739 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004740 "build_fcn": (
4741 build_unary,
4742 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004743 TosaTensorValuesGen.tvgLazyGenDefault,
4744 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004745 ),
evacha011adff832024-03-06 17:33:44 +00004746 "types": TYPE_FIB + [DType.INT4, DType.INT48],
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004747 "data_gen": {
4748 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4749 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004750 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004751 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004752 "gather": {
4753 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004754 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004755 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004756 "build_fcn": (
4757 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004758 TosaTensorGen.tgGather,
4759 TosaTensorValuesGen.tvgGather,
4760 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004761 ),
James Ward24dbc422022-10-19 12:20:31 +01004762 "types": (
4763 DType.INT8,
4764 DType.INT16,
4765 DType.INT32,
4766 DType.FP16,
4767 DType.BF16,
4768 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004769 DType.FP8E4M3,
4770 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004771 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004772 "error_if_validators": (
4773 TosaErrorValidator.evWrongInputType,
4774 TosaErrorValidator.evWrongOutputType,
4775 TosaErrorValidator.evWrongInputList,
4776 TosaErrorValidator.evWrongOutputList,
4777 TosaErrorValidator.evWrongRank,
4778 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004779 "data_gen": {
4780 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4781 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004782 },
4783 "scatter": {
4784 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004785 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004786 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004787 "build_fcn": (
4788 build_scatter,
4789 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004790 TosaTensorValuesGen.tvgScatter,
4791 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004792 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004793 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004794 "error_if_validators": (
4795 TosaErrorValidator.evWrongInputType,
4796 TosaErrorValidator.evWrongOutputType,
4797 TosaErrorValidator.evWrongInputList,
4798 TosaErrorValidator.evWrongOutputList,
4799 TosaErrorValidator.evWrongRank,
4800 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004801 "data_gen": {
4802 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4803 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004804 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004805 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004806 "resize": {
4807 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004808 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004809 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004810 "build_fcn": (
4811 build_resize,
4812 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004813 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004814 TosaArgGen.agResize,
4815 ),
James Ward24dbc422022-10-19 12:20:31 +01004816 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004817 "invalid_test_validators": (
4818 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004819 ),
4820 "error_if_validators": (
4821 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004822 TosaErrorValidator.evScaleSmallerEqualZero,
4823 TosaErrorValidator.evScaleNLargerMax,
4824 TosaErrorValidator.evScaleDLargerMax,
4825 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004826 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004827 TosaErrorValidator.evBorderSmallerMin,
4828 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004829 TosaErrorValidator.evWrongInputType,
4830 TosaErrorValidator.evWrongOutputType,
4831 TosaErrorValidator.evWrongRank,
4832 TosaErrorValidator.evWrongInputList,
4833 TosaErrorValidator.evWrongOutputList,
4834 TosaErrorValidator.evBatchMismatch,
4835 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004836 TosaErrorValidator.evResizeOutputShapeMismatch,
4837 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004838 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004839 "data_gen": {
4840 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4841 },
4842 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004843 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004844 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004845 "cast": {
4846 "op": Op.CAST,
4847 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004848 "build_fcn": (
4849 build_cast,
4850 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004851 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004852 TosaArgGen.agCast,
4853 ),
James Ward8b390432022-08-12 20:48:56 +01004854 "types": (
4855 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004856 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004857 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004858 DType.INT8,
4859 DType.INT16,
4860 DType.INT32,
4861 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004862 DType.FP8E4M3,
4863 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004864 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004865 "error_if_validators": (
4866 TosaErrorValidator.evWrongInputType,
4867 TosaErrorValidator.evWrongOutputType,
4868 TosaErrorValidator.evWrongInputList,
4869 TosaErrorValidator.evWrongOutputList,
4870 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004871 "data_gen": {
4872 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4873 },
4874 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004875 },
4876 "rescale": {
4877 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004878 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004879 "build_fcn": (
4880 build_rescale,
4881 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004882 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004883 TosaArgGen.agRescale,
4884 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004885 "types": [
4886 DType.UINT8,
4887 DType.INT8,
4888 DType.INT16,
4889 DType.INT32,
4890 DType.INT48,
4891 DType.UINT16,
4892 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004893 "error_if_validators": (
4894 TosaErrorValidator.evInputZeroPointNotZero,
4895 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004896 TosaErrorValidator.evU16InputZeroPointNotValid,
4897 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004898 TosaErrorValidator.evScaleTrue,
4899 TosaErrorValidator.evScaleNotTrue,
4900 TosaErrorValidator.evWrongInputType,
4901 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004902 TosaErrorValidator.evWrongInputList,
4903 TosaErrorValidator.evWrongOutputList,
4904 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004905 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004906 # Custom
4907 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004908 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004909 # Two varients of cond_if, one that generates one of two constant tensors (no
4910 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4911 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004912 "cond_if_const": {
4913 "op": Op.COND_IF,
4914 "operands": (0, 2),
4915 "build_fcn": (
4916 build_cond_if_const,
4917 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004918 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004919 TosaArgGen.agCondIf,
4920 ),
4921 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004922 "error_if_validators": (
4923 TosaErrorValidator.evOutputListThenGraphMismatch,
4924 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004925 TosaErrorValidator.evCondIfCondNotMatchingBool,
4926 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004927 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004928 },
4929 "cond_if_binary": {
4930 "op": Op.COND_IF,
4931 "operands": (2, 0),
4932 "build_fcn": (
4933 build_cond_if_binary,
4934 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004935 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004936 TosaArgGen.agCondIf,
4937 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004938 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004939 "error_if_validators": (
4940 TosaErrorValidator.evInputListThenGraphMismatch,
4941 TosaErrorValidator.evInputListElseGraphMismatch,
4942 TosaErrorValidator.evOutputListThenGraphMismatch,
4943 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004944 TosaErrorValidator.evCondIfCondNotMatchingBool,
4945 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004946 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004947 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004948 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004949 "while_loop": {
4950 "op": Op.WHILE_LOOP,
4951 "operands": (0, 1),
4952 "build_fcn": (
4953 build_while_loop,
4954 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004955 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004956 TosaArgGen.agWhileLoop,
4957 ),
4958 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004959 "error_if_validators": (
4960 TosaErrorValidator.evInputListOutputListMismatch,
4961 TosaErrorValidator.evInputListCondGraphMismatch,
4962 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4963 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4964 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004965 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004966 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004967 },
Luke Hutton57287132023-02-06 14:54:18 +00004968 "fft2d": {
4969 "op": Op.FFT2D,
4970 "operands": (2, 0),
4971 "rank": (3, 3),
4972 "build_fcn": (
4973 build_fft2d,
4974 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004975 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004976 TosaArgGen.agFFT2d,
4977 ),
4978 "types": [DType.FP32],
4979 "error_if_validators": (
4980 TosaErrorValidator.evWrongInputType,
4981 TosaErrorValidator.evWrongOutputType,
4982 TosaErrorValidator.evWrongInputList,
4983 TosaErrorValidator.evWrongOutputList,
4984 TosaErrorValidator.evWrongRank,
4985 TosaErrorValidator.evBatchMismatch,
4986 TosaErrorValidator.evKernelNotPowerOfTwo,
4987 TosaErrorValidator.evFFTInputShapeMismatch,
4988 TosaErrorValidator.evFFTOutputShapeMismatch,
4989 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004990 "data_gen": {
4991 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4992 },
Luke Hutton57287132023-02-06 14:54:18 +00004993 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004994 "rfft2d": {
4995 "op": Op.RFFT2D,
4996 "operands": (1, 0),
4997 "rank": (3, 3),
4998 "build_fcn": (
4999 build_rfft2d,
5000 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00005001 TosaTensorValuesGen.tvgLazyGenDefault,
5002 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00005003 ),
5004 "types": [DType.FP32],
5005 "error_if_validators": (
5006 TosaErrorValidator.evWrongInputType,
5007 TosaErrorValidator.evWrongOutputType,
5008 TosaErrorValidator.evWrongInputList,
5009 TosaErrorValidator.evWrongOutputList,
5010 TosaErrorValidator.evWrongRank,
5011 TosaErrorValidator.evBatchMismatch,
5012 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00005013 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00005014 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00005015 "data_gen": {
5016 "fp": (gtu.DataGenType.DOT_PRODUCT,),
5017 },
Luke Hutton261b7b62023-01-10 14:50:31 +00005018 },
Won Jeon74342e52024-01-09 00:34:40 +00005019 # Shape
5020 "add_shape": {
5021 "op": Op.ADD_SHAPE,
5022 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005023 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005024 "build_fcn": (
5025 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005026 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005027 TosaTensorValuesGen.tvgAddSub,
5028 TosaArgGen.agNone,
5029 ),
5030 "types": [DType.SHAPE],
5031 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5032 },
5033 "sub_shape": {
5034 "op": Op.SUB_SHAPE,
5035 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005036 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005037 "build_fcn": (
5038 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005039 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005040 TosaTensorValuesGen.tvgAddSub,
5041 TosaArgGen.agNone,
5042 ),
5043 "types": [DType.SHAPE],
5044 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5045 },
5046 "mul_shape": {
5047 "op": Op.MUL_SHAPE,
5048 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005049 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005050 "build_fcn": (
5051 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005052 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005053 TosaTensorValuesGen.tvgMul,
5054 TosaArgGen.agNone,
5055 ),
5056 "types": [DType.SHAPE],
5057 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5058 },
5059 "div_shape": {
5060 "op": Op.DIV_SHAPE,
5061 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005062 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005063 "build_fcn": (
5064 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005065 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005066 TosaTensorValuesGen.tvgIntDiv,
5067 TosaArgGen.agNone,
5068 ),
5069 "types": [DType.SHAPE],
5070 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5071 },
5072 "concat_shape": {
5073 "op": Op.CONCAT_SHAPE,
5074 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005075 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005076 "build_fcn": (
5077 build_concat,
5078 TosaTensorGen.tgConcat,
5079 TosaTensorValuesGen.tvgConcat,
5080 TosaArgGen.agNone,
5081 ),
5082 "types": [DType.SHAPE],
5083 "error_if_validators": (),
5084 },
5085 "const_shape": {
5086 "op": Op.CONST_SHAPE,
5087 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005088 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005089 "build_fcn": (
5090 build_const,
5091 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005092 TosaTensorValuesGen.tvgLazyGenDefault,
5093 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005094 ),
5095 "types": [DType.SHAPE],
5096 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005097 }
5098
Kevin Cheng550ccc52021-03-03 11:21:43 -08005099
Eric Kunzee5e26762020-10-13 16:11:07 -07005100class OutputShaper:
5101 # Methods in this class compute the expected output shape and datatype
5102 # for common classes of operations
5103 def __init__(self):
5104 pass
5105
5106 # These methods return arguments that can be used for
5107 # creating a new output tensor
5108 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005109 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5110 if error_name != ErrorIf.RankMismatch:
5111 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005112 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005113
5114 shape = []
5115 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005116 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005117 shape.append(b.shape[i])
5118 else:
5119 shape.append(a.shape[i])
5120
Jerry Ge135c9552023-05-23 20:59:32 +00005121 fuzz_idx = rng.integers(0, len(a.shape))
5122 if error_name == ErrorIf.DimensionMismatch:
5123 shape[fuzz_idx] += 1
5124
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005125 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005126 all_dtypes = [
5127 DType.INT8,
5128 DType.INT16,
5129 DType.INT32,
5130 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005131 DType.FP16,
5132 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005133 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005134 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005135 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5136 outputDType = rng.choice(wrong_dtypes)
5137 else:
5138 outputDType = a.dtype
5139
5140 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005141
5142 @staticmethod
5143 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005144 assert len(a.shape) == len(b.shape)
5145 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005146
5147 shape = []
5148 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005149 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005150 shape.append(a.shape[i])
5151
Kevin Cheng550ccc52021-03-03 11:21:43 -08005152 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005153
5154 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005155 def unaryOp(ser, rng, a, error_name=None):
5156 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005157 all_dtypes = [
5158 DType.INT8,
5159 DType.INT16,
5160 DType.INT32,
5161 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005162 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005163 DType.FP16,
5164 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005165 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005166 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5167 outputDType = rng.choice(wrong_dtypes)
5168 else:
5169 outputDType = a.dtype
5170
5171 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005172
5173 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005174 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005175 if error_name != ErrorIf.RankMismatch:
5176 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005177 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005178
5179 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005180 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005181 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005182 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5183 else:
5184 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005185
Jerry Ge135c9552023-05-23 20:59:32 +00005186 fuzz_idx = rng.integers(0, len(a.shape))
5187 if error_name == ErrorIf.DimensionMismatch:
5188 shape[fuzz_idx] += 1
5189
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005190 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005191 all_dtypes = [
5192 DType.INT8,
5193 DType.INT16,
5194 DType.INT32,
5195 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005196 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005197 DType.FP16,
5198 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005199 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005200 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5201 outputDType = rng.choice(wrong_dtypes)
5202 else:
5203 outputDType = a.dtype
5204
5205 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005206
5207 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005208 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005209 if error_name != ErrorIf.RankMismatch:
5210 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005211 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005212
5213 # Do broadcast
5214 shape = []
5215 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005216 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005217 shape.append(b.shape[i])
5218 else:
5219 shape.append(a.shape[i])
5220
Jerry Ge135c9552023-05-23 20:59:32 +00005221 fuzz_idx = rng.integers(0, len(a.shape))
5222 if error_name == ErrorIf.DimensionMismatch:
5223 shape[fuzz_idx] += 1
5224
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005225 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005226 wrong_dtypes = [
5227 DType.INT8,
5228 DType.INT16,
5229 DType.INT32,
5230 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005231 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005232 DType.FP16,
5233 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005234 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005235 outputDType = rng.choice(wrong_dtypes)
5236 else:
5237 outputDType = DType.BOOL
5238
5239 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005240
5241 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005242 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005243 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005244 if error_name not in [
5245 ErrorIf.AxisSmallerZero,
5246 ErrorIf.AxisLargerRank,
5247 ErrorIf.ShapeOfAxisNotOne,
5248 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005249 shape[axis] = 1
5250 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5251 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005252
Matthew Haddond6ce7252021-09-29 15:35:44 +01005253 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005254 all_dtypes = [
5255 DType.INT8,
5256 DType.INT16,
5257 DType.INT32,
5258 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005259 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005260 DType.FP16,
5261 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005262 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005263 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5264 outputDType = rng.choice(wrong_dtypes)
5265 else:
5266 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005267
Matthew Haddond6ce7252021-09-29 15:35:44 +01005268 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005269
5270 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005271 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005272 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005273
5274 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5275 del shape[axis]
5276
5277 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5278 remove = rng.choice([True, False])
5279 if remove and len(shape) > 1:
5280 del shape[0]
5281 else:
5282 shape.append(1)
5283 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5284 for i in range(len(shape)):
5285 shape[i] = shape[i] + rng.integers(1, 10)
5286
5287 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005288 all_dtypes = [
5289 DType.INT8,
5290 DType.INT16,
5291 DType.INT32,
5292 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005293 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005294 DType.FP16,
5295 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005296 DType.FP8E4M3,
5297 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005298 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005299 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5300 outputDType = rng.choice(wrong_dtypes)
5301 else:
5302 outputDType = DType.INT32
5303
5304 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005305
5306 @staticmethod
Tai Lyf36f2562024-03-14 16:21:29 +00005307 def _get_conv_output_type(input_dtype):
5308 if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
5309 return input_dtype
5310 elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
5311 return DType.FP16
5312 elif input_dtype in (DType.INT8, DType.INT4):
5313 return DType.INT32
5314 elif input_dtype in (DType.INT16,):
5315 return DType.INT48
5316 assert True, f"Unsupported convolution data type {input_dtype}"
5317
5318 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005319 def conv2dOp(
5320 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5321 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005322
5323 # IFM: NHWC
5324 # Filter: OHWI
5325 # OFM: NHWC
5326
Kevin Cheng550ccc52021-03-03 11:21:43 -08005327 h = (
5328 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005329 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005330 + padding[0]
5331 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005332 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005333 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005334
Kevin Cheng550ccc52021-03-03 11:21:43 -08005335 w = (
5336 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005337 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005338 + padding[2]
5339 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005340 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005341 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005342
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005343 if error_name == ErrorIf.ConvOutputShapeMismatch:
5344 choices = [1, 2, 3]
5345 change = rng.choice(choices)
5346 # increment in multiples of stride to not hit non-integer error case
5347 if change in [1, 3]:
5348 h = h + (rng.choice(choices) * strides[0])
5349 if change in [2, 3]:
5350 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005351
Eric Kunzee5e26762020-10-13 16:11:07 -07005352 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5353
James Ward8b390432022-08-12 20:48:56 +01005354 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005355 # Pick some potentially correct output dtype if input type is incorrect
5356 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005357 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005358 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005359
5360 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005361 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005362 excludes = [DType.FP16, DType.FP32]
Jeremy Johnson80fd9b82024-03-12 11:46:50 +00005363 elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
Won Jeon2c34b462024-02-06 18:37:00 +00005364 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005365 else:
5366 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005367 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005368 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005369
Kevin Cheng550ccc52021-03-03 11:21:43 -08005370 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005371
5372 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005373 def conv3dOp(
5374 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5375 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005376
5377 # IFM: NDHWC
5378 # Filter: ODHWI
5379 # OFM: NDHWC
5380
5381 d = (
5382 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005383 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005384 + padding[0]
5385 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005386 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005387 ) // strides[0] + 1
5388
5389 h = (
5390 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005391 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005392 + padding[2]
5393 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005394 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005395 ) // strides[1] + 1
5396
5397 w = (
5398 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005399 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005400 + padding[4]
5401 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005402 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005403 ) // strides[2] + 1
5404
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005405 if error_name == ErrorIf.ConvOutputShapeMismatch:
5406 choices = [1, 2, 3, 4]
5407 change = rng.choice(choices)
5408 # increment in multiples of stride to not hit non-integer error case
5409 if change in [1, 4]:
5410 d = d + (rng.choice(choices) * strides[0])
5411 if change in [2, 4]:
5412 h = h + (rng.choice(choices) * strides[1])
5413 if change in [3, 4]:
5414 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005415
Kevin Cheng1533b852021-09-01 12:51:58 -07005416 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5417
James Ward8b390432022-08-12 20:48:56 +01005418 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005419 # Pick some potentially correct output dtype if input type is incorrect
5420 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005421 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005422 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005423
5424 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005425 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005426 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005427 else:
5428 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005429 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005430 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005431
5432 return ser.addOutput(ofm_shape, out_dtype)
5433
5434 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005435 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005436 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005437 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005438 # IFM: NHWC
5439 # Filter: HWCM
5440 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005441
Kevin Cheng550ccc52021-03-03 11:21:43 -08005442 h = (
5443 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005444 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005445 + padding[0]
5446 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005447 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005448 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005449
Kevin Cheng550ccc52021-03-03 11:21:43 -08005450 w = (
5451 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005452 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005453 + padding[2]
5454 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005455 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005456 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005457
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005458 if error_name == ErrorIf.ConvOutputShapeMismatch:
5459 choices = [1, 2, 3]
5460 change = rng.choice(choices)
5461 # increment in multiples of stride to not hit non-integer error case
5462 if change in [1, 3]:
5463 h = h + (rng.choice(choices) * strides[0])
5464 if change in [2, 3]:
5465 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005466
Eric Kunzee5e26762020-10-13 16:11:07 -07005467 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5468
James Ward8b390432022-08-12 20:48:56 +01005469 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005470 # Pick some potentially correct output dtype if input type is incorrect
5471 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005472 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005473 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005474
5475 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005476 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005477 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005478 else:
5479 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005480 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005481 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005482
Kevin Cheng550ccc52021-03-03 11:21:43 -08005483 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005484
5485 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005486 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005487 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005488 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005489 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005490 h = 1
5491 w = 1
5492 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005493 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5494 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005495
5496 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005497 choices = [1, 2, 3]
5498 change = rng.choice(choices)
5499 # increment in multiples of stride to not hit non-integer error case
5500 if change in [1, 3]:
5501 h = h + (rng.choice(choices) * stride[0])
5502 if change in [2, 3]:
5503 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005504 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005505
5506 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005507 all_dtypes = [
5508 DType.INT8,
5509 DType.INT16,
5510 DType.INT32,
5511 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005512 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005513 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005514 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005515 DType.FP8E4M3,
5516 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005517 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005518 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5519 outputDType = rng.choice(wrong_dtypes)
5520 else:
5521 outputDType = ifm.dtype
5522
5523 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005524
5525 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005526 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005527 # input: N, IC
5528 # filter: OC, IC
5529 # output: N, OC
5530
5531 output_shape = [input.shape[0], filter.shape[0]]
5532
James Ward8b390432022-08-12 20:48:56 +01005533 # Validated in arg_gen (also invalidated for ErrorIf)
5534 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005535
Kevin Cheng550ccc52021-03-03 11:21:43 -08005536 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005537
5538 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005539 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005540 # a: N, H, C
5541 # b: N, C, W
5542 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005543
Kevin Cheng2d60f002021-06-09 14:18:32 -07005544 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005545
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005546 if error_name == ErrorIf.WrongOutputType:
5547 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005548 incorrect_types = (
5549 DType.INT4,
5550 DType.INT8,
5551 DType.INT16,
5552 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005553 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005554 DType.FP16,
5555 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005556 DType.FP8E4M3,
5557 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005558 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005559 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005560 incorrect_types = (
5561 DType.INT4,
5562 DType.INT8,
5563 DType.INT16,
5564 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005565 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005566 DType.FP16,
5567 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005568 DType.FP8E4M3,
5569 DType.FP8E5M2,
5570 )
5571 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5572 incorrect_types = (
5573 DType.INT4,
5574 DType.INT8,
5575 DType.INT16,
5576 DType.INT32,
5577 DType.INT48,
5578 DType.FP32,
5579 DType.BF16,
5580 DType.FP8E4M3,
5581 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005582 )
James Ward24dbc422022-10-19 12:20:31 +01005583 elif (
5584 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5585 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005586 incorrect_types = (
5587 DType.INT4,
5588 DType.INT8,
5589 DType.INT16,
5590 DType.INT32,
5591 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005592 DType.FP8E4M3,
5593 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005594 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005595 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005596 elif error_name == ErrorIf.WrongInputType:
5597 # Pick some potentially correct output dtype if input type is incorrect
5598 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005599 else:
James Ward8b390432022-08-12 20:48:56 +01005600 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005601
Kevin Cheng550ccc52021-03-03 11:21:43 -08005602 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005603
5604 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005605 def concatOp(ser, rng, axis, inputs, error_name=None):
5606 input1 = inputs[0]
5607 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005608
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005609 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005610 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005611 if not (
5612 # unable to concat tensors of different ranks
5613 error_name == ErrorIf.ConcatInputRankMismatch
5614 # unable to concat tensors along an invalid axis
5615 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005616 ):
5617 for tensor in remaining_inputs:
5618 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005619
Matthew Haddon01c359d2021-10-15 16:30:48 +01005620 if error_name == ErrorIf.ConcatShapeSumMismatch:
5621 output_shape[axis] += rng.integers(5, 10)
5622
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005623 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005624 all_dtypes = {
5625 DType.INT8,
5626 DType.INT16,
5627 DType.INT32,
5628 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005629 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005630 DType.FP16,
5631 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005632 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005633 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5634 outputDType = rng.choice(wrong_dtypes)
5635 else:
5636 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005637
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005638 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005639
5640 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005641 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005642
5643 output_shape = a.shape.copy()
5644
5645 for i in range(len(output_shape)):
5646 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5647
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005648 if error_name == ErrorIf.PadOutputShapeMismatch:
5649 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005650 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005651 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005652 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005653
Matthew Haddone807aae2021-10-11 18:12:58 +01005654 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005655 all_dtypes = [
5656 DType.INT8,
5657 DType.INT16,
5658 DType.INT32,
5659 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005660 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005661 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005662 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005663 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005664 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5665 outputDType = rng.choice(wrong_dtypes)
5666 else:
5667 outputDType = a.dtype
5668
5669 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005670
5671 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005672 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005673 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005674
5675 if error_name == ErrorIf.WrongOutputType:
5676 all_dtypes = [
5677 DType.INT8,
5678 DType.INT16,
5679 DType.INT32,
5680 DType.INT48,
5681 DType.FP32,
5682 DType.FP16,
5683 DType.BF16,
5684 ]
5685 wrong_dtypes = list(set(all_dtypes))
5686 outputDType = rng.choice(wrong_dtypes)
5687 else:
5688 outputDType = DType.SHAPE
5689
5690 return ser.addOutput(output_shape, outputDType)
5691
5692 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005693 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005694 output_shape = shape.copy()
5695
Matthew Haddone807aae2021-10-11 18:12:58 +01005696 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5697 for i in range(len(output_shape)):
5698 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5699
5700 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005701 all_dtypes = [
5702 DType.INT8,
5703 DType.INT16,
5704 DType.INT32,
5705 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005706 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005707 DType.FP16,
5708 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005709 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005710 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5711 outputDType = rng.choice(wrong_dtypes)
5712 else:
5713 outputDType = a.dtype
5714
5715 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005716
5717 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005718 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005719
Matthew Haddone807aae2021-10-11 18:12:58 +01005720 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005721 all_dtypes = [
5722 DType.INT8,
5723 DType.INT16,
5724 DType.INT32,
5725 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005726 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005727 DType.FP16,
5728 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005729 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005730 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005731 outputDType = rng.choice(wrong_dtypes)
5732 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005733 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005734
Luke Huttona4e48ca2023-02-22 11:53:48 +00005735 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005736 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005737 for index in range(len(output_shape)):
5738 if output_shape[index] <= 2:
5739 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5740 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005741 output_shape[index] = output_shape[index] + rng.choice(
5742 [-2, -1, 1, 2]
5743 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005744 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5745 output_shape = input.shape.copy()
5746 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005747 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005748
5749 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005750
5751 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005752 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005753
5754 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005755 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005756
5757 for i in range(len(output_shape)):
5758 output_shape[i] = a.shape[i] * multiples[i]
5759
Luke Huttona4e48ca2023-02-22 11:53:48 +00005760 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005761 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005762
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005763 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005764 all_dtypes = [
5765 DType.INT8,
5766 DType.INT16,
5767 DType.INT32,
5768 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005769 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005770 DType.FP16,
5771 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005772 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005773 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5774 outputDType = rng.choice(wrong_dtypes)
5775 else:
5776 outputDType = a.dtype
5777
5778 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005779
5780 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005781 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005782 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005783
Kevin Cheng550ccc52021-03-03 11:21:43 -08005784 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005785
Luke Huttona4e48ca2023-02-22 11:53:48 +00005786 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005787 for i in range(len(output_shape)):
5788 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005789
Luke Huttona4e48ca2023-02-22 11:53:48 +00005790 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5791 for i in range(len(output_shape)):
5792 output_shape[i] += rng.integers(1, 10)
5793 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005794 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005795
Matthew Haddone807aae2021-10-11 18:12:58 +01005796 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005797 all_dtypes = [
5798 DType.INT8,
5799 DType.INT16,
5800 DType.INT32,
5801 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005802 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005803 DType.FP16,
5804 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005805 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005806 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5807 outputDType = rng.choice(wrong_dtypes)
5808 else:
5809 outputDType = a.dtype
5810
5811 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005812
5813 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005814 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005815 if error_name != ErrorIf.WrongRank:
5816 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005817 assert len(indices.shape) == 2
5818 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005819
Kevin Cheng77d0f762020-11-24 10:26:32 -08005820 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5821
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005822 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005823 all_dtypes = [
5824 DType.INT8,
5825 DType.INT16,
5826 DType.INT32,
5827 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005828 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005829 DType.FP16,
5830 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005831 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005832 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5833 outputDType = rng.choice(wrong_dtypes)
5834 else:
5835 outputDType = values.dtype
5836
5837 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005838
5839 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005840 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005841 if error_name != ErrorIf.WrongRank:
5842 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005843 assert len(indices.shape) == 2
5844 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005845 assert values_in.shape[0] == indices.shape[0] # N
5846 assert input.shape[1] == indices.shape[1] # W
5847 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005848
5849 output_shape = values_in.shape
5850
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005851 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005852 all_dtypes = [
5853 DType.INT8,
5854 DType.INT16,
5855 DType.INT32,
5856 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005857 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005858 DType.FP16,
5859 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005860 DType.FP8E4M3,
5861 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005862 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005863 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5864 outputDType = rng.choice(wrong_dtypes)
5865 else:
5866 outputDType = values_in.dtype
5867
5868 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005869
5870 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005871 def tableOp(ser, rng, input, error_name=None):
5872 # Same shape as the input, dtype dependent on input dtype
5873 if error_name != ErrorIf.WrongInputType:
5874 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005875 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005876 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005877 wrong_dtypes = [
5878 DType.INT8,
5879 DType.INT16,
5880 DType.INT32,
5881 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005882 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005883 DType.FP16,
5884 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005885 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005886 wrong_dtypes.remove(output_dtype)
5887 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005888 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005889
5890 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005891 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005892 serializer,
5893 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005894 input,
5895 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005896 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005897 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005898 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005899 input_dtype,
5900 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005901 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005902 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005903 # Calculate OH, OW
5904 scale_y_n = scale[0]
5905 scale_y_d = scale[1]
5906 scale_x_n = scale[2]
5907 scale_x_d = scale[3]
5908 if error_name == ErrorIf.ScaleSmallerEqualZero:
5909 scale_y_n = max(scale_y_n, 1)
5910 scale_y_d = max(scale_y_d, 1)
5911 scale_x_n = max(scale_x_n, 1)
5912 scale_x_d = max(scale_x_d, 1)
5913
5914 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5915 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5916
5917 if error_name is not None:
5918 # Make sure the output tensor is valid, which can occur when
5919 # scale, offset or border have been changed for ERROR_IFs
5920 oh = max(oh, 1)
5921 ow = max(ow, 1)
5922 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005923 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5924 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005925
5926 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5927 choices = [1, 2, 3]
5928 change = rng.choice(choices)
5929 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5930 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005931 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005932 oh -= scale_y_d
5933 assert oh > 0 # Should have been caught in agResize
5934 else:
5935 oh += scale_y_d
5936 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005937 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005938 ow -= scale_x_d
5939 assert ow > 0 # Should have been caught in agResize
5940 else:
5941 ow += scale_x_d
5942
Matthew Haddon848efb42021-09-09 12:30:53 +01005943 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005944 output_dims = [
5945 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005946 oh,
5947 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005948 input.shape[0],
5949 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005950 elif error_name == ErrorIf.BatchMismatch:
5951 output_dims = [
5952 input.shape[0] + rng.integers(1, 10),
5953 oh,
5954 ow,
5955 input.shape[3],
5956 ]
5957 elif error_name == ErrorIf.ChannelMismatch:
5958 output_dims = [
5959 input.shape[0],
5960 oh,
5961 ow,
5962 input.shape[3] + rng.integers(1, 10),
5963 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005964 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005965 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005966
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005967 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005968
5969 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005970 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005971 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005972
5973 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005974 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005975 if error_name == ErrorIf.ConvOutputShapeMismatch:
5976 choices = [1, 2, 3]
5977 change = rng.choice(choices)
5978 if change in [1, 3]:
5979 output_shape[1] = output_shape[1] + rng.choice(choices)
5980 if change in [2, 3]:
5981 output_shape[2] = output_shape[2] + rng.choice(choices)
5982
James Ward8b390432022-08-12 20:48:56 +01005983 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005984 # Pick some potentially correct output dtype if input type is incorrect
5985 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005986 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005987 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005988
5989 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005990 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005991 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005992 else:
5993 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005994 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005995 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005996
Kevin Cheng550ccc52021-03-03 11:21:43 -08005997 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005998
5999 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00006000 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
6001 outputs = []
6002
6003 assert ifm1.dtype == ifm2.dtype
6004 input_dtype = ifm1.dtype
6005
6006 if error_name != ErrorIf.FFTInputShapeMismatch:
6007 assert ifm1.shape == ifm2.shape
6008
6009 input_shape = ifm1.shape
6010 if error_name != ErrorIf.WrongRank:
6011 assert len(input_shape) == 3
6012
6013 output_shape = input_shape.copy()
6014 output_dtype = input_dtype
6015
6016 if error_name == ErrorIf.WrongOutputType:
6017 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006018 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00006019 output_dtype = rng.choice(wrong_dtypes)
6020 elif error_name == ErrorIf.BatchMismatch:
6021 output_shape[0] += rng.integers(1, 10)
6022 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6023 modify_dim = rng.choice([1, 2])
6024 output_shape[modify_dim] += rng.integers(1, 10)
6025
6026 outputs.append(serializer.addOutput(output_shape, output_dtype))
6027 outputs.append(serializer.addOutput(output_shape, output_dtype))
6028 return outputs
6029
6030 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00006031 def rfft2dOp(serializer, rng, value, error_name=None):
6032 outputs = []
6033
6034 input_shape = value.shape
6035 if error_name != ErrorIf.WrongRank:
6036 assert len(input_shape) == 3
6037
6038 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
6039
6040 output_dtype = value.dtype
6041 if error_name == ErrorIf.WrongOutputType:
6042 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006043 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00006044 output_dtype = rng.choice(wrong_dtypes)
6045 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00006046 output_shape[0] += rng.integers(1, 10)
6047 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6048 modify_dim = rng.choice([1, 2])
6049 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00006050
6051 outputs.append(serializer.addOutput(output_shape, output_dtype))
6052 outputs.append(serializer.addOutput(output_shape, output_dtype))
6053 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00006054
6055 @staticmethod
6056 def addShapeOp(ser, rng, a, b, error_name=None):
6057 if error_name != ErrorIf.RankMismatch:
6058 assert len(a.shape) == len(b.shape)
6059 assert a.dtype == b.dtype
6060
6061 shape = []
6062 for i in range(len(a.shape)):
6063 shape.append(a.shape[i])
6064
6065 fuzz_idx = rng.integers(0, len(a.shape))
6066 if error_name == ErrorIf.DimensionMismatch:
6067 shape[fuzz_idx] += 1
6068
6069 if error_name == ErrorIf.WrongOutputType:
6070 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
6071 outputDType = rng.choice(wrong_dtypes)
6072 else:
6073 outputDType = DType.SHAPE
6074 return ser.addOutput(shape, outputDType)