blob: 7702753e3809e2af2ddaa774b24a5c525fe66480 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005import os
Tai Ly60dc48c2024-03-08 22:19:41 +00006import struct
Matthew Haddon630c17c2021-10-14 15:05:41 +01007from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01008from datetime import datetime
9from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -070010
Jeremy Johnson1271c442023-09-05 11:39:26 +010011import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000012import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000013import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010014from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010015from generator.tosa_arg_gen import TosaArgGen
16from generator.tosa_arg_gen import TosaQuantGen
17from generator.tosa_arg_gen import TosaTensorGen
18from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000019from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_error_if import TosaErrorIfArgGen
21from generator.tosa_error_if import TosaErrorValidator
22from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010023from generator.tosa_random_gen import TosaHashRandomGenerator
24from generator.tosa_random_gen import TosaRandomGenerator
Jeremy Johnson1271c442023-09-05 11:39:26 +010025from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000026from tosa.DType import DType
27from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010028
Jeremy Johnson1271c442023-09-05 11:39:26 +010029TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
30// SPDX-License-Identifier: Apache-2.0
31// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
32"""
33
Jeremy Johnsonaf090182024-02-13 18:25:39 +000034logging.basicConfig()
35logger = logging.getLogger("tosa_verif_build_tests")
36
Matthew Haddonb724efc2021-08-25 16:40:29 +010037
Eric Kunzee5e26762020-10-13 16:11:07 -070038class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010039 # Maximum rank of tensor supported by test generator.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000040 # This currently matches the 8K level defined in the specification.
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010041 TOSA_TENSOR_MAX_RANK = 6
Jeremy Johnsonb2099702023-04-12 15:59:01 +010042 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010043 TOSA_8K_LEVEL_MAX_KERNEL = 8192
44 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010045
Jeremy Johnson1271c442023-09-05 11:39:26 +010046 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000047 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010048 TOSA_MI_DOT_PRODUCT_MIN = 1000
49
Eric Kunzee5e26762020-10-13 16:11:07 -070050 def __init__(self, args):
51 self.args = args
52 self.basePath = args.output_dir
53 self.random_seed = args.random_seed
54 self.ser = None
Eric Kunzee5e26762020-10-13 16:11:07 -070055 self.createDynamicOpLists()
56 self.initOpListDefaults()
57 self.quantGen = TosaQuantGen()
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010058 self.global_rng = None
Eric Kunzee5e26762020-10-13 16:11:07 -070059 # Force makeShape to do a specific starting shape
60 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010061 # JSON schema validation
62 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010063 # Data generator library is sometimes needed for compliance set up
64 # even if we are generating the data later (lazy_data_generation)
65 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070066
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010067 # Work out floating point range
68 def convertFPRange(rangeFP, maxFP):
69 # Converts program arguments of max/-max to FP max
70 vals = []
71 for v in rangeFP:
72 if v == "max":
73 v = maxFP
74 elif v == "-max":
75 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000076 elif v < 0:
77 # Trim to minimum data type value
78 v = max(v, -maxFP)
79 elif v > 0:
80 # Trim to maximum data type value
81 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010082 vals.append(v)
83 return tuple(sorted(vals))
84
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010085 self.random_dtype_range = {
86 DType.SHAPE: tuple(self.args.tensor_shape_range[0:2])
87 }
Won Jeon2c34b462024-02-06 18:37:00 +000088 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010089 self.random_dtype_range[dtype] = convertFPRange(
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010090 args.tensor_fp_value_range,
91 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
92 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010093 self.resetGlobalRNG()
94
95 def resetGlobalRNG(self):
96 self.global_rng = TosaRandomGenerator(self.random_seed, self.random_dtype_range)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010097
Eric Kunzee5e26762020-10-13 16:11:07 -070098 def createSerializer(self, opName, testPath):
99 self.testPath = os.path.join(opName, testPath)
100
101 fullPath = os.path.join(self.basePath, self.testPath)
102 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100103 # Embed const data in the flatbuffer
104 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +0100105 if self.args.lazy_data_gen:
106 # Lazy data generation - so make constants files
107 constMode = ts.ConstMode.INPUTS
108 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100109 constMode = ts.ConstMode.EMBED_DUMP
110 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -0700111
112 def getSerializer(self):
113 return self.ser
114
Jeremy Johnson1271c442023-09-05 11:39:26 +0100115 def serialize(self, testName, metaData=None):
116 path = Path(self.basePath) / self.testPath
117
118 # Write out TOSA flatbuffer binary
119 path_fb = path / f"{testName}.tosa"
120 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700121 fd.write(self.ser.serialize())
122
Jeremy Johnson1271c442023-09-05 11:39:26 +0100123 # Get JSON descriptor from serializer
124 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
125
126 if metaData:
127 # Add extra meta data to desc.json
128 desc["meta"] = metaData
129
130 # Validate desc.json before we output it
131 self.descSchemaValidator.validate_config(desc)
132
133 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100134 if "data_gen" in metaData:
135 if self.args.lazy_data_gen:
136 # Output datagen meta data as CPP data
137 path_md = path / f"{testName}_meta_data_gen.cpp"
138 with path_md.open("w") as fd:
139 fd.write(TOSA_AUTOGENERATED_HEADER)
140 fd.write("// Test meta data for data generation setup\n\n")
141 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
142 json.dump(metaData["data_gen"], fd)
143 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100144 if "compliance" in metaData:
145 # Output datagen meta data as CPP data
146 path_md = path / f"{testName}_meta_compliance.cpp"
147 with path_md.open("w") as fd:
148 fd.write(TOSA_AUTOGENERATED_HEADER)
149 fd.write("// Test meta data for compliance validation\n\n")
150 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
151 json.dump(metaData["compliance"], fd)
152 fd.write(')";\n\n')
153
154 # Write desc.json
155 path_desc = path / "desc.json"
156 with path_desc.open("w") as fd:
157 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700158
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100159 def buildPlaceholderTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700160 placeholders = []
161
Kevin Cheng989cb052021-04-28 16:29:44 -0700162 assert len(shape_list) == len(dtype_list)
163
Jeremy Johnson1271c442023-09-05 11:39:26 +0100164 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700165 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100166 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100167 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700168 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700169
170 return placeholders
171
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100172 def buildConstTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700173 consts = []
174
Kevin Cheng989cb052021-04-28 16:29:44 -0700175 assert len(shape_list) == len(dtype_list)
176
Jeremy Johnson1271c442023-09-05 11:39:26 +0100177 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700178 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100179 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100180 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700181 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700182
183 return consts
184
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100185 def makeShape(self, rng, rank):
Eric Kunzee5e26762020-10-13 16:11:07 -0700186 if self.targetted_shape:
187 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800188 return np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100189 rng.integers(
Kevin Cheng550ccc52021-03-03 11:21:43 -0800190 low=self.args.tensor_shape_range[0],
191 high=self.args.tensor_shape_range[1],
192 size=rank,
193 )
194 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700195
196 def setTargetShape(self, shape):
197 self.targetted_shape = shape
198
Eric Kunzee5e26762020-10-13 16:11:07 -0700199 def shapeStr(self, shape):
200
201 sStr = []
202 # Convert to strings
203 for i in shape:
204 sStr.append(str(i))
205
Kevin Cheng550ccc52021-03-03 11:21:43 -0800206 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700207
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100208 def typeStr(self, dtype):
209 if isinstance(dtype, list) or isinstance(dtype, tuple):
210 assert len(dtype) >= 2
211 strs = [self.typeStr(t) for t in dtype]
212 # Limit types to the first 2 as the 3rd is the accumulator
213 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100215 if dtype in gtu.DTYPE_ATTRIBUTES:
216 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700217 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100218 raise Exception(
219 "Unknown dtype, cannot convert to string: {}".format(dtype)
220 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700221
Luke Hutton57287132023-02-06 14:54:18 +0000222 def constrictBatchSize(self, shape):
223 # Limit the batch size unless an explicit target shape set
224 if self.args.max_batch_size and not self.args.target_shapes:
225 shape[0] = min(shape[0], self.args.max_batch_size)
226 return shape
227
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100228 def makeDimension(self, rng):
229 return rng.randInt(
James Ward30124a82023-02-02 14:56:33 +0000230 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
231 )
232
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100233 def tensorComplianceMetaData(
234 self, op, inputType, argsDict, outputTensor, errorName
235 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000236 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
237 UNSUPPORTED_NON_FP32_INPUT_OPS = (
238 Op.MATMUL,
239 Op.CONV2D,
240 Op.FULLY_CONNECTED,
241 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000242 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000243 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000244 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100245 if (
246 errorName
247 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000248 or (
249 not gtu.dtypeIsSupportedByCompliance(inputType)
250 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
251 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100252 ):
253 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100255
Jeremy Johnson1271c442023-09-05 11:39:26 +0100256 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100257 compliance_tens = {
258 "mode": None,
259 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
260 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
261 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100262 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
263 mode = gtu.ComplianceMode.DOT_PRODUCT
264 compliance_tens["dot_product_info"] = {
265 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100266 "ks": int(argsDict["ksb"])
267 if "ksb" in argsDict
268 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100269 }
evacha019c96eef2024-02-07 11:21:55 +0000270 elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100271 mode = gtu.ComplianceMode.FP_SPECIAL
272 elif "compliance" in op and "ulp" in op["compliance"]:
273 mode = gtu.ComplianceMode.ULP
274 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000275 elif "compliance" in op and "relative" in op["compliance"]:
276 mode = gtu.ComplianceMode.RELATIVE
277 compliance_tens["relative_info"] = {
278 "max": argsDict["max_abs_value"],
279 "scale": op["compliance"]["relative"],
280 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100281 elif op["op"] == Op.REDUCE_PRODUCT:
282 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000283 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000284 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000285 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000286 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
287 compliance_tens["abs_error_info"] = {
288 "lower_bound": op["compliance"]["abs_error_lower_bound"]
289 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800290 elif op["op"] in (Op.SIN, Op.COS):
291 mode = gtu.ComplianceMode.ABS_ERROR
292 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
293 compliance_tens["abs_error_info"] = {
294 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
295 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100296 else:
297 mode = gtu.ComplianceMode.EXACT
298 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
299
300 return compliance_tens
301
302 # Build Op functions
303 # Create the output tensor (calling OutputShaper as needed)
304 # Do final tweaks to attributes (if necessary for errorIf)
305 # Add Op into graph
306 # Return resulting tensor information or BuildInfo
307
308 class BuildInfo:
309 """Enhanced build information containing result tensor and associated compliance dict."""
310
311 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000312 if isinstance(resultTensor, list):
313 assert complianceDict is None or isinstance(complianceDict, list)
314 self.resultTensorList = resultTensor
315 self.complianceDictList = complianceDict
316 else:
317 self.resultTensorList = [resultTensor]
318 if complianceDict is None:
319 self.complianceDictList = None
320 else:
321 self.complianceDictList = [complianceDict]
322
323 def getComplianceInfo(self):
324 if self.complianceDictList is None:
325 return None
326 else:
327 tens_dict = {}
328 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
329 if comp is not None:
330 tens_dict[tens.name] = comp
331
332 if tens_dict:
333 # Have some compliance data, so return the info
334 compliance = {
335 "version": "0.1",
336 "tensors": tens_dict,
337 }
338 else:
339 compliance = None
340 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700341
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000342 def build_unary(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100343 self,
344 rng,
345 op,
346 inputs,
347 args_dict,
348 validator_fcns=None,
349 error_name=None,
350 qinfo=None,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000351 ):
352 assert len(inputs) == 1
353 a = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100354 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100355
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000356 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100357
358 # Ensure new output type has correct qinfo
359 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000360 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000361 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100362 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, a.dtype),
363 TosaQuantGen.getZeroPoint(
364 rng, self.args.zeropoint, result_tensor.dtype
365 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000366 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100367
368 # Invalidate Input/Output list for error if checks.
369 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000370 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100371 pCount, cCount = op["operands"]
372 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000373 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100374 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000375 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100376
Les Bell729b0352021-11-24 10:28:21 +0000377 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100378 self.ser,
379 validator_fcns,
380 error_name,
381 op=op,
382 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000383 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000384 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000385 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100386 input_list=input_list,
387 output_list=output_list,
388 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000389 ):
390 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100391
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000392 attr = None
393 if op["op"] == Op.NEGATE:
394 attr = ts.TosaSerializerAttribute()
395 attr.NegateAttribute(qinfo[0], qinfo[1])
396
397 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000398
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000399 compliance = self.tensorComplianceMetaData(
400 op, a.dtype, args_dict, result_tensor, error_name
401 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000402 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700403
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000404 def build_binary_broadcast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100405 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000406 ):
407 assert len(inputs) == 2
408 a, b = inputs
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100409 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100410
411 # Invalidate Input/Output list for error if checks.
412 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000413 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100414 pCount, cCount = op["operands"]
415 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000416 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100417 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000418 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100419
Les Bell729b0352021-11-24 10:28:21 +0000420 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100421 self.ser,
422 validator_fcns,
423 error_name,
424 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000425 input1=a,
426 input2=b,
427 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000428 output_dtype=result_tensor.dtype,
429 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100430 input_list=input_list,
431 output_list=output_list,
432 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000433 ):
434 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100435
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000436 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000437
Jeremy Johnson9a758382023-11-07 16:27:35 +0000438 compliance = self.tensorComplianceMetaData(
439 op, a.dtype, args_dict, result_tensor, error_name
440 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000441
442 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700443
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000444 def build_arithmetic_right_shift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100445 self,
446 rng,
447 op,
448 inputs,
449 args_dict,
450 validator_fcns=None,
451 error_name=None,
452 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000453 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000454 assert len(inputs) == 2
455 a, b = inputs
456 round = args_dict["round"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100457 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100458
459 # Invalidate Input/Output list for error if checks.
460 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000461 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100462 pCount, cCount = op["operands"]
463 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000464 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100465 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000466 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100467
Les Bell729b0352021-11-24 10:28:21 +0000468 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100469 self.ser,
470 validator_fcns,
471 error_name,
472 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000473 input1=a,
474 input2=b,
475 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000476 output_dtype=result_tensor.dtype,
477 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100478 input_list=input_list,
479 output_list=output_list,
480 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000481 ):
482 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800483
484 attr = ts.TosaSerializerAttribute()
485 attr.ArithmeticRightShiftAttribute(round)
486
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000487 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000488
489 compliance = self.tensorComplianceMetaData(
490 op, a.dtype, args_dict, result_tensor, error_name
491 )
492
493 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800494
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100495 def build_mul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100496 self,
497 rng,
498 op,
499 inputs,
500 args_dict,
501 validator_fcns=None,
502 error_name=None,
503 qinfo=None,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100504 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000505 # Note that mul is binary operator but it has a shift value tensor
506 assert len(inputs) == 3
507 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100508
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100509 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700510
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100511 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100512 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100513 result_tensor.setDtype(DType.INT32)
514
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100515 if error_name == ErrorIf.WrongOutputType:
516 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100517 outputDType = rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100518 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100519
520 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000521 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100522 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100523 pCount, cCount = op["operands"]
524 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000525 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100526 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000527 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100528
Les Bell729b0352021-11-24 10:28:21 +0000529 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100530 self.ser,
531 validator_fcns,
532 error_name,
533 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000534 input1=a,
535 input2=b,
536 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100537 output_dtype=result_tensor.dtype,
538 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100539 input_list=input_list,
540 output_list=output_list,
541 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000542 ):
543 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700544
Jeremy Johnson0a042992024-02-28 13:20:05 +0000545 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100546
547 compliance = self.tensorComplianceMetaData(
548 op, a.dtype, args_dict, result_tensor, error_name
549 )
550
551 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700552
Jeremy Johnson587cc842024-02-08 11:45:44 +0000553 def build_table(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100554 self,
555 rng,
556 op,
557 inputs,
558 args_dict,
559 validator_fcns=None,
560 error_name=None,
561 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000562 ):
563 assert len(inputs) == 1
564 a = inputs[0]
565 table = args_dict["table"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100566 result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700567
Kevin Chengfe392ce2021-10-18 21:51:55 +0000568 attr = ts.TosaSerializerAttribute()
569 attr.TableAttribute(table)
570
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100571 # Invalidate Input/Output list for error if checks.
572 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000573 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100574 pCount, cCount = op["operands"]
575 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000576 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100577 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000578 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100579
Les Bell729b0352021-11-24 10:28:21 +0000580 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100581 self.ser,
582 validator_fcns,
583 error_name,
584 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000585 input_shape=a.shape,
586 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000587 output_dtype=result_tensor.dtype,
588 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100589 input_list=input_list,
590 output_list=output_list,
591 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000592 ):
593 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100594
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000595 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700596
Jeremy Johnson587cc842024-02-08 11:45:44 +0000597 compliance = self.tensorComplianceMetaData(
598 op, a.dtype, args_dict, result_tensor, error_name
599 )
600
601 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700602
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000603 def build_select(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100604 self,
605 rng,
606 op,
607 inputs,
608 args_dict,
609 validator_fcns=None,
610 error_name=None,
611 qinfo=None,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000612 ):
613 assert len(inputs) == 3
614 cond, a, b = inputs
615
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100616 result_tensor = OutputShaper.selectOp(self.ser, rng, cond, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100617
618 # Invalidate Input/Output list for error if checks.
619 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000620 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100621 pCount, cCount = op["operands"]
622 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100624 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000625 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626
Les Bell729b0352021-11-24 10:28:21 +0000627 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100628 self.ser,
629 validator_fcns,
630 error_name,
631 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000632 input1=cond,
633 input2=a,
634 input3=b,
635 input_shape=a.shape,
636 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000637 output_dtype=result_tensor.dtype,
638 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100639 input_list=input_list,
640 output_list=output_list,
641 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000642 ):
643 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100644
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000645 self.ser.addOperator(
646 op["op"],
647 input_list,
648 output_list,
649 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000650 compliance = self.tensorComplianceMetaData(
651 op, a.dtype, args_dict, result_tensor, error_name
652 )
653
654 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700655
Jeremy Johnsona0150012023-11-15 15:52:06 +0000656 def build_comparison(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100657 self,
658 rng,
659 op,
660 inputs,
661 args_dict,
662 validator_fcns=None,
663 error_name=None,
664 qinfo=None,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000665 ):
666 assert len(inputs) == 2
667 a, b = inputs
668
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100669 result_tensor = OutputShaper.binaryComparisonOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100670
671 # Invalidate Input/Output list for error if checks.
672 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000673 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100674 pCount, cCount = op["operands"]
675 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000676 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100677 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000678 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100679
Les Bell729b0352021-11-24 10:28:21 +0000680 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100681 self.ser,
682 validator_fcns,
683 error_name,
684 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000685 input1=a,
686 input2=b,
687 input_shape=a.shape,
688 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000689 output_shape=result_tensor.shape,
690 output_dtype=result_tensor.dtype,
691 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100692 input_list=input_list,
693 output_list=output_list,
694 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000695 ):
696 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100697
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000698 self.ser.addOperator(
699 op["op"],
700 input_list,
701 output_list,
702 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000703
704 compliance = self.tensorComplianceMetaData(
705 op, a.dtype, args_dict, result_tensor, error_name
706 )
707 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700708
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000709 def build_argmax(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100710 self, rng, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000711 ):
712 assert len(inputs) == 1
713 a = inputs[0]
714 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100715 result_tensor = OutputShaper.argmaxOp(self.ser, rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100716
717 # Invalidate Input/Output list for error if checks.
718 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000719 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100720 pCount, cCount = op["operands"]
721 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000722 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100723 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000724 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100725
Les Bell729b0352021-11-24 10:28:21 +0000726 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100727 self.ser,
728 validator_fcns,
729 error_name,
730 op=op,
731 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000732 input_shape=a.shape,
733 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000734 output_shape=result_tensor.shape,
735 output_dtype=result_tensor.dtype,
736 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100737 input_list=input_list,
738 output_list=output_list,
739 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000740 ):
741 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700742
743 attr = ts.TosaSerializerAttribute()
744 attr.AxisAttribute(axis)
745
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000746 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000747
748 compliance = self.tensorComplianceMetaData(
749 op, inputs[0].dtype, args_dict, result_tensor, error_name
750 )
751 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700752
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000753 def build_pool2d(
754 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100755 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000756 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100757 inputs,
758 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000759 validator_fcns=None,
760 error_name=None,
761 qinfo=None,
762 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100763 assert len(inputs) == 1
764 input = inputs[0]
765 # max_pool has no accum_dtype
766 accum_dtype = (
767 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
768 )
769 stride = args_dict["stride"]
770 pad = args_dict["pad"]
771 kernel = args_dict["kernel"]
772
Jeremy Johnson0601f802023-11-08 16:28:09 +0000773 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100774 self.ser, rng, input, kernel, stride, pad, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000775 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100776
777 # Ensure new output type has correct qinfo
778 if error_name == ErrorIf.WrongInputType:
779 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000780 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100781 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, input.dtype),
782 TosaQuantGen.getZeroPoint(
783 rng, self.args.zeropoint, result_tensor.dtype
784 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000785 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100786
787 # Invalidate Input/Output list for error if checks.
788 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000789 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100790 pCount, cCount = op["operands"]
791 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000792 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100793 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000794 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100795
Les Bell729b0352021-11-24 10:28:21 +0000796 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100797 self.ser,
798 validator_fcns,
799 error_name,
800 op=op,
801 input_shape=input.shape,
802 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000803 output_shape=result_tensor.shape,
804 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000805 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100806 kernel=kernel,
807 stride=stride,
808 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000809 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000810 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100811 input_list=input_list,
812 output_list=output_list,
813 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000814 ):
815 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700816
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000817 if qinfo is None:
818 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700819
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000820 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100821 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000822
823 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700824
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100825 compliance = self.tensorComplianceMetaData(
826 op, inputs[0].dtype, args_dict, result_tensor, error_name
827 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100828
829 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100830
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000831 def build_conv2d(
832 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100833 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000834 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100835 inputs,
836 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000837 validator_fcns=None,
838 error_name=None,
839 qinfo=None,
840 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100841 assert len(inputs) == 3
842 ifm, filter, bias = inputs
843 accum_dtype = args_dict["acc_type"]
844 strides = args_dict["stride"]
845 padding = args_dict["pad"]
846 dilations = args_dict["dilation"]
847
Kevin Cheng550ccc52021-03-03 11:21:43 -0800848 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100849 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100850 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100851 rng,
James Ward8b390432022-08-12 20:48:56 +0100852 ifm,
853 filter,
854 accum_dtype,
855 strides,
856 padding,
857 dilations,
858 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000859 )
860
861 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000862 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
863 DType.INT8,
864 DType.UINT8,
865 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000866 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100867 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
868 TosaQuantGen.getZeroPoint(
869 rng, self.args.zeropoint, result_tensor.dtype
870 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000871 ]
Les Bell0e027d42021-11-09 14:42:14 +0000872
873 # Invalidate Input/Output list for error_if checks.
874 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100875 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000876 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000877 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100878 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000879 )
Les Bell0e027d42021-11-09 14:42:14 +0000880
Les Bell729b0352021-11-24 10:28:21 +0000881 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000882 self.ser,
883 validator_fcns,
884 error_name,
885 op=op,
886 input_dtype=ifm.dtype,
887 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100888 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000889 qinfo=qinfo,
890 input_list=input_list,
891 num_operands=num_operands,
892 output_list=output_list,
893 pad=padding,
894 stride=strides,
895 dilation=dilations,
896 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100897 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100898 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000899 ):
900 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700901
Tai Lyd3797f02023-11-15 23:06:19 +0000902 # TODO - Test local_bound, for now set local bound attribute to False
903 local_bound = False
904
Eric Kunzee5e26762020-10-13 16:11:07 -0700905 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000906 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -0700907
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000908 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100909
910 compliance = self.tensorComplianceMetaData(
911 op, ifm.dtype, args_dict, result_tensor, error_name
912 )
913
914 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700915
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000916 def build_conv3d(
917 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100918 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000919 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100920 inputs,
921 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000922 validator_fcns=None,
923 error_name=None,
924 qinfo=None,
925 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100926 assert len(inputs) == 3
927 ifm, filter, bias = inputs
928 accum_dtype = args_dict["acc_type"]
929 strides = args_dict["stride"]
930 padding = args_dict["pad"]
931 dilations = args_dict["dilation"]
932
Kevin Cheng1533b852021-09-01 12:51:58 -0700933 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000934 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100935 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100936 rng,
James Ward8b390432022-08-12 20:48:56 +0100937 ifm,
938 filter,
939 accum_dtype,
940 strides,
941 padding,
942 dilations,
943 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000944 )
945
946 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000947 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
948 DType.INT8,
949 DType.UINT8,
950 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000951 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100952 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
953 TosaQuantGen.getZeroPoint(
954 rng, self.args.zeropoint, result_tensor.dtype
955 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000956 ]
Les Bell0e027d42021-11-09 14:42:14 +0000957
958 # Invalidate Input/Output list for error_if checks.
959 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +0000960 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000961 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000962 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100963 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000964 )
Les Bell0e027d42021-11-09 14:42:14 +0000965
Les Bell729b0352021-11-24 10:28:21 +0000966 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000967 self.ser,
968 validator_fcns,
969 error_name,
970 op=op,
971 input_dtype=ifm.dtype,
972 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +0000973 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000974 qinfo=qinfo,
975 input_list=input_list,
976 num_operands=num_operands,
977 output_list=output_list,
978 pad=padding,
979 stride=strides,
980 dilation=dilations,
981 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100982 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +0000983 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +0000984 ):
985 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700986
Tai Lyd3797f02023-11-15 23:06:19 +0000987 # TODO - Test local_bound, for now set local bound attribute to False
988 local_bound = False
989
Kevin Cheng1533b852021-09-01 12:51:58 -0700990 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +0000991 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Kevin Cheng1533b852021-09-01 12:51:58 -0700992
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000993 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +0000994
995 compliance = self.tensorComplianceMetaData(
996 op, ifm.dtype, args_dict, result_tensor, error_name
997 )
998
999 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001000
Kevin Cheng550ccc52021-03-03 11:21:43 -08001001 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001002 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001003 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001004 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001005 inputs,
1006 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001007 validator_fcns=None,
1008 error_name=None,
1009 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001010 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001011 assert len(inputs) == 3
1012 ifm, filter, bias = inputs
1013 accum_dtype = args_dict["acc_type"]
1014 strides = args_dict["stride"]
1015 out_pad = args_dict["pad"]
1016 output_shape = args_dict["out_shape"]
1017
TatWai Chong24594f52022-06-08 00:48:04 -07001018 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001019 result_tensor = OutputShaper.transposeConv2DOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001020 self.ser, rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001021 )
Les Bell0e027d42021-11-09 14:42:14 +00001022
1023 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001024 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1025 DType.INT8,
1026 DType.UINT8,
1027 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001028 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001029 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1030 TosaQuantGen.getZeroPoint(
1031 rng, self.args.zeropoint, result_tensor.dtype
1032 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001033 ]
Les Bell0e027d42021-11-09 14:42:14 +00001034
1035 # Invalidate Input/Output list for error_if checks.
1036 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001037 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001038 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001039 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001040 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001041 )
Les Bell0e027d42021-11-09 14:42:14 +00001042
Les Bell729b0352021-11-24 10:28:21 +00001043 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001044 self.ser,
1045 validator_fcns,
1046 error_name,
1047 op=op,
1048 input_dtype=ifm.dtype,
1049 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001050 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001051 qinfo=qinfo,
1052 input_list=input_list,
1053 num_operands=num_operands,
1054 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001055 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001056 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001057 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001058 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001059 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001060 ):
1061 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001062
Tai Lyd3797f02023-11-15 23:06:19 +00001063 # TODO - Test local_bound, for now set local bound attribute to False
1064 local_bound = False
1065
Eric Kunzee5e26762020-10-13 16:11:07 -07001066 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001067 attr.TransposeConvAttribute(
Jeremy Johnson95a67102024-01-10 14:16:39 +00001068 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
Tai Lyd3797f02023-11-15 23:06:19 +00001069 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001070
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001071 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001072
1073 compliance = self.tensorComplianceMetaData(
1074 op, ifm.dtype, args_dict, result_tensor, error_name
1075 )
1076
1077 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001078
Kevin Cheng550ccc52021-03-03 11:21:43 -08001079 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001080 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001081 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001082 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001083 inputs,
1084 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001085 validator_fcns=None,
1086 error_name=None,
1087 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001088 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001089 assert len(inputs) == 3
1090 ifm, filter, bias = inputs
1091 accum_dtype = args_dict["acc_type"]
1092 strides = args_dict["stride"]
1093 padding = args_dict["pad"]
1094 dilations = args_dict["dilation"]
1095
Jeremy Johnson4f931302024-01-04 17:05:24 +00001096 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001097 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001098 rng,
James Ward8b390432022-08-12 20:48:56 +01001099 ifm,
1100 filter,
1101 accum_dtype,
1102 strides,
1103 padding,
1104 dilations,
1105 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001106 )
1107
1108 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001109 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1110 DType.INT8,
1111 DType.UINT8,
1112 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001113 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001114 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1115 TosaQuantGen.getZeroPoint(
1116 rng, self.args.zeropoint, result_tensor.dtype
1117 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001118 ]
Les Bell0e027d42021-11-09 14:42:14 +00001119
1120 # Invalidate Input/Output list for error_if checks.
1121 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001122 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001123 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001124 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001125 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001126 )
Les Bell0e027d42021-11-09 14:42:14 +00001127
Les Bell729b0352021-11-24 10:28:21 +00001128 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001129 self.ser,
1130 validator_fcns,
1131 error_name,
1132 op=op,
1133 input_dtype=ifm.dtype,
1134 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001135 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001136 qinfo=qinfo,
1137 input_list=input_list,
1138 num_operands=num_operands,
1139 output_list=output_list,
1140 pad=padding,
1141 stride=strides,
1142 dilation=dilations,
1143 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001144 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001145 output_shape=result_tensor.shape,
Les Bell729b0352021-11-24 10:28:21 +00001146 ):
1147 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001148
Tai Lyd3797f02023-11-15 23:06:19 +00001149 # TODO - Test local_bound, for now set local bound attribute to False
1150 local_bound = False
1151
Eric Kunzee5e26762020-10-13 16:11:07 -07001152 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001153 attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
Eric Kunzee5e26762020-10-13 16:11:07 -07001154
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001155 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001156
1157 compliance = self.tensorComplianceMetaData(
1158 op, ifm.dtype, args_dict, result_tensor, error_name
1159 )
1160
1161 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001162
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001163 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001164 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001165 rng,
James Ward8b390432022-08-12 20:48:56 +01001166 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001167 inputs,
1168 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001169 validator_fcns=None,
1170 error_name=None,
1171 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001172 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001173 assert len(inputs) == 3
1174 ifm, filter, bias = inputs
1175 accum_dtype = args_dict["acc_type"]
1176
1177 result_tensor = OutputShaper.fullyConnectedOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001178 self.ser, rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001179 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001180
1181 # Invalidate Input/Output list for error if checks.
1182 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001183 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001184 pCount, cCount = op["operands"]
1185 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001186 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001187 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001188 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001189
Les Bell729b0352021-11-24 10:28:21 +00001190 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001191 self.ser,
1192 validator_fcns,
1193 error_name,
1194 op=op,
1195 input_shape=ifm.shape,
1196 input_dtype=ifm.dtype,
1197 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001198 output_shape=result_tensor.shape,
1199 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001200 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001201 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001202 input_list=input_list,
1203 output_list=output_list,
1204 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001205 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001206 ):
1207 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001208
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001209 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001210 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001211
1212 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001213
1214 compliance = self.tensorComplianceMetaData(
1215 op, ifm.dtype, args_dict, result_tensor, error_name
1216 )
1217
1218 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001219
James Ward8b390432022-08-12 20:48:56 +01001220 def build_matmul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001221 self,
1222 rng,
1223 op,
1224 inputs,
1225 args_dict,
1226 validator_fcns=None,
1227 error_name=None,
1228 qinfo=None,
James Ward8b390432022-08-12 20:48:56 +01001229 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001230 assert len(inputs) == 2
1231 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001232 accum_dtype = args_dict["acc_type"]
1233 result_tensor = OutputShaper.matmulOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001234 self.ser, rng, a, b, accum_dtype, error_name
James Ward8b390432022-08-12 20:48:56 +01001235 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001236
1237 # Invalidate Input/Output list for error if checks.
1238 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001239 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001240 pCount, cCount = op["operands"]
1241 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001242 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001243 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001244 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001245
Les Bell729b0352021-11-24 10:28:21 +00001246 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001247 self.ser,
1248 validator_fcns,
1249 error_name,
1250 op=op,
1251 input_shape=a.shape,
1252 input_dtype=a.dtype,
1253 input2_shape=b.shape,
1254 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001255 output_shape=result_tensor.shape,
1256 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001257 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001258 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001259 input_list=input_list,
1260 output_list=output_list,
1261 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001262 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001263 ):
1264 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001265
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001266 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001267 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001268
1269 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001270
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001271 compliance = self.tensorComplianceMetaData(
1272 op, a.dtype, args_dict, result_tensor, error_name
1273 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001274
1275 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001276
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001277 def build_reduce(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001278 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001279 ):
1280 assert len(inputs) == 1
1281 a = inputs[0]
1282 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001283 result_tensor = OutputShaper.reduceOp(self.ser, rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001284
1285 # Invalidate Input/Output list for error if checks.
1286 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001287 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001288 pCount, cCount = op["operands"]
1289 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001290 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001291 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001292 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001293
Les Bell729b0352021-11-24 10:28:21 +00001294 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001295 self.ser,
1296 validator_fcns,
1297 error_name,
1298 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001299 axis=axis,
1300 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001301 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001302 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001303 output_dtype=result_tensor.dtype,
1304 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001305 input_list=input_list,
1306 output_list=output_list,
1307 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001308 ):
1309 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001310
1311 attr = ts.TosaSerializerAttribute()
1312 attr.AxisAttribute(axis)
1313
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001314 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001315
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001316 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1317 # Number of products - needed for compliance
1318 args_dict["n"] = a.shape[axis]
1319
1320 compliance = self.tensorComplianceMetaData(
1321 op, a.dtype, args_dict, result_tensor, error_name
1322 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001323
1324 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001325
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001326 def build_clamp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001327 self,
1328 rng,
1329 op,
1330 inputs,
1331 args_dict,
1332 validator_fcns=None,
1333 error_name=None,
1334 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001335 ):
1336 assert len(inputs) == 1
1337 a = inputs[0]
1338
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001339 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001340
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001341 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001342
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001343 if error_name == ErrorIf.MaxSmallerMin:
1344 # Make sure the numbers are different to invoke this error
1345 while v[0] == v[1]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001346 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001347 max_val = min(v)
1348 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001349 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001350 max_val = max(v)
1351 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001352
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001353 # Invalidate Input/Output list for error if checks.
1354 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001355 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001356 pCount, cCount = op["operands"]
1357 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001358 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001359 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001360 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001361
Les Bell729b0352021-11-24 10:28:21 +00001362 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001363 self.ser,
1364 validator_fcns,
1365 error_name,
1366 op=op,
1367 max_val=max_val,
1368 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001369 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001370 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001372 output_dtype=result_tensor.dtype,
1373 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374 input_list=input_list,
1375 output_list=output_list,
1376 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001377 ):
1378 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001379
1380 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001381 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1382 if a.dtype == DType.FP16:
1383 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1384 min_val = min_val.astype(np.float32)
1385 max_val = max_val.astype(np.float32)
Tai Ly60dc48c2024-03-08 22:19:41 +00001386 min_val_as_bytes = struct.pack("<f", min_val)
1387 max_val_as_bytes = struct.pack("<f", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001388 elif a.dtype in (DType.INT8, DType.INT16):
Tai Ly60dc48c2024-03-08 22:19:41 +00001389 min_val_as_bytes = struct.pack("<i", min_val)
1390 max_val_as_bytes = struct.pack("<i", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001391 else:
1392 # to avoid internal error for incorrect input types
Tai Ly60dc48c2024-03-08 22:19:41 +00001393 min_val_as_bytes = struct.pack("<i", 0)
1394 max_val_as_bytes = struct.pack("<i", 0)
1395
1396 attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001397
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001398 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001399
1400 compliance = self.tensorComplianceMetaData(
1401 op, a.dtype, args_dict, result_tensor, error_name
1402 )
1403
1404 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001405
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001406 def build_activation(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001407 self,
1408 rng,
1409 op,
1410 inputs,
1411 args_dict,
1412 validator_fcns=None,
1413 error_name=None,
1414 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001415 ):
1416 assert len(inputs) == 1
1417 a = inputs[0]
1418
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001419 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001420
1421 # Invalidate Input/Output list for error if checks.
1422 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001423 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001424 pCount, cCount = op["operands"]
1425 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001426 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001427 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001428 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001429
Les Bell729b0352021-11-24 10:28:21 +00001430 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001431 self.ser,
1432 validator_fcns,
1433 error_name,
1434 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001435 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001436 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001437 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001438 output_dtype=result_tensor.dtype,
1439 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001440 input_list=input_list,
1441 output_list=output_list,
1442 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001443 ):
1444 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001445
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001446 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001447
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001448 compliance = self.tensorComplianceMetaData(
1449 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001450 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001451
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001452 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001453
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001454 def build_concat(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001455 self,
1456 rng,
1457 op,
1458 inputs,
1459 args_dict,
1460 validator_fcns=None,
1461 error_name=None,
1462 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001463 ):
Won Jeon74342e52024-01-09 00:34:40 +00001464 if op["op"] == Op.CONCAT_SHAPE:
1465 axis = 0
1466 else:
1467 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001468 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001469 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001470
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001471 result_tensor = OutputShaper.concatOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001472 self.ser, rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001473 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001474
Matthew Haddon818ab902021-07-27 09:12:49 +01001475 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001476 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001477 input_tensor_names.append(tensor.name)
1478
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001479 # Invalidate Input/Output list for error if checks.
1480 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001481 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001482 pCount, cCount = op["operands"]
1483 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001484 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001485 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001486 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001487
Les Bell729b0352021-11-24 10:28:21 +00001488 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001489 self.ser,
1490 validator_fcns,
1491 error_name,
1492 op=op,
1493 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001494 input_shape=inputs[0].shape,
1495 output_shape=result_tensor.shape,
1496 input_dtype=inputs[0].dtype,
1497 output_dtype=result_tensor.dtype,
1498 inputs=inputs,
1499 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001500 input_list=input_list,
1501 output_list=output_list,
1502 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001503 ):
1504 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001505
Won Jeon74342e52024-01-09 00:34:40 +00001506 if op["op"] == Op.CONCAT:
1507 attr = ts.TosaSerializerAttribute()
1508 attr.AxisAttribute(axis)
1509 else:
1510 assert op["op"] == Op.CONCAT_SHAPE
1511 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001512 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001513
1514 compliance = self.tensorComplianceMetaData(
1515 op, inputs[0].dtype, args_dict, result_tensor, error_name
1516 )
1517
1518 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001519
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001520 def build_pad(
1521 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001522 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001523 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001524 inputs,
1525 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001526 validator_fcns=None,
1527 error_name=None,
1528 qinfo=None,
1529 ):
Tai Lye095da72024-01-25 22:00:18 +00001530 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001531 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001532 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001533 padding = args_dict["pad"]
1534 pad_const_int = args_dict["pad_const_int"]
1535 pad_const_float = args_dict["pad_const_fp"]
1536
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001537 result_tensor = OutputShaper.padOp(self.ser, rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001538
Tai Ly60dc48c2024-03-08 22:19:41 +00001539 # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
1540 if gtu.dtypeIsFloat(a.dtype):
1541 pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
1542 else:
1543 pad_const_val_as_bytes = struct.pack("<i", pad_const_int)
1544
Kevin Chengfe392ce2021-10-18 21:51:55 +00001545 attr = ts.TosaSerializerAttribute()
Tai Ly60dc48c2024-03-08 22:19:41 +00001546 attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
Eric Kunzee5e26762020-10-13 16:11:07 -07001547
Matthew Haddone807aae2021-10-11 18:12:58 +01001548 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001549 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001550 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001551 pCount, cCount = op["operands"]
1552 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001553 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001554 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001555 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001556
Les Bell729b0352021-11-24 10:28:21 +00001557 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001558 self.ser,
1559 validator_fcns,
1560 error_name,
1561 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001562 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001563 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001564 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001565 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001566 pad=padding,
1567 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001568 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001569 input_list=input_list,
1570 output_list=output_list,
1571 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001572 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001573 ):
1574 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001575
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001576 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001577
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001578 compliance = self.tensorComplianceMetaData(
1579 op, a.dtype, args_dict, result_tensor, error_name
1580 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001581
1582 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001583
Won Jeona21b2e82023-08-10 10:33:01 +00001584 def build_dim(
1585 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001586 rng,
Won Jeona21b2e82023-08-10 10:33:01 +00001587 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001588 inputs,
1589 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001590 validator_fcns=None,
1591 error_name=None,
1592 qinfo=None,
1593 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001594 assert len(inputs) == 1
1595 a = inputs[0]
1596 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001597 result_tensor = OutputShaper.dimOp(self.ser, rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001598
1599 # Invalidate Input/Output list for error if checks.
1600 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001601 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001602 pCount, cCount = op["operands"]
1603 num_operands = pCount + cCount
1604 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001605 rng, error_name, input_list, output_list
Won Jeona21b2e82023-08-10 10:33:01 +00001606 )
1607
1608 if not TosaErrorValidator.evValidateErrorIfs(
1609 self.ser,
1610 validator_fcns,
1611 error_name,
1612 op=op,
1613 axis=axis,
1614 input_shape=a.shape,
1615 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001616 output_shape=result_tensor.shape,
1617 output_dtype=result_tensor.dtype,
1618 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001619 input_list=input_list,
1620 output_list=output_list,
1621 num_operands=num_operands,
1622 ):
1623 return None
1624
1625 attr = ts.TosaSerializerAttribute()
1626 attr.AxisAttribute(axis)
1627
1628 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001629 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001630
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001631 def build_reshape(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001632 self,
1633 rng,
1634 op,
1635 inputs,
1636 args_dict,
1637 validator_fcns=None,
1638 error_name=None,
1639 qinfo=None,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001640 ):
Tai Ly8690a082023-12-18 20:40:24 +00001641 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001642 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001643 shape = inputs[1]
1644 shape_attr = args_dict["new_shape"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001645 result_tensor = OutputShaper.reshapeOp(self.ser, rng, a, shape_attr, error_name)
Matthew Haddone807aae2021-10-11 18:12:58 +01001646
1647 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001648 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001649 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001650 pCount, cCount = op["operands"]
1651 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001652 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001653 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001654 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001655
Les Bell729b0352021-11-24 10:28:21 +00001656 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001657 self.ser,
1658 validator_fcns,
1659 error_name,
1660 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001661 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001662 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001663 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001664 output_dtype=result_tensor.dtype,
1665 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001666 input_list=input_list,
1667 output_list=output_list,
1668 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001669 ):
1670 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001671
Tai Ly8690a082023-12-18 20:40:24 +00001672 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001673
1674 compliance = self.tensorComplianceMetaData(
1675 op, a.dtype, args_dict, result_tensor, error_name
1676 )
1677
1678 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001679
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001680 def build_reverse(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001681 self,
1682 rng,
1683 op,
1684 inputs,
1685 args_dict,
1686 validator_fcns=None,
1687 error_name=None,
1688 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001689 ):
1690 assert len(inputs) == 1
1691 a = inputs[0]
1692 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001693 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001694
1695 # Invalidate Input/Output list for error if checks.
1696 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001697 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001698 pCount, cCount = op["operands"]
1699 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001700 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001701 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001702 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001703
Les Bell729b0352021-11-24 10:28:21 +00001704 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001705 self.ser,
1706 validator_fcns,
1707 error_name,
1708 op=op,
1709 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001710 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001711 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001712 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001713 output_dtype=result_tensor.dtype,
1714 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001715 input_list=input_list,
1716 output_list=output_list,
1717 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001718 ):
1719 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001720
1721 attr = ts.TosaSerializerAttribute()
1722 attr.AxisAttribute(axis)
1723
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001724 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001725 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001726
evacha0198477222024-01-26 12:25:32 +00001727 def build_transpose(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001728 self,
1729 rng,
1730 op,
1731 inputs,
1732 args_dict,
1733 validator_fcns=None,
1734 error_name=None,
1735 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00001736 ):
1737 assert len(inputs) == 1
1738 a = inputs[0]
1739 perms = args_dict["perms"]
1740
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001741 result_tensor = OutputShaper.transposeOp(self.ser, rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001742
Kevin Chengfe392ce2021-10-18 21:51:55 +00001743 attr = ts.TosaSerializerAttribute()
1744 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001745
Matthew Haddone807aae2021-10-11 18:12:58 +01001746 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001747 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001748 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001749 pCount, cCount = op["operands"]
1750 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001751 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001752 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001753 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001754
Les Bell729b0352021-11-24 10:28:21 +00001755 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001756 self.ser,
1757 validator_fcns,
1758 error_name,
1759 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001760 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001761 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001762 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001763 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001764 output_dtype=result_tensor.dtype,
1765 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001766 input_list=input_list,
1767 output_list=output_list,
1768 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001769 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001770 ):
1771 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001772
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001773 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001774
1775 compliance = self.tensorComplianceMetaData(
1776 op, a.dtype, args_dict, result_tensor, error_name
1777 )
1778
1779 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001780
evacha017f7d4252024-01-24 12:08:09 +00001781 def build_slice(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001782 self,
1783 rng,
1784 op,
1785 inputs,
1786 args_dict,
1787 validator_fcns=None,
1788 error_name=None,
1789 qinfo=None,
evacha017f7d4252024-01-24 12:08:09 +00001790 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001791 assert len(inputs) == 3
1792 a, start_var, size_var = inputs
1793 start_const = args_dict["start"]
1794 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001795
1796 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001797 self.ser, rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001798 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001799
1800 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001801 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001802 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001803 pCount, cCount = op["operands"]
1804 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001805 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001806 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001807 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001808
Les Bell729b0352021-11-24 10:28:21 +00001809 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001810 self.ser,
1811 validator_fcns,
1812 error_name,
1813 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001814 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001815 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001816 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001817 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001818 start=start_const,
1819 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001820 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001821 input_list=input_list,
1822 output_list=output_list,
1823 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001824 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001825 ):
1826 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001827
Tai Ly8ead6c42024-02-14 22:35:44 +00001828 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001829
1830 compliance = self.tensorComplianceMetaData(
1831 op, a.dtype, args_dict, result_tensor, error_name
1832 )
1833
1834 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001835
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001836 def build_tile(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001837 self,
1838 rng,
1839 op,
1840 inputs,
1841 args_dict,
1842 validator_fcns=None,
1843 error_name=None,
1844 qinfo=None,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001845 ):
Tai Ly8690a082023-12-18 20:40:24 +00001846 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001847 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001848 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001849 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001850 result_tensor = OutputShaper.tileOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001851 self.ser, rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001852 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001853
1854 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001855 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001856 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001857 pCount, cCount = op["operands"]
1858 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001859 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001860 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001861 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001862
Les Bell729b0352021-11-24 10:28:21 +00001863 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001864 self.ser,
1865 validator_fcns,
1866 error_name,
1867 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001868 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001869 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001870 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001871 output_dtype=result_tensor.dtype,
1872 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001873 input_list=input_list,
1874 output_list=output_list,
1875 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001876 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001877 ):
1878 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001879
Tai Ly8690a082023-12-18 20:40:24 +00001880 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001881
1882 compliance = self.tensorComplianceMetaData(
1883 op, a.dtype, args_dict, result_tensor, error_name
1884 )
1885
1886 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001887
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001888 def build_gather(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001889 self,
1890 rng,
1891 op,
1892 inputs,
1893 args_dict,
1894 validator_fcns=None,
1895 error_name=None,
1896 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001897 ):
1898 assert len(inputs) == 2
1899 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001900
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001901 result_tensor = OutputShaper.gatherOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001902 self.ser, rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001903 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001904
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001905 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001906 input_list = [values.name, indices.name]
1907 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001908 pCount, cCount = op["operands"]
1909 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001910 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001911 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001912 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001913
Les Bell729b0352021-11-24 10:28:21 +00001914 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001915 self.ser,
1916 validator_fcns,
1917 error_name,
1918 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001919 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001920 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001921 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001922 output_dtype=result_tensor.dtype,
1923 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001924 input_list=input_list,
1925 output_list=output_list,
1926 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001927 ):
1928 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001929
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001930 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001931
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001932 compliance = self.tensorComplianceMetaData(
1933 op, values.dtype, args_dict, result_tensor, error_name
1934 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001935
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001936 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001937
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001938 def build_scatter(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001939 self,
1940 rng,
1941 op,
1942 inputs,
1943 args_dict,
1944 validator_fcns=None,
1945 error_name=None,
1946 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001947 ):
1948 assert len(inputs) == 3
1949 values_in, indices, input = inputs
1950 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001951 self.ser, rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001952 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001953
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001954 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001955 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001956 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001957 pCount, cCount = op["operands"]
1958 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001959 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001960 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001961 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001962
Les Bell729b0352021-11-24 10:28:21 +00001963 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001964 self.ser,
1965 validator_fcns,
1966 error_name,
1967 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001968 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001969 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001970 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001971 output_dtype=result_tensor.dtype,
1972 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001973 input_list=input_list,
1974 output_list=output_list,
1975 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001976 ):
1977 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001978
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001979 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001980
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001981 compliance = self.tensorComplianceMetaData(
1982 op, values_in.dtype, args_dict, result_tensor, error_name
1983 )
1984
1985 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001986
Kevin Cheng550ccc52021-03-03 11:21:43 -08001987 def build_resize(
1988 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001989 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001990 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001991 inputs,
1992 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01001993 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001994 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001995 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001996 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001997 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001998 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001999 scale_input = inputs[1]
2000 offset_input = inputs[2]
2001 border_input = inputs[3]
2002
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002003 mode = args_dict["mode"]
2004 scale = args_dict["scale"]
2005 offset = args_dict["offset"]
2006 border = args_dict["border"]
2007 output_dtype = args_dict["output_dtype"]
2008
2009 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002010 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002011 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002012 input,
2013 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002014 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002015 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002016 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002017 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002018 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002019 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002020 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002021
Matthew Haddon848efb42021-09-09 12:30:53 +01002022 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002023 input_list = [
2024 input.name,
2025 scale_input.name,
2026 offset_input.name,
2027 border_input.name,
2028 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002029 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002030 pCount, cCount = op["operands"]
2031 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002032 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002033 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002034 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002035
Les Bell729b0352021-11-24 10:28:21 +00002036 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002037 self.ser,
2038 validator_fcns,
2039 error_name,
2040 op=op,
2041 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002042 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002043 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002044 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002045 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002046 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002047 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002048 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002049 input_list=input_list,
2050 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002051 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002052 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002053 ):
2054 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002055
Eric Kunzee5e26762020-10-13 16:11:07 -07002056 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002057 # write empty scale/offset/border into ResizeAttribute
2058 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002059 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002060
2061 compliance = self.tensorComplianceMetaData(
2062 op, input.dtype, args_dict, result_tensor, error_name
2063 )
2064
2065 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002066
evacha0198477222024-01-26 12:25:32 +00002067 def build_const(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002068 self,
2069 rng,
2070 op,
2071 inputs,
2072 args_dict,
2073 validator_fcns=None,
2074 error_name=None,
2075 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00002076 ):
2077 assert len(inputs) == 1
2078 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002079 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002080
2081 compliance = self.tensorComplianceMetaData(
2082 op, val.dtype, args_dict, val, error_name
2083 )
2084
2085 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002086
2087 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002088 def build_cast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002089 self,
2090 rng,
2091 op,
2092 inputs,
2093 args_dict,
2094 validator_fcns=None,
2095 error_name=None,
2096 qinfo=None,
Jeremy Johnson708da822023-11-15 16:25:45 +00002097 ):
2098 assert len(inputs) == 1
2099 val = inputs[0]
2100 out_dtype = args_dict["out_type"]
2101
2102 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002103 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002104 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002105
2106 # Invalidate Input/Output list for error if checks.
2107 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002108 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002109 pCount, cCount = op["operands"]
2110 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002111 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002112 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002113 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002114
Les Bell729b0352021-11-24 10:28:21 +00002115 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002116 self.ser,
2117 validator_fcns,
2118 error_name,
2119 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002120 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002121 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002122 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002123 output_dtype=result_tensor.dtype,
2124 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002125 input_list=input_list,
2126 output_list=output_list,
2127 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002128 ):
2129 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002130
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002131 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002132
2133 compliance = self.tensorComplianceMetaData(
2134 op, val.dtype, args_dict, result_tensor, error_name
2135 )
2136
2137 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002138
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002139 def build_rescale(
2140 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002141 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002142 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002143 inputs,
2144 args_dict,
2145 validator_fcns=None,
2146 error_name=None,
2147 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002148 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002149 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002150 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002151 multiplier_val = inputs[1]
2152 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002153 out_dtype = args_dict["output_dtype"]
2154 scale32 = args_dict["scale"]
2155 double_round = args_dict["double_round"]
2156 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002157 shift_arr = args_dict["shift"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002158 multiplier_arr = args_dict["multiplier"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002159
2160 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002161 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002162 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002163
2164 if per_channel:
2165 nc = val.shape[-1]
2166 else:
2167 nc = 1
2168
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002169 in_type_width = gtu.dtypeWidth(val.dtype)
2170 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002171
Tai Ly8690a082023-12-18 20:40:24 +00002172 input_unsigned = False
2173 output_unsigned = False
2174
Kevin Cheng3a478572021-01-22 17:21:02 -08002175 if val.dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002176 input_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002177 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002178 elif val.dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002179 input_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002180 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002181 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002182 elif error_name in [
2183 ErrorIf.InputZeroPointNotZero,
2184 ErrorIf.U16InputZeroPointNotValid,
2185 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002186 input_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002187 if input_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002188 input_zp = input_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002189 in_type_width += 1
2190 elif val.dtype == DType.UINT16:
2191 # Must come after ErrorIf.U16InputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002192 input_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002193 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002194 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002195 else:
2196 input_zp = 0
2197
Kevin Cheng3a478572021-01-22 17:21:02 -08002198 if out_dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002199 output_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002200 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002201 elif out_dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002202 output_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002203 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002204 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002205 elif error_name in [
2206 ErrorIf.OutputZeroPointNotZero,
2207 ErrorIf.U16OutputZeroPointNotValid,
2208 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002209 output_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002210 if output_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002211 output_zp = output_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002212 out_type_width += 1
2213 elif out_dtype == DType.UINT16:
2214 # Must come after ErrorIf.U16OutputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002215 output_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002216 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002217 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002218 else:
2219 output_zp = 0
2220
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002221 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2222 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002223
2224 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002225 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2226 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002227
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002228 logger.debug(
2229 f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
2230 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002231 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002232 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002233 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002234 assert val.placeholderFilename
2235 values = np.load(
2236 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2237 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002238 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2239 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2240 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002241 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2242 # Check we can safely convert to the expected dtype
2243 assert (
2244 val_adj.all() >= np.iinfo(values.dtype).min
2245 and val_adj.all() <= np.iinfo(values.dtype).max
2246 )
2247
2248 # Force casting to output datatype
2249 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2250
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002251 if not np.all(np.array_equal(values, val_adj)):
2252 # Values changed so overwrite file with new values
2253 np.save(
2254 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2255 val_adj,
2256 False,
2257 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002258
Matthew Haddonc2025212021-10-08 21:21:05 +01002259 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002260 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002261 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002262 pCount, cCount = op["operands"]
2263 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002264 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002265 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002266 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002267
2268 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002269 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002270 self.ser,
2271 validator_fcns,
2272 error_name,
2273 op=op,
2274 input_dtype=val.dtype,
2275 output_dtype=out_dtype,
2276 input_shape=val.shape,
2277 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002278 scale32=scale32,
2279 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002280 input_list=input_list,
2281 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002282 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002283 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002284 ):
2285 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002286
Eric Kunzee5e26762020-10-13 16:11:07 -07002287 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002288 attr.RescaleAttribute(
2289 input_zp,
2290 output_zp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002291 scale32,
2292 double_round,
2293 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002294 input_unsigned,
2295 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002296 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002297
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002298 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002299
2300 compliance = self.tensorComplianceMetaData(
2301 op, val.dtype, args_dict, result_tensor, error_name
2302 )
2303
2304 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002305
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002306 def _get_condition_tensor(self, rng, op, cond, error_name):
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002307 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002308 cond_type = gtu.get_wrong_output_type(op, rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002309 else:
2310 cond_type = DType.BOOL
2311 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002312 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002313 if choice == 1:
2314 cond_shape = [2]
2315 else:
2316 cond_shape = [1, 2]
2317 else:
2318 # Must be of size 1 (rank 0)
2319 cond_shape = []
2320 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2321 return cond_tens
2322
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002323 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002324 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002325 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002326 op,
2327 inputs,
2328 args_dict,
2329 validator_fcns=None,
2330 error_name=None,
2331 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002332 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002333 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002334 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002335 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002336 assert len(inputs) == 2
2337 then_tens, else_tens = inputs
2338
2339 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002340
2341 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002342 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002343
2344 # Make then/else tensors
2345 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002346
Jeremy Johnson587cc842024-02-08 11:45:44 +00002347 dtype = DType.INT32
2348
Matthew Haddon630c17c2021-10-14 15:05:41 +01002349 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002350 if error_name in [
2351 ErrorIf.CondIfOutputListThenGraphMismatch,
2352 ErrorIf.CondIfOutputListElseGraphMismatch,
2353 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002354 incorrect_shape = deepcopy(then_tens.shape)
2355 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002356 incorrect_shape[i] += (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002357 rng.choice([-3, -2, 2, 3])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002358 if incorrect_shape[i] > 3
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002359 else rng.choice([1, 2, 4])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002360 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002361 incorrect_arr = np.int32(rng.integers(0, 256, size=incorrect_shape))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002362
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002363 then_arr = np.int32(rng.integers(0, 256, size=out_shape))
2364 else_arr = np.int32(rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002365
2366 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002367 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002368
2369 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002370 then_block = "THEN_BLOCK"
2371 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002372 attr = ts.TosaSerializerAttribute()
2373 attr.CondIfAttribute(then_block, else_block)
2374
2375 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002376 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002377
Jerry Ge9e94af82022-10-27 09:57:00 -07002378 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002379 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002380 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002381 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002382 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002383 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002384 self.ser.addOutputTensor(then_tens)
2385
Jerry Ge9e94af82022-10-27 09:57:00 -07002386 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002387 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002388 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002389 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002390 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002391 self.ser.addOutputTensor(else_tens)
2392
Les Bell729b0352021-11-24 10:28:21 +00002393 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002394 self.ser,
2395 validator_fcns,
2396 error_name,
2397 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002398 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002399 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002400 ):
2401 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002402
Jeremy Johnson587cc842024-02-08 11:45:44 +00002403 compliance = self.tensorComplianceMetaData(
2404 op, dtype, args_dict, result_tensor, error_name
2405 )
2406
2407 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002408
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002409 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002410 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002411 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002412 op,
2413 inputs,
2414 args_dict,
2415 validator_fcns=None,
2416 error_name=None,
2417 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002418 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002419 # For cond_if with a binary op in the then/else blocks, take a and b and
2420 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002421 assert len(inputs) == 2
2422 a, b = inputs
2423
2424 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002425
2426 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002427 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002428
Jeremy Johnson587cc842024-02-08 11:45:44 +00002429 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002430
2431 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002432 then_block = "THEN_BLOCK"
2433 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002434 attr = ts.TosaSerializerAttribute()
2435 attr.CondIfAttribute(then_block, else_block)
2436
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002437 if error_name in [
2438 ErrorIf.CondIfInputListThenGraphMismatch,
2439 ErrorIf.CondIfInputListElseGraphMismatch,
2440 ErrorIf.CondIfOutputListElseGraphMismatch,
2441 ErrorIf.CondIfOutputListThenGraphMismatch,
2442 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002443 incorrect_shape = a.shape.copy()
2444 for i in range(len(incorrect_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002445 incorrect_shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002446 incorrect_block_input = deepcopy(a)
2447 incorrect_block_input.shape = incorrect_shape
2448
Eric Kunzee5e26762020-10-13 16:11:07 -07002449 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002450 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002451 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002452 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002453
James Ward24dbc422022-10-19 12:20:31 +01002454 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002455 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002456 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002457 then_op, else_op = (
2458 self.TOSA_OP_LIST["logical_right_shift"],
2459 self.TOSA_OP_LIST["logical_left_shift"],
2460 )
Les Bell6040b4d2021-10-11 12:50:31 +01002461 else:
2462 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002463
Jeremy Johnson587cc842024-02-08 11:45:44 +00002464 # Determine the element-wise binary operation that compliance will need to
2465 # check the results of
2466 compliance_op = then_op if cond else else_op
2467
2468 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002469 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002470 if (
2471 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2472 and block == then_block
2473 ) or (
2474 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2475 and block == else_block
2476 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002477 self.ser.addInputTensor(incorrect_block_input)
2478 self.ser.addInputTensor(b)
2479 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002480 elif (
2481 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2482 and block == then_block
2483 ) or (
2484 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2485 and block == else_block
2486 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002487 self.ser.addInputTensor(a)
2488 self.ser.addInputTensor(b)
2489 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2490 else:
2491 self.ser.addInputTensor(a)
2492 self.ser.addInputTensor(b)
2493 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002494 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002495
Les Bell729b0352021-11-24 10:28:21 +00002496 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002497 self.ser,
2498 validator_fcns,
2499 error_name,
2500 op=op,
2501 a=a,
2502 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002503 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002504 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002505 ):
2506 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002507
Jeremy Johnson587cc842024-02-08 11:45:44 +00002508 compliance = self.tensorComplianceMetaData(
2509 compliance_op, a.dtype, args_dict, result_tensor, error_name
2510 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002511
Jeremy Johnson587cc842024-02-08 11:45:44 +00002512 return TosaTestGen.BuildInfo(result_tensor, compliance)
2513
2514 def build_while_loop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002515 self,
2516 rng,
2517 op,
2518 inputs,
2519 args_dict,
2520 validator_fcns=None,
2521 error_name=None,
2522 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002523 ):
2524 assert len(inputs) == 1
2525 a = inputs[0]
2526 iter_val = args_dict["iterations"]
2527
Kevin Cheng550ccc52021-03-03 11:21:43 -08002528 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002529
Kevin Cheng550ccc52021-03-03 11:21:43 -08002530 cond_block = "COND_BLOCK"
2531 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002532
2533 attr = ts.TosaSerializerAttribute()
2534 attr.WhileLoopAttribute(cond_block, body_block)
2535
2536 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002537 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002538 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002539 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002540
2541 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002542 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2543 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002544 if error_name == ErrorIf.InputListOutputListMismatch:
2545 incorrect_acc = deepcopy(acc)
2546 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002547 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002548 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2549 else:
2550 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002551
2552 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002553 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002554 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002555 [iter.name, a.name, acc.name],
2556 [iter_out.name, a_out.name, acc_out.name],
2557 attr,
2558 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002559 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002560
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002561 if error_name in [
2562 ErrorIf.InputListCondGraphMismatch,
2563 ErrorIf.InputListBodyGraphInputMismatch,
2564 ErrorIf.InputListBodyGraphOutputMismatch,
2565 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002566 incorrect_iter = deepcopy(iter)
2567 for i in range(len(incorrect_iter.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002568 incorrect_iter.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002569 if len(incorrect_iter.shape) == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002570 incorrect_iter.shape.append(rng.choice([-3, -2, 2, 3]))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002571
2572 incorrect_acc = deepcopy(acc)
2573 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002574 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002575
Eric Kunzee5e26762020-10-13 16:11:07 -07002576 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002577 self.ser.addBasicBlock(cond_block)
2578
Matthew Haddon630c17c2021-10-14 15:05:41 +01002579 if error_name == ErrorIf.InputListCondGraphMismatch:
2580 self.ser.addInputTensor(incorrect_iter)
2581 self.ser.addInputTensor(a)
2582 self.ser.addInputTensor(incorrect_acc)
2583 else:
2584 self.ser.addInputTensor(iter)
2585 self.ser.addInputTensor(a)
2586 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002587 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002588
2589 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002590 cond_type = rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002591 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002592 cond_type = DType.BOOL
2593 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002594 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002595 if choice == 1:
2596 cond_shape = [3]
2597 else:
2598 cond_shape = [1, 2]
2599 else:
2600 cond_shape = []
2601 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002602
Kevin Cheng550ccc52021-03-03 11:21:43 -08002603 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002604
2605 # BODY block (input: a, acc, iter, output: a, acc, iter)
2606 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002607 self.ser.addBasicBlock(body_block)
2608
Matthew Haddon630c17c2021-10-14 15:05:41 +01002609 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2610 self.ser.addInputTensor(incorrect_iter)
2611 self.ser.addInputTensor(a)
2612 self.ser.addInputTensor(incorrect_acc)
2613 else:
2614 self.ser.addInputTensor(iter)
2615 self.ser.addInputTensor(a)
2616 self.ser.addInputTensor(acc)
2617
Kevin Cheng550ccc52021-03-03 11:21:43 -08002618 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002619
2620 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002621 iter_body_out = self.ser.addIntermediate(
2622 incorrect_iter.shape, incorrect_iter.dtype
2623 )
2624 acc_body_out = self.ser.addIntermediate(
2625 incorrect_acc.shape, incorrect_acc.dtype
2626 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002627 else:
2628 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2629 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2630
Eric Kunzee5e26762020-10-13 16:11:07 -07002631 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2632 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2633 self.ser.addOutputTensor(iter_body_out)
2634 self.ser.addOutputTensor(a)
2635 self.ser.addOutputTensor(acc_body_out)
2636
Les Bell729b0352021-11-24 10:28:21 +00002637 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002638 self.ser,
2639 validator_fcns,
2640 error_name,
2641 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002642 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002643 ):
2644 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002645
Jeremy Johnson587cc842024-02-08 11:45:44 +00002646 compliance = self.tensorComplianceMetaData(
2647 op, a.dtype, args_dict, acc_out, error_name
2648 )
2649
2650 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002651
Luke Hutton57287132023-02-06 14:54:18 +00002652 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002653 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002654 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002655 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002656 inputs,
2657 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002658 validator_fcns=None,
2659 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002660 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002661 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002662 assert len(inputs) == 2
2663 val1, val2 = inputs
2664 inverse = args_dict["inverse"]
2665
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002666 results = OutputShaper.fft2dOp(self.ser, rng, val1, val2, error_name)
Luke Hutton57287132023-02-06 14:54:18 +00002667
2668 input_names = [val1.name, val2.name]
2669 pCount, cCount = op["operands"]
2670 num_operands = pCount + cCount
2671
2672 output_names = [res.name for res in results]
2673 output_shapes = [res.shape for res in results]
2674 output_dtypes = [res.dtype for res in results]
2675
2676 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002677 rng, error_name, input_names, output_names
Luke Hutton57287132023-02-06 14:54:18 +00002678 )
2679
2680 if not TosaErrorValidator.evValidateErrorIfs(
2681 self.ser,
2682 validator_fcns,
2683 error_name,
2684 op=op,
2685 inverse=inverse,
2686 input1=val1,
2687 input2=val2,
2688 input_shape=val1.shape,
2689 input_dtype=val1.dtype,
2690 output_shape=output_shapes,
2691 output_dtype=output_dtypes,
2692 result_tensors=results,
2693 input_list=input_names,
2694 output_list=output_names,
2695 num_operands=num_operands,
2696 ):
2697 return None
2698
Tai Lyd3797f02023-11-15 23:06:19 +00002699 # TODO - Test local_bound, for now set local bound attribute to False
2700 local_bound = False
2701
Luke Hutton57287132023-02-06 14:54:18 +00002702 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002703 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002704
2705 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002706
2707 compliance = []
2708 for res in results:
2709 compliance.append(
2710 self.tensorComplianceMetaData(
2711 op, val1.dtype, args_dict, res, error_name
2712 )
2713 )
2714
2715 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002716
Tai Lyd3797f02023-11-15 23:06:19 +00002717 def build_rfft2d(
2718 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002719 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002720 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002721 inputs,
2722 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002723 validator_fcns=None,
2724 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002725 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002726 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002727 assert len(inputs) == 1
2728 val = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002729 results = OutputShaper.rfft2dOp(self.ser, rng, val, error_name)
Luke Hutton261b7b62023-01-10 14:50:31 +00002730
2731 input_names = [val.name]
2732 pCount, cCount = op["operands"]
2733 num_operands = pCount + cCount
2734
2735 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002736 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002737 output_dtypes = [res.dtype for res in results]
2738
2739 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002740 rng, error_name, input_names, output_names
Luke Hutton261b7b62023-01-10 14:50:31 +00002741 )
2742
2743 if not TosaErrorValidator.evValidateErrorIfs(
2744 self.ser,
2745 validator_fcns,
2746 error_name,
2747 op=op,
2748 input_shape=val.shape,
2749 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002750 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002751 output_dtype=output_dtypes,
2752 result_tensors=results,
2753 input_list=input_names,
2754 output_list=output_names,
2755 num_operands=num_operands,
2756 ):
2757 return None
2758
Tai Lyd3797f02023-11-15 23:06:19 +00002759 # TODO - Test local_bound, for now set local bound attribute to False
2760 local_bound = False
2761
2762 attr = ts.TosaSerializerAttribute()
2763 attr.RFFTAttribute(local_bound)
2764
2765 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002766
2767 compliance = []
2768 for res in results:
2769 compliance.append(
2770 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2771 )
2772
2773 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002774
Won Jeon74342e52024-01-09 00:34:40 +00002775 def build_shape_op(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002776 self,
2777 rng,
2778 op,
2779 inputs,
2780 args_dict,
2781 validator_fcns=None,
2782 error_name=None,
2783 qinfo=None,
Won Jeon74342e52024-01-09 00:34:40 +00002784 ):
2785 assert len(inputs) == 2
2786 a, b = inputs
2787
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002788 result_tensor = OutputShaper.addShapeOp(self.ser, rng, a, b, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00002789
2790 # Invalidate Input/Output list for error if checks.
2791 input_list = [a.name, b.name]
2792 output_list = [result_tensor.name]
2793 pCount, cCount = op["operands"]
2794 num_operands = pCount + cCount
2795 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2796 self, error_name, input_list, output_list
2797 )
2798
2799 if not TosaErrorValidator.evValidateErrorIfs(
2800 self.ser,
2801 validator_fcns,
2802 error_name,
2803 op=op,
2804 input1=a,
2805 input2=b,
2806 input_shape=a.shape,
2807 input_dtype=a.dtype,
2808 output_shape=result_tensor.shape,
2809 output_dtype=result_tensor.dtype,
2810 result_tensors=[result_tensor],
2811 input_list=input_list,
2812 output_list=output_list,
2813 num_operands=num_operands,
2814 ):
2815 return None
2816
2817 self.ser.addOperator(
2818 op["op"],
2819 input_list,
2820 output_list,
2821 )
2822 compliance = self.tensorComplianceMetaData(
2823 op, a.dtype, args_dict, result_tensor, error_name
2824 )
2825
2826 return TosaTestGen.BuildInfo(result_tensor, compliance)
2827
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002828 def create_filter_lists(
2829 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2830 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002831 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2832 default_test_rank_range = range(1, 5)
2833 if not shapeFilter:
2834 shapeFilter = [None]
2835
2836 # Calculate the filters based on what is requested and what the operator allows
2837 rmin, rmax = op["rank"]
2838 if rankFilter is not None:
2839 cleanRankFilter = []
2840 # Ensure rankFilter values are allowed by operator
2841 for rank in rankFilter:
2842 if rank >= rmin and rank <= rmax:
2843 cleanRankFilter.append(rank)
2844 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002845 # Ensure default behaviour is bounded by default range or by operator,
2846 # whichever is the smaller range of ranks.
2847 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002848 cleanRankFilter = (
2849 opRankRange
2850 if len(opRankRange) <= len(default_test_rank_range)
2851 else default_test_rank_range
2852 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002853 else:
2854 cleanRankFilter = range(rmin, rmax + 1)
2855
2856 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002857
Matthew Haddon1c00b712021-10-01 15:51:03 +01002858 if dtypeFilter is not None:
2859 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002860 # Create list of operator dtypes filtered by requested dtypes
2861 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002862 if dtype in dtypeFilter or (
2863 isinstance(dtype, list) and dtype[0] in dtypeFilter
2864 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002865 cleanDtypeFilter.append(dtype)
2866 else:
2867 cleanDtypeFilter = dtypes
2868
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002869 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002870 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002871 "shapeFilter": shapeFilter,
2872 "rankFilter": cleanRankFilter,
2873 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002874 }
2875 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002876 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002877 if validator is not None:
2878 validator_info = validator(check=False, op=op)
2879 else:
2880 return None
2881
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002882 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002883
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002884 # Set parameters as required
2885 if error_arguments["rank"] is not None:
2886 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002887 else:
2888 rankFilter = cleanRankFilter
2889
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002890 if error_arguments["dtype"] is not None:
2891 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002892 else:
2893 dtypeFilter = cleanDtypeFilter
2894
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002895 if error_arguments["shape"] is not None:
2896 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002897 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002898 shapeFilter = shapeFilter[
2899 :2
2900 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002901
2902 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002903 "shapeFilter": shapeFilter,
2904 "rankFilter": rankFilter,
2905 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002906 }
2907 return filterDict
2908
Kevin Cheng550ccc52021-03-03 11:21:43 -08002909 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002910 self,
2911 opName,
2912 shapeFilter=[None],
2913 rankFilter=None,
2914 dtypeFilter=None,
2915 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002916 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002917
2918 try:
2919 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002920 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002921 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002922
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002923 if not self.args.stable_rng:
2924 # Initialize a new random number generator per op
2925 self.resetGlobalRNG()
Eric Kunzee5e26762020-10-13 16:11:07 -07002926
Jeremy Johnson1271c442023-09-05 11:39:26 +01002927 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002928
Eric Kunzee5e26762020-10-13 16:11:07 -07002929 # Test list consists of a tuple of:
2930 # (opName, testNameStr, dtype, shapeList, argumentsList)
2931 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002932 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002933 error_if_validators = op["error_if_validators"]
2934 else:
2935 error_if_validators = [None]
Eric Kunzee5e26762020-10-13 16:11:07 -07002936
Matthew Haddon1c00b712021-10-01 15:51:03 +01002937 for validator in error_if_validators:
2938 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002939 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002940 else:
2941 error_name = None
2942
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002943 filterDict = self.create_filter_lists(
2944 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2945 )
2946 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002947 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002948 cleanRankFilter = filterDict["rankFilter"]
2949 cleanDtypeFilter = filterDict["dtypeFilter"]
2950 cleanShapeFilter = filterDict["shapeFilter"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002951 logger.debug(
2952 f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
2953 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002954
2955 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002956 for t in cleanDtypeFilter:
2957 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002958 # Filter out by rank
2959 if shape is not None and len(shape) != r:
2960 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002961 self.setTargetShape(shape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002962 typeStr = self.typeStr(t)
2963 if self.args.stable_rng:
2964 shape_rng = TosaHashRandomGenerator(
2965 self.random_seed,
2966 [opName, r, typeStr],
2967 self.random_dtype_range,
2968 )
2969 else:
2970 shape_rng = self.global_rng
2971 shapeList = tgen_fcn(self, shape_rng, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002972
Matthew Haddon74567092021-07-16 15:38:20 +01002973 shapeStr = self.shapeStr(shapeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07002974
Matthew Haddon74567092021-07-16 15:38:20 +01002975 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2976 argList = []
2977 if agen_fcn:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002978 if self.args.stable_rng:
2979 arg_rng = TosaHashRandomGenerator(
2980 self.random_seed,
2981 [opName, shapeStr, typeStr],
2982 self.random_dtype_range,
2983 )
2984 else:
2985 arg_rng = self.global_rng
2986
2987 argList = agen_fcn(
2988 self, arg_rng, opName, shapeList, t, error_name
2989 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002990 else:
Matthew Haddon74567092021-07-16 15:38:20 +01002991 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07002992
Matthew Haddon74567092021-07-16 15:38:20 +01002993 for argStr, args in argList:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002994 # Create the test name string - for example: add_1x2x3_i32
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002995 if testType == "positive":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002996 name_parts = [opName, shapeStr, typeStr]
2997 else:
2998 assert testType == "negative"
2999 name_parts = [
3000 opName,
3001 "ERRORIF",
3002 error_name,
3003 shapeStr,
3004 typeStr,
3005 ]
3006 if argStr:
3007 name_parts.append(argStr)
3008 testStr = "_".join(name_parts)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003009
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003010 testList.append(
3011 (opName, testStr, t, error_name, shapeList, args)
3012 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003013
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003014 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01003015 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3016 if "invalid_test_validators" in op:
3017 invalid_test_validators = op["invalid_test_validators"]
3018 clean_testList = []
3019 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01003020 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003021 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003022 if validator_fcn(
3023 opName=test[0],
3024 input_dtype=test[2],
3025 shapeList=test[4],
3026 args=test[5],
3027 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003028 remove_test = True
3029 if not remove_test:
3030 clean_testList.append(test)
3031 testList = clean_testList
Eric Kunzee5e26762020-10-13 16:11:07 -07003032
3033 return testList
3034
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003035 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00003036 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003037 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003038 try:
3039 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003040 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003041 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003042
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003043 logger.info(f"Creating {testStr}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01003044
Eric Kunzee5e26762020-10-13 16:11:07 -07003045 # Create a serializer
3046 self.createSerializer(opName, testStr)
3047
Jeremy Johnson1271c442023-09-05 11:39:26 +01003048 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003049 if "error_if_validators" in op:
3050 error_if_validators = op["error_if_validators"]
3051 else:
3052 error_if_validators = None
3053
Kevin Cheng550ccc52021-03-03 11:21:43 -08003054 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003055 num_operands = pCount + cCount
3056
3057 if isinstance(dtype_or_dtypeList, list):
3058 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003059 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003060 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003061 else:
3062 dtypeList = [dtype_or_dtypeList] * (num_operands)
3063
Won Jeon74342e52024-01-09 00:34:40 +00003064 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003065 assert (
3066 len(shapeList) == num_operands
3067 ), "shapeList length {} must match number of operands {}".format(
3068 len(shapeList), num_operands
3069 )
3070 assert (
3071 len(dtypeList) == num_operands
3072 ), "dtypeList length {} must match number of operands {}".format(
3073 len(dtypeList), num_operands
3074 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003075
3076 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003077 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003078 except KeyError:
3079 qgen = None
3080
3081 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003082
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003083 # Set the random number generator
3084 if self.args.stable_rng:
3085 build_rng = TosaHashRandomGenerator(
3086 self.random_seed, [testStr], self.random_dtype_range
3087 )
3088 else:
3089 build_rng = self.global_rng
3090
Matthew Haddon1c00b712021-10-01 15:51:03 +01003091 if qgen is not None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003092 qinfo = qgen(
3093 build_rng, self.args.zeropoint, op, dtype_or_dtypeList, error_name
3094 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003095 else:
3096 qinfo = None
3097
Jeremy Johnson1271c442023-09-05 11:39:26 +01003098 # Extra meta data for the desc.json
3099 tensMeta = {}
3100
Jeremy Johnson587cc842024-02-08 11:45:44 +00003101 # Check we are using the new interface with an argsDict dictionary
3102 assert isinstance(
3103 argsDict, dict
3104 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003105
Jeremy Johnson587cc842024-02-08 11:45:44 +00003106 # New interface with args info in dictionary
3107 assert "dg_type" in argsDict
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003108 tvgInfo = tvgen_fcn(
3109 self, build_rng, opName, dtypeList, shapeList, argsDict, error_name
3110 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003111 if tvgInfo.dataGenDict:
3112 tensMeta["data_gen"] = tvgInfo.dataGenDict
3113 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003114
Jeremy Johnson587cc842024-02-08 11:45:44 +00003115 result = build_fcn(
3116 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003117 build_rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003118 op,
3119 tens,
3120 argsDict,
3121 validator_fcns=error_if_validators,
3122 error_name=error_name,
3123 qinfo=qinfo,
3124 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003125
Jeremy Johnson1271c442023-09-05 11:39:26 +01003126 if result:
Les Bell729b0352021-11-24 10:28:21 +00003127 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003128 if isinstance(result, TosaTestGen.BuildInfo):
3129 # Add the compliance meta data (if any)
3130 compliance = result.getComplianceInfo()
3131 if compliance:
3132 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003133 self.serialize("test", tensMeta)
Les Bell729b0352021-11-24 10:28:21 +00003134 else:
3135 # The test is not valid
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003136 logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
Matthew Haddon1c00b712021-10-01 15:51:03 +01003137
Eric Kunzee5e26762020-10-13 16:11:07 -07003138 def createDynamicOpLists(self):
3139
Jeremy Johnson00423432022-09-12 17:27:37 +01003140 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3141 # Already created these lists (can occur when class is initialized more than once)
3142 return
3143
Eric Kunzee5e26762020-10-13 16:11:07 -07003144 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003145 if not self.args.level8k:
3146 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3147 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3148 else:
3149 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3150 KERNELS_2D = [[1, bigK], [bigK, 2]]
3151 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003152
Kevin Cheng1533b852021-09-01 12:51:58 -07003153 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003154 testName = "conv2d_{}x{}".format(k[0], k[1])
3155 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3156 self.TOSA_OP_LIST[testName]["filter"] = k
3157 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003158 self.TOSA_OP_LIST[testName]["real_name"] = "conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003159
Kevin Cheng550ccc52021-03-03 11:21:43 -08003160 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3161 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3162 "depthwise_conv2d_TEMPLATE"
3163 ].copy()
3164 self.TOSA_OP_LIST[testName]["filter"] = k
3165 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003166 self.TOSA_OP_LIST[testName]["real_name"] = "depthwise_conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003167
Kevin Cheng550ccc52021-03-03 11:21:43 -08003168 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3169 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3170 "transpose_conv2d_TEMPLATE"
3171 ].copy()
3172 self.TOSA_OP_LIST[testName]["filter"] = k
3173 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003174 self.TOSA_OP_LIST[testName]["real_name"] = "transpose_conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003175
Kevin Cheng1533b852021-09-01 12:51:58 -07003176 for k in KERNELS_3D:
3177 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3178 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3179 self.TOSA_OP_LIST[testName]["filter"] = k
3180 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003181 self.TOSA_OP_LIST[testName]["real_name"] = "conv3d"
Kevin Cheng1533b852021-09-01 12:51:58 -07003182
Eric Kunzee5e26762020-10-13 16:11:07 -07003183 # Delete any templates after having created any dynamic ops
3184 # This is a two-pass operation because it's bad practice to delete
3185 # keys from dictionaries while iterating
3186 keyList = []
3187 for k in self.TOSA_OP_LIST:
3188 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003189 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003190 keyList.append(k)
3191 continue
3192 except KeyError:
3193 pass
3194
3195 for k in keyList:
3196 del self.TOSA_OP_LIST[k]
3197
3198 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003199 """Fill in default fields for ops if they aren't already specified.
3200 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003201 for op in self.TOSA_OP_LIST:
3202
3203 # Required fields
3204 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003205 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003206 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003207 raise Exception(
3208 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3209 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003210
3211 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003212 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003213 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003214 raise Exception(
3215 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3216 op
3217 )
3218 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003219
3220 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003221 _ = self.TOSA_OP_LIST[op]["types"]
3222 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003223 raise Exception(
3224 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3225 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003226
3227 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003228 _ = self.TOSA_OP_LIST[op]["op"]
3229 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003230 raise Exception(
3231 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3232 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003233
3234 # Put in default rank range, if missing
3235 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003236 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003237 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003238 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003239
3240 # Tensor operator list
3241 # 'op': op name
3242 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003243 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3244 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003245 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3246 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003247 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003248
Kevin Cheng550ccc52021-03-03 11:21:43 -08003249 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003250 TYPE_INT_FP = [
3251 DType.INT8,
3252 DType.INT16,
3253 DType.INT32,
3254 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003255 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003256 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003257 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003258
Kevin Cheng550ccc52021-03-03 11:21:43 -08003259 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003260 TYPE_FI32 = [
3261 DType.FP32,
3262 DType.FP16,
3263 DType.BF16,
3264 DType.INT32,
3265 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003266 TYPE_FIB = [
3267 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003268 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003269 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003270 DType.INT8,
3271 DType.INT16,
3272 DType.INT32,
3273 DType.BOOL,
3274 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003275 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003276
Won Jeon2c34b462024-02-06 18:37:00 +00003277 TYPE_NARROW_INT_FP = [
3278 DType.INT8,
3279 DType.INT16,
3280 DType.FP16,
3281 DType.BF16,
3282 DType.FP32,
3283 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003284
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003285 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003286 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003287 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003288 [DType.INT8, DType.INT8, DType.INT32],
3289 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003290 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003291 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003292 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003293 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003294 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3295 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003296 ]
3297
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01003298 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003299
3300 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003301 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003302 "argmax": {
3303 "op": Op.ARGMAX,
3304 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003305 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003306 "build_fcn": (
3307 build_argmax,
3308 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003309 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003310 TosaArgGen.agAxis,
3311 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003312 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003313 "error_if_validators": (
3314 TosaErrorValidator.evAxisSmallerZero,
3315 TosaErrorValidator.evAxisLargerRank,
3316 TosaErrorValidator.evArgmaxOutputRankMismatch,
3317 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3318 TosaErrorValidator.evWrongRank,
3319 TosaErrorValidator.evWrongInputType,
3320 TosaErrorValidator.evWrongOutputType,
3321 TosaErrorValidator.evWrongInputList,
3322 TosaErrorValidator.evWrongOutputList,
3323 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003324 "data_gen": {
3325 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3326 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003327 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003328 "avg_pool2d": {
3329 "op": Op.AVG_POOL2D,
3330 "operands": (1, 0),
3331 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003332 "build_fcn": (
3333 build_pool2d,
3334 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003335 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003336 TosaArgGen.agPooling,
3337 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003338 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003339 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003340 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003341 "error_if_validators": (
3342 TosaErrorValidator.evKernelSmallerOne,
3343 TosaErrorValidator.evStrideSmallerOne,
3344 TosaErrorValidator.evPadSmallerZero,
3345 TosaErrorValidator.evWrongRank,
3346 TosaErrorValidator.evWrongInputType,
3347 TosaErrorValidator.evWrongOutputType,
3348 TosaErrorValidator.evWrongInputList,
3349 TosaErrorValidator.evWrongOutputList,
3350 TosaErrorValidator.evInputZeroPointNotZero,
3351 TosaErrorValidator.evOutputZeroPointNotZero,
3352 TosaErrorValidator.evPadLargerEqualKernel,
3353 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003354 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003355 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003356 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003357 "data_gen": {
3358 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3359 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003360 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003361 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003362 "conv2d_TEMPLATE": {
3363 "op": Op.CONV2D,
3364 "operands": (1, 2),
3365 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003366 "build_fcn": (
3367 build_conv2d,
3368 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003369 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003370 TosaArgGen.agConv,
3371 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003372 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003373 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003374 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3375 "error_if_validators": (
3376 TosaErrorValidator.evWrongInputType,
3377 TosaErrorValidator.evWrongOutputType,
3378 TosaErrorValidator.evWrongInputList,
3379 TosaErrorValidator.evWrongOutputList,
3380 TosaErrorValidator.evInputZeroPointNotZero,
3381 TosaErrorValidator.evWeightZeroPointNotZero,
3382 TosaErrorValidator.evPadSmallerZero,
3383 TosaErrorValidator.evStrideSmallerOne,
3384 TosaErrorValidator.evDilationSmallerOne,
3385 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003386 TosaErrorValidator.evConvOutputShapeMismatch,
3387 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003388 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003389 "data_gen": {
3390 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3391 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003392 "template": True,
3393 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003394 # Templated operator. Filled in by createDynamicOpLists
3395 "conv3d_TEMPLATE": {
3396 "op": Op.CONV3D,
3397 "operands": (1, 2),
3398 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003399 "build_fcn": (
3400 build_conv3d,
3401 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003402 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003403 TosaArgGen.agConv,
3404 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003405 "qgen": TosaQuantGen.qgConv,
3406 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003407 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3408 "error_if_validators": (
3409 TosaErrorValidator.evWrongInputType,
3410 TosaErrorValidator.evWrongOutputType,
3411 TosaErrorValidator.evWrongInputList,
3412 TosaErrorValidator.evWrongOutputList,
3413 TosaErrorValidator.evInputZeroPointNotZero,
3414 TosaErrorValidator.evWeightZeroPointNotZero,
3415 TosaErrorValidator.evPadSmallerZero,
3416 TosaErrorValidator.evStrideSmallerOne,
3417 TosaErrorValidator.evDilationSmallerOne,
3418 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003419 TosaErrorValidator.evConvOutputShapeMismatch,
3420 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003421 ),
evacha0147ab1762024-01-29 13:23:23 +00003422 "data_gen": {
3423 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3424 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003425 "template": True,
3426 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003427 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003428 "depthwise_conv2d_TEMPLATE": {
3429 "op": Op.DEPTHWISE_CONV2D,
3430 "operands": (1, 2),
3431 "filter": [1, 1],
3432 "rank": (4, 4),
3433 "build_fcn": (
3434 build_depthwise_conv2d,
3435 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003436 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003437 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003438 ),
3439 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003440 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003441 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3442 "error_if_validators": (
3443 TosaErrorValidator.evWrongInputType,
3444 TosaErrorValidator.evWrongOutputType,
3445 TosaErrorValidator.evWrongInputList,
3446 TosaErrorValidator.evWrongOutputList,
3447 TosaErrorValidator.evInputZeroPointNotZero,
3448 TosaErrorValidator.evWeightZeroPointNotZero,
3449 TosaErrorValidator.evPadSmallerZero,
3450 TosaErrorValidator.evStrideSmallerOne,
3451 TosaErrorValidator.evDilationSmallerOne,
3452 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003453 TosaErrorValidator.evConvOutputShapeMismatch,
3454 TosaErrorValidator.evConvOutputShapeNonInteger,
Les Bell0e027d42021-11-09 14:42:14 +00003455 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003456 "data_gen": {
3457 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3458 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003459 "template": True,
3460 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003461 "fully_connected": {
3462 "op": Op.FULLY_CONNECTED,
3463 "operands": (1, 2),
3464 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003465 "build_fcn": (
3466 build_fully_connected,
3467 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003468 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003469 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003470 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003471 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003472 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003473 "error_if_validators": (
3474 TosaErrorValidator.evInputZeroPointNotZero,
3475 TosaErrorValidator.evWeightZeroPointNotZero,
3476 TosaErrorValidator.evWrongRank,
3477 TosaErrorValidator.evWrongInputType,
3478 TosaErrorValidator.evWrongOutputType,
3479 TosaErrorValidator.evWrongInputList,
3480 TosaErrorValidator.evWrongOutputList,
3481 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003482 "data_gen": {
3483 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3484 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003485 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003486 "matmul": {
3487 "op": Op.MATMUL,
3488 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003489 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003490 "build_fcn": (
3491 build_matmul,
3492 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003493 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003494 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003495 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003496 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003497 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003498 "error_if_validators": (
3499 TosaErrorValidator.evInputZeroPointNotZero,
3500 TosaErrorValidator.evWrongRank,
3501 TosaErrorValidator.evWrongInputType,
3502 TosaErrorValidator.evWrongOutputType,
3503 TosaErrorValidator.evWrongInputList,
3504 TosaErrorValidator.evWrongOutputList,
3505 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003506 "data_gen": {
3507 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003508 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003509 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003510 "max_pool2d": {
3511 "op": Op.MAX_POOL2D,
3512 "operands": (1, 0),
3513 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003514 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003515 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003516 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003517 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003518 TosaArgGen.agPooling,
3519 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003520 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003521 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003522 "error_if_validators": (
3523 TosaErrorValidator.evKernelSmallerOne,
3524 TosaErrorValidator.evStrideSmallerOne,
3525 TosaErrorValidator.evPadSmallerZero,
3526 TosaErrorValidator.evWrongRank,
3527 TosaErrorValidator.evWrongInputType,
3528 TosaErrorValidator.evWrongOutputType,
3529 TosaErrorValidator.evWrongInputList,
3530 TosaErrorValidator.evWrongOutputList,
3531 TosaErrorValidator.evPadLargerEqualKernel,
3532 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003533 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003534 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003535 "data_gen": {
3536 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3537 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003538 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003539 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003540 "transpose_conv2d_TEMPLATE": {
3541 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003542 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003543 "rank": (4, 4),
3544 "build_fcn": (
3545 build_transpose_conv2d,
3546 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003547 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003548 TosaArgGen.agTransposeConv2D,
3549 ),
3550 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003551 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003552 "invalid_test_validators": (
3553 TosaInvalidValidator.ivHeightWidthInvalid,
3554 TosaInvalidValidator.ivNonPositiveOutputShape,
3555 ),
3556 "error_if_validators": (
3557 TosaErrorValidator.evWrongInputType,
3558 TosaErrorValidator.evWrongOutputType,
3559 TosaErrorValidator.evWrongInputList,
3560 TosaErrorValidator.evWrongOutputList,
3561 TosaErrorValidator.evInputZeroPointNotZero,
3562 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003563 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003564 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003565 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003566 TosaErrorValidator.evConvOutputShapeMismatch,
Les Bell0e027d42021-11-09 14:42:14 +00003567 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003568 "data_gen": {
3569 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3570 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003571 "template": True,
3572 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003573 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003574 "clamp": {
3575 "op": Op.CLAMP,
3576 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003577 "build_fcn": (
3578 build_clamp,
3579 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003580 TosaTensorValuesGen.tvgLazyGenDefault,
3581 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003582 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003583 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003584 "error_if_validators": (
3585 TosaErrorValidator.evMaxSmallerMin,
3586 TosaErrorValidator.evWrongInputType,
3587 TosaErrorValidator.evWrongOutputType,
3588 TosaErrorValidator.evWrongInputList,
3589 TosaErrorValidator.evWrongOutputList,
3590 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003591 "data_gen": {
3592 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3593 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003594 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003595 "sigmoid": {
3596 "op": Op.SIGMOID,
3597 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003598 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003599 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003600 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003601 TosaTensorValuesGen.tvgLazyGenDefault,
3602 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003603 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003604 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003605 "error_if_validators": (
3606 TosaErrorValidator.evWrongInputType,
3607 TosaErrorValidator.evWrongOutputType,
3608 TosaErrorValidator.evWrongInputList,
3609 TosaErrorValidator.evWrongOutputList,
3610 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003611 "data_gen": {
3612 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3613 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003614 },
3615 "tanh": {
3616 "op": Op.TANH,
3617 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003618 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003619 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003620 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003621 TosaTensorValuesGen.tvgLazyGenDefault,
3622 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003623 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003624 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003625 "error_if_validators": (
3626 TosaErrorValidator.evWrongInputType,
3627 TosaErrorValidator.evWrongOutputType,
3628 TosaErrorValidator.evWrongInputList,
3629 TosaErrorValidator.evWrongOutputList,
3630 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003631 "data_gen": {
3632 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3633 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003634 "compliance": {
3635 "abs_error_lower_bound": 0.5,
3636 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003637 },
Won Jeon78155c62023-06-10 00:20:04 +00003638 "erf": {
3639 "op": Op.ERF,
3640 "operands": (1, 0),
3641 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003642 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003643 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003644 TosaTensorValuesGen.tvgLazyGenDefault,
3645 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003646 ),
3647 "types": TYPE_FP,
3648 "error_if_validators": (
3649 TosaErrorValidator.evWrongInputType,
3650 TosaErrorValidator.evWrongOutputType,
3651 TosaErrorValidator.evWrongInputList,
3652 TosaErrorValidator.evWrongOutputList,
3653 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003654 "data_gen": {
3655 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3656 },
3657 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003658 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003659 # Elementwise Binary Operators
3660 "add": {
3661 "op": Op.ADD,
3662 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003663 "build_fcn": (
3664 build_binary_broadcast,
3665 TosaTensorGen.tgBroadcastFuzz,
3666 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003667 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003668 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003669 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003670 "error_if_validators": (
3671 TosaErrorValidator.evRankMismatch,
3672 TosaErrorValidator.evWrongInputType,
3673 TosaErrorValidator.evWrongOutputType,
3674 TosaErrorValidator.evWrongInputList,
3675 TosaErrorValidator.evWrongOutputList,
3676 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003677 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003678 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003679 "data_gen": {
3680 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3681 },
3682 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003683 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003684 "arithmetic_right_shift": {
3685 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3686 "operands": (2, 0),
3687 "build_fcn": (
3688 build_arithmetic_right_shift,
3689 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003690 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003691 TosaArgGen.agArithmeticRightShift,
3692 ),
3693 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003694 "error_if_validators": (
3695 TosaErrorValidator.evRankMismatch,
3696 TosaErrorValidator.evWrongInputType,
3697 TosaErrorValidator.evWrongOutputType,
3698 TosaErrorValidator.evWrongInputList,
3699 TosaErrorValidator.evWrongOutputList,
3700 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003701 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003702 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003703 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003704 "bitwise_and": {
3705 "op": Op.BITWISE_AND,
3706 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003707 "build_fcn": (
3708 build_binary_broadcast,
3709 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003710 TosaTensorValuesGen.tvgLazyGenDefault,
3711 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003712 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003713 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003714 "error_if_validators": (
3715 TosaErrorValidator.evRankMismatch,
3716 TosaErrorValidator.evWrongInputType,
3717 TosaErrorValidator.evWrongOutputType,
3718 TosaErrorValidator.evWrongInputList,
3719 TosaErrorValidator.evWrongOutputList,
3720 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003721 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003722 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003723 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003724 "bitwise_or": {
3725 "op": Op.BITWISE_OR,
3726 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003727 "build_fcn": (
3728 build_binary_broadcast,
3729 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003730 TosaTensorValuesGen.tvgLazyGenDefault,
3731 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003732 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003733 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 "error_if_validators": (
3735 TosaErrorValidator.evRankMismatch,
3736 TosaErrorValidator.evWrongInputType,
3737 TosaErrorValidator.evWrongOutputType,
3738 TosaErrorValidator.evWrongInputList,
3739 TosaErrorValidator.evWrongOutputList,
3740 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003741 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003743 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003744 "bitwise_xor": {
3745 "op": Op.BITWISE_XOR,
3746 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003747 "build_fcn": (
3748 build_binary_broadcast,
3749 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003750 TosaTensorValuesGen.tvgLazyGenDefault,
3751 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003753 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003754 "error_if_validators": (
3755 TosaErrorValidator.evRankMismatch,
3756 TosaErrorValidator.evWrongInputType,
3757 TosaErrorValidator.evWrongOutputType,
3758 TosaErrorValidator.evWrongInputList,
3759 TosaErrorValidator.evWrongOutputList,
3760 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003761 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003762 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003763 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003764 "intdiv": {
3765 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003766 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003767 "build_fcn": (
3768 build_binary_broadcast,
3769 TosaTensorGen.tgBroadcastFuzz,
3770 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003771 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003772 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003773 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003774 "error_if_validators": (
3775 TosaErrorValidator.evRankMismatch,
3776 TosaErrorValidator.evWrongInputType,
3777 TosaErrorValidator.evWrongOutputType,
3778 TosaErrorValidator.evWrongInputList,
3779 TosaErrorValidator.evWrongOutputList,
3780 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003781 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003782 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003783 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003784 "logical_and": {
3785 "op": Op.LOGICAL_AND,
3786 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003787 "build_fcn": (
3788 build_binary_broadcast,
3789 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003790 TosaTensorValuesGen.tvgLazyGenDefault,
3791 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003792 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003793 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003794 "error_if_validators": (
3795 TosaErrorValidator.evRankMismatch,
3796 TosaErrorValidator.evWrongInputType,
3797 TosaErrorValidator.evWrongOutputType,
3798 TosaErrorValidator.evWrongInputList,
3799 TosaErrorValidator.evWrongOutputList,
3800 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003801 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003802 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003803 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003804 "logical_left_shift": {
3805 "op": Op.LOGICAL_LEFT_SHIFT,
3806 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003807 "build_fcn": (
3808 build_binary_broadcast,
3809 TosaTensorGen.tgBroadcastFuzz,
3810 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003811 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003812 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003813 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003814 "error_if_validators": (
3815 TosaErrorValidator.evRankMismatch,
3816 TosaErrorValidator.evWrongInputType,
3817 TosaErrorValidator.evWrongOutputType,
3818 TosaErrorValidator.evWrongInputList,
3819 TosaErrorValidator.evWrongOutputList,
3820 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003821 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003822 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003823 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003824 "logical_right_shift": {
3825 "op": Op.LOGICAL_RIGHT_SHIFT,
3826 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003827 "build_fcn": (
3828 build_binary_broadcast,
3829 TosaTensorGen.tgBroadcastFuzz,
3830 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003831 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003832 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003833 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003834 "error_if_validators": (
3835 TosaErrorValidator.evRankMismatch,
3836 TosaErrorValidator.evWrongInputType,
3837 TosaErrorValidator.evWrongOutputType,
3838 TosaErrorValidator.evWrongInputList,
3839 TosaErrorValidator.evWrongOutputList,
3840 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003841 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003842 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003843 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003844 "logical_or": {
3845 "op": Op.LOGICAL_OR,
3846 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003847 "build_fcn": (
3848 build_binary_broadcast,
3849 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003850 TosaTensorValuesGen.tvgLazyGenDefault,
3851 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003852 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003853 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003854 "error_if_validators": (
3855 TosaErrorValidator.evRankMismatch,
3856 TosaErrorValidator.evWrongInputType,
3857 TosaErrorValidator.evWrongOutputType,
3858 TosaErrorValidator.evWrongInputList,
3859 TosaErrorValidator.evWrongOutputList,
3860 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003861 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003862 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003863 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003864 "logical_xor": {
3865 "op": Op.LOGICAL_XOR,
3866 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003867 "build_fcn": (
3868 build_binary_broadcast,
3869 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003870 TosaTensorValuesGen.tvgLazyGenDefault,
3871 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003872 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003873 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003874 "error_if_validators": (
3875 TosaErrorValidator.evRankMismatch,
3876 TosaErrorValidator.evWrongInputType,
3877 TosaErrorValidator.evWrongOutputType,
3878 TosaErrorValidator.evWrongInputList,
3879 TosaErrorValidator.evWrongOutputList,
3880 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003881 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003882 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003883 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003884 "maximum": {
3885 "op": Op.MAXIMUM,
3886 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003887 "build_fcn": (
3888 build_binary_broadcast,
3889 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003890 TosaTensorValuesGen.tvgLazyGenDefault,
3891 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003892 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003893 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003894 "error_if_validators": (
3895 TosaErrorValidator.evRankMismatch,
3896 TosaErrorValidator.evWrongInputType,
3897 TosaErrorValidator.evWrongOutputType,
3898 TosaErrorValidator.evWrongInputList,
3899 TosaErrorValidator.evWrongOutputList,
3900 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003901 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003902 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003903 "data_gen": {
3904 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3905 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003906 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003907 "minimum": {
3908 "op": Op.MINIMUM,
3909 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003910 "build_fcn": (
3911 build_binary_broadcast,
3912 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003913 TosaTensorValuesGen.tvgLazyGenDefault,
3914 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003915 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003916 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003917 "error_if_validators": (
3918 TosaErrorValidator.evRankMismatch,
3919 TosaErrorValidator.evWrongInputType,
3920 TosaErrorValidator.evWrongOutputType,
3921 TosaErrorValidator.evWrongInputList,
3922 TosaErrorValidator.evWrongOutputList,
3923 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003924 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003925 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003926 "data_gen": {
3927 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3928 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003929 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 "mul": {
3931 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003932 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003933 "build_fcn": (
3934 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003935 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003936 TosaTensorValuesGen.tvgMul,
3937 TosaArgGen.agMul,
3938 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003939 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003940 "error_if_validators": (
3941 TosaErrorValidator.evWrongInputType,
3942 TosaErrorValidator.evWrongOutputType,
3943 TosaErrorValidator.evWrongInputList,
3944 TosaErrorValidator.evWrongOutputList,
3945 TosaErrorValidator.evRankMismatch,
3946 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003947 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003948 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003949 "data_gen": {
3950 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3951 },
3952 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003954 "pow": {
3955 "op": Op.POW,
3956 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003957 "build_fcn": (
3958 build_binary_broadcast,
3959 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003960 TosaTensorValuesGen.tvgPow,
3961 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003962 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003963 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003964 "error_if_validators": (
3965 TosaErrorValidator.evRankMismatch,
3966 TosaErrorValidator.evWrongInputType,
3967 TosaErrorValidator.evWrongOutputType,
3968 TosaErrorValidator.evWrongInputList,
3969 TosaErrorValidator.evWrongOutputList,
3970 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003971 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003972 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00003973 "data_gen": {
3974 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3975 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003977 "sub": {
3978 "op": Op.SUB,
3979 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003980 "build_fcn": (
3981 build_binary_broadcast,
3982 TosaTensorGen.tgBroadcastFuzz,
3983 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003984 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003985 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003986 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003987 "error_if_validators": (
3988 TosaErrorValidator.evRankMismatch,
3989 TosaErrorValidator.evWrongInputType,
3990 TosaErrorValidator.evWrongOutputType,
3991 TosaErrorValidator.evWrongInputList,
3992 TosaErrorValidator.evWrongOutputList,
3993 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003994 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003995 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003996 "data_gen": {
3997 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3998 },
3999 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004001 "table": {
4002 "op": Op.TABLE,
4003 # Use the automatic generation functions to create the input array
4004 # but create the table tensor in the build function, as it may be
4005 # a different type from the input
4006 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004007 "build_fcn": (
4008 build_table,
4009 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004010 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004011 TosaArgGen.agTable,
4012 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004013 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004014 "error_if_validators": (
4015 TosaErrorValidator.evWrongInputType,
4016 TosaErrorValidator.evWrongOutputType,
4017 TosaErrorValidator.evWrongInputList,
4018 TosaErrorValidator.evWrongOutputList,
4019 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004020 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004021 # Elementwise Unary operators
4022 "abs": {
4023 "op": Op.ABS,
4024 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004025 "build_fcn": (
4026 build_unary,
4027 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004028 TosaTensorValuesGen.tvgLazyGenDefault,
4029 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004030 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004031 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004032 "error_if_validators": (
4033 TosaErrorValidator.evWrongInputType,
4034 TosaErrorValidator.evWrongOutputType,
4035 TosaErrorValidator.evWrongInputList,
4036 TosaErrorValidator.evWrongOutputList,
4037 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004038 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004039 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004040 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004041 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004042 "bitwise_not": {
4043 "op": Op.BITWISE_NOT,
4044 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004045 "build_fcn": (
4046 build_unary,
4047 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004048 TosaTensorValuesGen.tvgLazyGenDefault,
4049 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004050 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004051 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004052 "error_if_validators": (
4053 TosaErrorValidator.evWrongInputType,
4054 TosaErrorValidator.evWrongOutputType,
4055 TosaErrorValidator.evWrongInputList,
4056 TosaErrorValidator.evWrongOutputList,
4057 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004058 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004059 "ceil": {
4060 "op": Op.CEIL,
4061 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004062 "build_fcn": (
4063 build_unary,
4064 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004065 TosaTensorValuesGen.tvgLazyGenDefault,
4066 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004067 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004068 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004069 "error_if_validators": (
4070 TosaErrorValidator.evWrongInputType,
4071 TosaErrorValidator.evWrongOutputType,
4072 TosaErrorValidator.evWrongInputList,
4073 TosaErrorValidator.evWrongOutputList,
4074 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004075 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004076 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004077 },
4078 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004079 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004080 "clz": {
4081 "op": Op.CLZ,
4082 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004083 "build_fcn": (
4084 build_unary,
4085 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004086 TosaTensorValuesGen.tvgLazyGenDefault,
4087 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004088 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004089 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004090 "error_if_validators": (
4091 TosaErrorValidator.evWrongInputType,
4092 TosaErrorValidator.evWrongOutputType,
4093 TosaErrorValidator.evWrongInputList,
4094 TosaErrorValidator.evWrongOutputList,
4095 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004096 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004097 "cos": {
4098 "op": Op.COS,
4099 "operands": (1, 0),
4100 "build_fcn": (
4101 build_unary,
4102 TosaTensorGen.tgBasic,
4103 TosaTensorValuesGen.tvgLazyGenDefault,
4104 TosaArgGen.agNone,
4105 ),
4106 "types": TYPE_FP,
4107 "error_if_validators": (
4108 TosaErrorValidator.evWrongInputType,
4109 TosaErrorValidator.evWrongOutputType,
4110 TosaErrorValidator.evWrongInputList,
4111 TosaErrorValidator.evWrongOutputList,
4112 ),
4113 "data_gen": {
4114 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4115 },
4116 "compliance": {"abs_error_normal_divisor": 2},
4117 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004118 "exp": {
4119 "op": Op.EXP,
4120 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004121 "build_fcn": (
4122 build_unary,
4123 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004124 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004125 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004126 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004127 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004128 "error_if_validators": (
4129 TosaErrorValidator.evWrongInputType,
4130 TosaErrorValidator.evWrongOutputType,
4131 TosaErrorValidator.evWrongInputList,
4132 TosaErrorValidator.evWrongOutputList,
4133 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004134 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004135 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004136 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004137 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004138 "floor": {
4139 "op": Op.FLOOR,
4140 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004141 "build_fcn": (
4142 build_unary,
4143 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004144 TosaTensorValuesGen.tvgLazyGenDefault,
4145 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004146 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004147 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004148 "error_if_validators": (
4149 TosaErrorValidator.evWrongInputType,
4150 TosaErrorValidator.evWrongOutputType,
4151 TosaErrorValidator.evWrongInputList,
4152 TosaErrorValidator.evWrongOutputList,
4153 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004154 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004155 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004156 },
4157 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004158 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004159 "log": {
4160 "op": Op.LOG,
4161 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004162 "build_fcn": (
4163 build_unary,
4164 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004165 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004166 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004167 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004168 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004169 "error_if_validators": (
4170 TosaErrorValidator.evWrongInputType,
4171 TosaErrorValidator.evWrongOutputType,
4172 TosaErrorValidator.evWrongInputList,
4173 TosaErrorValidator.evWrongOutputList,
4174 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004175 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004176 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004177 },
4178 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004179 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004180 "logical_not": {
4181 "op": Op.LOGICAL_NOT,
4182 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004183 "build_fcn": (
4184 build_unary,
4185 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004186 TosaTensorValuesGen.tvgLazyGenDefault,
4187 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004188 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004189 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004190 "error_if_validators": (
4191 TosaErrorValidator.evWrongInputType,
4192 TosaErrorValidator.evWrongOutputType,
4193 TosaErrorValidator.evWrongInputList,
4194 TosaErrorValidator.evWrongOutputList,
4195 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004196 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004197 "negate": {
4198 "op": Op.NEGATE,
4199 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004200 "build_fcn": (
4201 build_unary,
4202 TosaTensorGen.tgBasic,
4203 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004204 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004205 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004206 "qgen": TosaQuantGen.qgUnary,
4207 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004208 "error_if_validators": (
4209 TosaErrorValidator.evInputZeroPointNotZero,
4210 TosaErrorValidator.evOutputZeroPointNotZero,
4211 TosaErrorValidator.evWrongInputType,
4212 TosaErrorValidator.evWrongOutputType,
4213 TosaErrorValidator.evWrongInputList,
4214 TosaErrorValidator.evWrongOutputList,
4215 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004216 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004217 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004218 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004219 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004220 "reciprocal": {
4221 "op": Op.RECIPROCAL,
4222 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004223 "build_fcn": (
4224 build_unary,
4225 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004226 TosaTensorValuesGen.tvgLazyGenDefault,
4227 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004228 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004229 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004230 "error_if_validators": (
4231 TosaErrorValidator.evWrongInputType,
4232 TosaErrorValidator.evWrongOutputType,
4233 TosaErrorValidator.evWrongInputList,
4234 TosaErrorValidator.evWrongOutputList,
4235 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004236 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004237 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004238 },
4239 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004240 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004241 "rsqrt": {
4242 "op": Op.RSQRT,
4243 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004244 "build_fcn": (
4245 build_unary,
4246 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004247 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004248 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004249 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004250 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004251 "error_if_validators": (
4252 TosaErrorValidator.evWrongInputType,
4253 TosaErrorValidator.evWrongOutputType,
4254 TosaErrorValidator.evWrongInputList,
4255 TosaErrorValidator.evWrongOutputList,
4256 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004257 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004258 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004259 },
4260 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004261 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004262 "sin": {
4263 "op": Op.SIN,
4264 "operands": (1, 0),
4265 "build_fcn": (
4266 build_unary,
4267 TosaTensorGen.tgBasic,
4268 TosaTensorValuesGen.tvgLazyGenDefault,
4269 TosaArgGen.agNone,
4270 ),
4271 "types": TYPE_FP,
4272 "error_if_validators": (
4273 TosaErrorValidator.evWrongInputType,
4274 TosaErrorValidator.evWrongOutputType,
4275 TosaErrorValidator.evWrongInputList,
4276 TosaErrorValidator.evWrongOutputList,
4277 ),
4278 "data_gen": {
4279 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4280 },
4281 "compliance": {"abs_error_normal_divisor": 2},
4282 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004283 # Elementwise Ternary operators
4284 "select": {
4285 "op": Op.SELECT,
4286 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004287 "build_fcn": (
4288 build_select,
4289 TosaTensorGen.tgBroadcastFuzz,
4290 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004291 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004292 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004293 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004294 "error_if_validators": (
4295 TosaErrorValidator.evRankMismatch,
4296 TosaErrorValidator.evWrongInputType,
4297 TosaErrorValidator.evWrongOutputType,
4298 TosaErrorValidator.evWrongInputList,
4299 TosaErrorValidator.evWrongOutputList,
4300 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004301 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004302 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004303 "data_gen": {
4304 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4305 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004306 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004307 # Comparison operators
4308 "equal": {
4309 "op": Op.EQUAL,
4310 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004311 "build_fcn": (
4312 build_comparison,
4313 TosaTensorGen.tgBroadcastFuzz,
4314 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004315 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004316 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004317 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004318 "error_if_validators": (
4319 TosaErrorValidator.evRankMismatch,
4320 TosaErrorValidator.evWrongInputType,
4321 TosaErrorValidator.evWrongOutputType,
4322 TosaErrorValidator.evWrongInputList,
4323 TosaErrorValidator.evWrongOutputList,
4324 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004325 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004326 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004327 "data_gen": {
4328 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4329 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004330 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004331 "greater_equal": {
4332 "op": Op.GREATER_EQUAL,
4333 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004334 "build_fcn": (
4335 build_comparison,
4336 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004337 TosaTensorValuesGen.tvgLazyGenDefault,
4338 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004339 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004340 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004341 "error_if_validators": (
4342 TosaErrorValidator.evRankMismatch,
4343 TosaErrorValidator.evWrongInputType,
4344 TosaErrorValidator.evWrongOutputType,
4345 TosaErrorValidator.evWrongInputList,
4346 TosaErrorValidator.evWrongOutputList,
4347 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004348 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004349 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004350 "data_gen": {
4351 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4352 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004353 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004354 "greater": {
4355 "op": Op.GREATER,
4356 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004357 "build_fcn": (
4358 build_comparison,
4359 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004360 TosaTensorValuesGen.tvgLazyGenDefault,
4361 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004362 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004363 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004364 "error_if_validators": (
4365 TosaErrorValidator.evRankMismatch,
4366 TosaErrorValidator.evWrongInputType,
4367 TosaErrorValidator.evWrongOutputType,
4368 TosaErrorValidator.evWrongInputList,
4369 TosaErrorValidator.evWrongOutputList,
4370 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004371 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004372 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004373 "data_gen": {
4374 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4375 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004376 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004377 # Reduction operators
4378 "reduce_all": {
4379 "op": Op.REDUCE_ALL,
4380 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004381 "build_fcn": (
4382 build_reduce,
4383 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004384 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004385 TosaArgGen.agAxis,
4386 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004387 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004388 "error_if_validators": (
4389 TosaErrorValidator.evAxisLargerRank,
4390 TosaErrorValidator.evAxisSmallerZero,
4391 TosaErrorValidator.evShapeOfAxisNotOne,
4392 TosaErrorValidator.evWrongInputType,
4393 TosaErrorValidator.evWrongOutputType,
4394 TosaErrorValidator.evWrongRank,
4395 TosaErrorValidator.evWrongInputList,
4396 TosaErrorValidator.evWrongOutputList,
4397 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004398 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004399 "reduce_any": {
4400 "op": Op.REDUCE_ANY,
4401 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004402 "build_fcn": (
4403 build_reduce,
4404 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004405 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004406 TosaArgGen.agAxis,
4407 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004408 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004409 "error_if_validators": (
4410 TosaErrorValidator.evAxisLargerRank,
4411 TosaErrorValidator.evAxisSmallerZero,
4412 TosaErrorValidator.evShapeOfAxisNotOne,
4413 TosaErrorValidator.evWrongInputType,
4414 TosaErrorValidator.evWrongOutputType,
4415 TosaErrorValidator.evWrongRank,
4416 TosaErrorValidator.evWrongInputList,
4417 TosaErrorValidator.evWrongOutputList,
4418 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004419 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004420 "reduce_max": {
4421 "op": Op.REDUCE_MAX,
4422 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004423 "build_fcn": (
4424 build_reduce,
4425 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004426 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004427 TosaArgGen.agAxis,
4428 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004429 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004430 "error_if_validators": (
4431 TosaErrorValidator.evAxisLargerRank,
4432 TosaErrorValidator.evAxisSmallerZero,
4433 TosaErrorValidator.evShapeOfAxisNotOne,
4434 TosaErrorValidator.evWrongInputType,
4435 TosaErrorValidator.evWrongOutputType,
4436 TosaErrorValidator.evWrongRank,
4437 TosaErrorValidator.evWrongInputList,
4438 TosaErrorValidator.evWrongOutputList,
4439 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004440 "data_gen": {
4441 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4442 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004443 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004444 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004445 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004446 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004447 "build_fcn": (
4448 build_reduce,
4449 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004450 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004451 TosaArgGen.agAxis,
4452 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004453 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004454 "error_if_validators": (
4455 TosaErrorValidator.evAxisLargerRank,
4456 TosaErrorValidator.evAxisSmallerZero,
4457 TosaErrorValidator.evShapeOfAxisNotOne,
4458 TosaErrorValidator.evWrongInputType,
4459 TosaErrorValidator.evWrongOutputType,
4460 TosaErrorValidator.evWrongRank,
4461 TosaErrorValidator.evWrongInputList,
4462 TosaErrorValidator.evWrongOutputList,
4463 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004464 "data_gen": {
4465 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4466 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004467 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004468 "reduce_product": {
4469 "op": Op.REDUCE_PRODUCT,
4470 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004471 "build_fcn": (
4472 build_reduce,
4473 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004474 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004475 TosaArgGen.agAxis,
4476 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004477 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004478 "error_if_validators": (
4479 TosaErrorValidator.evAxisLargerRank,
4480 TosaErrorValidator.evAxisSmallerZero,
4481 TosaErrorValidator.evShapeOfAxisNotOne,
4482 TosaErrorValidator.evWrongInputType,
4483 TosaErrorValidator.evWrongOutputType,
4484 TosaErrorValidator.evWrongRank,
4485 TosaErrorValidator.evWrongInputList,
4486 TosaErrorValidator.evWrongOutputList,
4487 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004488 "data_gen": {
4489 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4490 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004491 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004492 "reduce_sum": {
4493 "op": Op.REDUCE_SUM,
4494 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004495 "build_fcn": (
4496 build_reduce,
4497 TosaTensorGen.tgBasic,
4498 TosaTensorValuesGen.tvgReduceSum,
4499 TosaArgGen.agAxis,
4500 ),
James Ward24dbc422022-10-19 12:20:31 +01004501 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004502 "error_if_validators": (
4503 TosaErrorValidator.evAxisLargerRank,
4504 TosaErrorValidator.evAxisSmallerZero,
4505 TosaErrorValidator.evShapeOfAxisNotOne,
4506 TosaErrorValidator.evWrongInputType,
4507 TosaErrorValidator.evWrongOutputType,
4508 TosaErrorValidator.evWrongRank,
4509 TosaErrorValidator.evWrongInputList,
4510 TosaErrorValidator.evWrongOutputList,
4511 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004512 "data_gen": {
4513 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4514 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004515 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004516 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004517 "concat": {
4518 "op": Op.CONCAT,
4519 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004520 "build_fcn": (
4521 build_concat,
4522 TosaTensorGen.tgConcat,
4523 TosaTensorValuesGen.tvgConcat,
4524 TosaArgGen.agAxis,
4525 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004526 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004527 "error_if_validators": (
4528 TosaErrorValidator.evAxisLargerRank,
4529 TosaErrorValidator.evAxisSmallerZero,
4530 TosaErrorValidator.evConcatInputRankMismatch,
4531 TosaErrorValidator.evConcatShapeSumMismatch,
4532 TosaErrorValidator.evConcatInputDimMismatch,
4533 TosaErrorValidator.evWrongInputType,
4534 TosaErrorValidator.evWrongOutputType,
4535 TosaErrorValidator.evWrongOutputList,
4536 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004537 "data_gen": {
4538 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4539 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004540 },
4541 "pad": {
4542 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004543 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004544 "build_fcn": (
4545 build_pad,
4546 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004547 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004548 TosaArgGen.agPad,
4549 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004550 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004551 "error_if_validators": (
4552 TosaErrorValidator.evWrongInputType,
4553 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004554 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004555 TosaErrorValidator.evWrongOutputType,
4556 TosaErrorValidator.evWrongInputList,
4557 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004558 TosaErrorValidator.evRankMismatch,
4559 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004560 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004561 "data_gen": {
4562 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4563 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004564 },
Won Jeona21b2e82023-08-10 10:33:01 +00004565 "dim": {
4566 "op": Op.DIM,
4567 "operands": (1, 0),
4568 "build_fcn": (
4569 build_dim,
4570 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004571 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004572 TosaArgGen.agAxis,
4573 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004574 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004575 "error_if_validators": (
4576 TosaErrorValidator.evAxisLargerRank,
4577 TosaErrorValidator.evAxisSmallerZero,
4578 TosaErrorValidator.evWrongInputType,
4579 TosaErrorValidator.evWrongInputList,
4580 TosaErrorValidator.evWrongOutputList,
4581 TosaErrorValidator.evWrongRank,
4582 ),
4583 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004584 "reshape": {
4585 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004586 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004587 "build_fcn": (
4588 build_reshape,
4589 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004590 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004591 TosaArgGen.agReshape,
4592 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004593 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004594 "error_if_validators": (
4595 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4596 TosaErrorValidator.evWrongInputType,
4597 TosaErrorValidator.evWrongOutputType,
4598 TosaErrorValidator.evWrongInputList,
4599 TosaErrorValidator.evWrongOutputList,
4600 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004601 "data_gen": {
4602 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4603 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004604 },
4605 "reverse": {
4606 "op": Op.REVERSE,
4607 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004608 "build_fcn": (
4609 build_reverse,
4610 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004611 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004612 TosaArgGen.agAxis,
4613 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004614 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004615 "error_if_validators": (
4616 TosaErrorValidator.evAxisSmallerZero,
4617 TosaErrorValidator.evAxisLargerRank,
4618 TosaErrorValidator.evWrongInputType,
4619 TosaErrorValidator.evWrongOutputType,
4620 TosaErrorValidator.evWrongInputList,
4621 TosaErrorValidator.evWrongOutputList,
4622 ),
evacha0198477222024-01-26 12:25:32 +00004623 "data_gen": {
4624 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4625 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004626 },
4627 "slice": {
4628 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004629 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004630 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004631 "build_fcn": (
4632 build_slice,
4633 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004634 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004635 TosaArgGen.agSlice,
4636 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004637 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004638 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004639 # TODO Turn off these error categories for now as the reference
4640 # model cannot allocate memory space for empty tensor. We probably
4641 # can report an accurate error messege at the right place during
4642 # exeuction.
4643 # TosaErrorValidator.evStartSmallerZero,
4644 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004645 TosaErrorValidator.evStartSizeOutsideBounds,
4646 TosaErrorValidator.evSizeOutputShapeMismatch,
4647 TosaErrorValidator.evInputSizeStartLengthMismatch,
4648 TosaErrorValidator.evWrongRank,
4649 TosaErrorValidator.evWrongInputType,
4650 TosaErrorValidator.evWrongOutputType,
4651 TosaErrorValidator.evWrongInputList,
4652 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004653 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004654 ),
evacha017f7d4252024-01-24 12:08:09 +00004655 "data_gen": {
4656 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4657 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004658 },
4659 "tile": {
4660 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004661 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004662 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004663 "build_fcn": (
4664 build_tile,
4665 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004666 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004667 TosaArgGen.agTile,
4668 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004669 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004670 "error_if_validators": (
4671 TosaErrorValidator.evWrongInputType,
4672 TosaErrorValidator.evWrongOutputType,
4673 TosaErrorValidator.evWrongInputList,
4674 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004675 TosaErrorValidator.evRankMismatch,
4676 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004677 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004678 "data_gen": {
4679 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4680 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004681 },
4682 "transpose": {
4683 "op": Op.TRANSPOSE,
4684 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004685 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004686 "build_fcn": (
4687 build_transpose,
4688 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004689 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004690 TosaArgGen.agTranspose,
4691 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004692 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004693 "error_if_validators": (
4694 TosaErrorValidator.evIndexOutsideBounds,
4695 TosaErrorValidator.evIndexUsedTwice,
4696 TosaErrorValidator.evWrongInputType,
4697 TosaErrorValidator.evWrongOutputType,
4698 TosaErrorValidator.evWrongInputList,
4699 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004700 TosaErrorValidator.evWrongRank,
4701 TosaErrorValidator.evRankMismatch,
4702 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004703 ),
evacha0198477222024-01-26 12:25:32 +00004704 "data_gen": {
4705 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4706 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004707 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004708 # Data nodes
4709 "const": {
4710 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004711 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004712 "build_fcn": (
4713 build_const,
4714 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004715 TosaTensorValuesGen.tvgLazyGenDefault,
4716 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004717 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004718 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004719 "data_gen": {
4720 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4721 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004722 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004723 "identity": {
4724 "op": Op.IDENTITY,
4725 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004726 "build_fcn": (
4727 build_unary,
4728 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004729 TosaTensorValuesGen.tvgLazyGenDefault,
4730 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004731 ),
evacha011adff832024-03-06 17:33:44 +00004732 "types": TYPE_FIB + [DType.INT4, DType.INT48],
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004733 "data_gen": {
4734 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4735 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004736 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004737 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004738 "gather": {
4739 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004740 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004741 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004742 "build_fcn": (
4743 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004744 TosaTensorGen.tgGather,
4745 TosaTensorValuesGen.tvgGather,
4746 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004747 ),
James Ward24dbc422022-10-19 12:20:31 +01004748 "types": (
4749 DType.INT8,
4750 DType.INT16,
4751 DType.INT32,
4752 DType.FP16,
4753 DType.BF16,
4754 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004755 DType.FP8E4M3,
4756 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004757 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004758 "error_if_validators": (
4759 TosaErrorValidator.evWrongInputType,
4760 TosaErrorValidator.evWrongOutputType,
4761 TosaErrorValidator.evWrongInputList,
4762 TosaErrorValidator.evWrongOutputList,
4763 TosaErrorValidator.evWrongRank,
4764 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004765 "data_gen": {
4766 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4767 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004768 },
4769 "scatter": {
4770 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004771 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004772 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004773 "build_fcn": (
4774 build_scatter,
4775 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004776 TosaTensorValuesGen.tvgScatter,
4777 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004778 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004779 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004780 "error_if_validators": (
4781 TosaErrorValidator.evWrongInputType,
4782 TosaErrorValidator.evWrongOutputType,
4783 TosaErrorValidator.evWrongInputList,
4784 TosaErrorValidator.evWrongOutputList,
4785 TosaErrorValidator.evWrongRank,
4786 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004787 "data_gen": {
4788 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4789 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004790 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004791 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004792 "resize": {
4793 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004794 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004795 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004796 "build_fcn": (
4797 build_resize,
4798 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004799 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004800 TosaArgGen.agResize,
4801 ),
James Ward24dbc422022-10-19 12:20:31 +01004802 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004803 "invalid_test_validators": (
4804 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004805 ),
4806 "error_if_validators": (
4807 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004808 TosaErrorValidator.evScaleSmallerEqualZero,
4809 TosaErrorValidator.evScaleNLargerMax,
4810 TosaErrorValidator.evScaleDLargerMax,
4811 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004812 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004813 TosaErrorValidator.evBorderSmallerMin,
4814 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004815 TosaErrorValidator.evWrongInputType,
4816 TosaErrorValidator.evWrongOutputType,
4817 TosaErrorValidator.evWrongRank,
4818 TosaErrorValidator.evWrongInputList,
4819 TosaErrorValidator.evWrongOutputList,
4820 TosaErrorValidator.evBatchMismatch,
4821 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004822 TosaErrorValidator.evResizeOutputShapeMismatch,
4823 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004824 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004825 "data_gen": {
4826 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4827 },
4828 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004829 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004830 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004831 "cast": {
4832 "op": Op.CAST,
4833 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004834 "build_fcn": (
4835 build_cast,
4836 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004837 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004838 TosaArgGen.agCast,
4839 ),
James Ward8b390432022-08-12 20:48:56 +01004840 "types": (
4841 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004842 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004843 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004844 DType.INT8,
4845 DType.INT16,
4846 DType.INT32,
4847 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004848 DType.FP8E4M3,
4849 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004850 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004851 "error_if_validators": (
4852 TosaErrorValidator.evWrongInputType,
4853 TosaErrorValidator.evWrongOutputType,
4854 TosaErrorValidator.evWrongInputList,
4855 TosaErrorValidator.evWrongOutputList,
4856 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004857 "data_gen": {
4858 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4859 },
4860 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004861 },
4862 "rescale": {
4863 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004864 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004865 "build_fcn": (
4866 build_rescale,
4867 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004868 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004869 TosaArgGen.agRescale,
4870 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004871 "types": [
4872 DType.UINT8,
4873 DType.INT8,
4874 DType.INT16,
4875 DType.INT32,
4876 DType.INT48,
4877 DType.UINT16,
4878 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004879 "error_if_validators": (
4880 TosaErrorValidator.evInputZeroPointNotZero,
4881 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004882 TosaErrorValidator.evU16InputZeroPointNotValid,
4883 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004884 TosaErrorValidator.evScaleTrue,
4885 TosaErrorValidator.evScaleNotTrue,
4886 TosaErrorValidator.evWrongInputType,
4887 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004888 TosaErrorValidator.evWrongInputList,
4889 TosaErrorValidator.evWrongOutputList,
4890 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004891 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004892 # Custom
4893 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004894 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004895 # Two varients of cond_if, one that generates one of two constant tensors (no
4896 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4897 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004898 "cond_if_const": {
4899 "op": Op.COND_IF,
4900 "operands": (0, 2),
4901 "build_fcn": (
4902 build_cond_if_const,
4903 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004904 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004905 TosaArgGen.agCondIf,
4906 ),
4907 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004908 "error_if_validators": (
4909 TosaErrorValidator.evOutputListThenGraphMismatch,
4910 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004911 TosaErrorValidator.evCondIfCondNotMatchingBool,
4912 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004913 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004914 },
4915 "cond_if_binary": {
4916 "op": Op.COND_IF,
4917 "operands": (2, 0),
4918 "build_fcn": (
4919 build_cond_if_binary,
4920 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004921 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004922 TosaArgGen.agCondIf,
4923 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004924 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004925 "error_if_validators": (
4926 TosaErrorValidator.evInputListThenGraphMismatch,
4927 TosaErrorValidator.evInputListElseGraphMismatch,
4928 TosaErrorValidator.evOutputListThenGraphMismatch,
4929 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004930 TosaErrorValidator.evCondIfCondNotMatchingBool,
4931 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004932 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004933 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004934 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004935 "while_loop": {
4936 "op": Op.WHILE_LOOP,
4937 "operands": (0, 1),
4938 "build_fcn": (
4939 build_while_loop,
4940 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004941 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004942 TosaArgGen.agWhileLoop,
4943 ),
4944 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004945 "error_if_validators": (
4946 TosaErrorValidator.evInputListOutputListMismatch,
4947 TosaErrorValidator.evInputListCondGraphMismatch,
4948 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4949 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4950 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004951 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004952 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004953 },
Luke Hutton57287132023-02-06 14:54:18 +00004954 "fft2d": {
4955 "op": Op.FFT2D,
4956 "operands": (2, 0),
4957 "rank": (3, 3),
4958 "build_fcn": (
4959 build_fft2d,
4960 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004961 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004962 TosaArgGen.agFFT2d,
4963 ),
4964 "types": [DType.FP32],
4965 "error_if_validators": (
4966 TosaErrorValidator.evWrongInputType,
4967 TosaErrorValidator.evWrongOutputType,
4968 TosaErrorValidator.evWrongInputList,
4969 TosaErrorValidator.evWrongOutputList,
4970 TosaErrorValidator.evWrongRank,
4971 TosaErrorValidator.evBatchMismatch,
4972 TosaErrorValidator.evKernelNotPowerOfTwo,
4973 TosaErrorValidator.evFFTInputShapeMismatch,
4974 TosaErrorValidator.evFFTOutputShapeMismatch,
4975 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004976 "data_gen": {
4977 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4978 },
Luke Hutton57287132023-02-06 14:54:18 +00004979 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004980 "rfft2d": {
4981 "op": Op.RFFT2D,
4982 "operands": (1, 0),
4983 "rank": (3, 3),
4984 "build_fcn": (
4985 build_rfft2d,
4986 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004987 TosaTensorValuesGen.tvgLazyGenDefault,
4988 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004989 ),
4990 "types": [DType.FP32],
4991 "error_if_validators": (
4992 TosaErrorValidator.evWrongInputType,
4993 TosaErrorValidator.evWrongOutputType,
4994 TosaErrorValidator.evWrongInputList,
4995 TosaErrorValidator.evWrongOutputList,
4996 TosaErrorValidator.evWrongRank,
4997 TosaErrorValidator.evBatchMismatch,
4998 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004999 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00005000 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00005001 "data_gen": {
5002 "fp": (gtu.DataGenType.DOT_PRODUCT,),
5003 },
Luke Hutton261b7b62023-01-10 14:50:31 +00005004 },
Won Jeon74342e52024-01-09 00:34:40 +00005005 # Shape
5006 "add_shape": {
5007 "op": Op.ADD_SHAPE,
5008 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005009 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005010 "build_fcn": (
5011 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005012 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005013 TosaTensorValuesGen.tvgAddSub,
5014 TosaArgGen.agNone,
5015 ),
5016 "types": [DType.SHAPE],
5017 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5018 },
5019 "sub_shape": {
5020 "op": Op.SUB_SHAPE,
5021 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005022 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005023 "build_fcn": (
5024 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005025 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005026 TosaTensorValuesGen.tvgAddSub,
5027 TosaArgGen.agNone,
5028 ),
5029 "types": [DType.SHAPE],
5030 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5031 },
5032 "mul_shape": {
5033 "op": Op.MUL_SHAPE,
5034 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005035 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005036 "build_fcn": (
5037 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005038 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005039 TosaTensorValuesGen.tvgMul,
5040 TosaArgGen.agNone,
5041 ),
5042 "types": [DType.SHAPE],
5043 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5044 },
5045 "div_shape": {
5046 "op": Op.DIV_SHAPE,
5047 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005048 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005049 "build_fcn": (
5050 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005051 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005052 TosaTensorValuesGen.tvgIntDiv,
5053 TosaArgGen.agNone,
5054 ),
5055 "types": [DType.SHAPE],
5056 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5057 },
5058 "concat_shape": {
5059 "op": Op.CONCAT_SHAPE,
5060 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005061 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005062 "build_fcn": (
5063 build_concat,
5064 TosaTensorGen.tgConcat,
5065 TosaTensorValuesGen.tvgConcat,
5066 TosaArgGen.agNone,
5067 ),
5068 "types": [DType.SHAPE],
5069 "error_if_validators": (),
5070 },
5071 "const_shape": {
5072 "op": Op.CONST_SHAPE,
5073 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005074 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005075 "build_fcn": (
5076 build_const,
5077 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005078 TosaTensorValuesGen.tvgLazyGenDefault,
5079 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005080 ),
5081 "types": [DType.SHAPE],
5082 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005083 }
5084
Kevin Cheng550ccc52021-03-03 11:21:43 -08005085
Eric Kunzee5e26762020-10-13 16:11:07 -07005086class OutputShaper:
5087 # Methods in this class compute the expected output shape and datatype
5088 # for common classes of operations
5089 def __init__(self):
5090 pass
5091
5092 # These methods return arguments that can be used for
5093 # creating a new output tensor
5094 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005095 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5096 if error_name != ErrorIf.RankMismatch:
5097 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005098 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005099
5100 shape = []
5101 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005102 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005103 shape.append(b.shape[i])
5104 else:
5105 shape.append(a.shape[i])
5106
Jerry Ge135c9552023-05-23 20:59:32 +00005107 fuzz_idx = rng.integers(0, len(a.shape))
5108 if error_name == ErrorIf.DimensionMismatch:
5109 shape[fuzz_idx] += 1
5110
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005111 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005112 all_dtypes = [
5113 DType.INT8,
5114 DType.INT16,
5115 DType.INT32,
5116 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005117 DType.FP16,
5118 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005119 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005120 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005121 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5122 outputDType = rng.choice(wrong_dtypes)
5123 else:
5124 outputDType = a.dtype
5125
5126 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005127
5128 @staticmethod
5129 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005130 assert len(a.shape) == len(b.shape)
5131 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005132
5133 shape = []
5134 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005135 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005136 shape.append(a.shape[i])
5137
Kevin Cheng550ccc52021-03-03 11:21:43 -08005138 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005139
5140 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005141 def unaryOp(ser, rng, a, error_name=None):
5142 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005143 all_dtypes = [
5144 DType.INT8,
5145 DType.INT16,
5146 DType.INT32,
5147 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005148 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005149 DType.FP16,
5150 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005151 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005152 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5153 outputDType = rng.choice(wrong_dtypes)
5154 else:
5155 outputDType = a.dtype
5156
5157 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005158
5159 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005160 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005161 if error_name != ErrorIf.RankMismatch:
5162 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005163 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005164
5165 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005166 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005167 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005168 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5169 else:
5170 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005171
Jerry Ge135c9552023-05-23 20:59:32 +00005172 fuzz_idx = rng.integers(0, len(a.shape))
5173 if error_name == ErrorIf.DimensionMismatch:
5174 shape[fuzz_idx] += 1
5175
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005176 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005177 all_dtypes = [
5178 DType.INT8,
5179 DType.INT16,
5180 DType.INT32,
5181 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005182 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005183 DType.FP16,
5184 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005185 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005186 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5187 outputDType = rng.choice(wrong_dtypes)
5188 else:
5189 outputDType = a.dtype
5190
5191 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005192
5193 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005194 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005195 if error_name != ErrorIf.RankMismatch:
5196 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005197 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005198
5199 # Do broadcast
5200 shape = []
5201 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005202 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005203 shape.append(b.shape[i])
5204 else:
5205 shape.append(a.shape[i])
5206
Jerry Ge135c9552023-05-23 20:59:32 +00005207 fuzz_idx = rng.integers(0, len(a.shape))
5208 if error_name == ErrorIf.DimensionMismatch:
5209 shape[fuzz_idx] += 1
5210
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005211 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005212 wrong_dtypes = [
5213 DType.INT8,
5214 DType.INT16,
5215 DType.INT32,
5216 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005217 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005218 DType.FP16,
5219 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005220 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005221 outputDType = rng.choice(wrong_dtypes)
5222 else:
5223 outputDType = DType.BOOL
5224
5225 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005226
5227 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005228 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005229 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005230 if error_name not in [
5231 ErrorIf.AxisSmallerZero,
5232 ErrorIf.AxisLargerRank,
5233 ErrorIf.ShapeOfAxisNotOne,
5234 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005235 shape[axis] = 1
5236 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5237 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005238
Matthew Haddond6ce7252021-09-29 15:35:44 +01005239 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005240 all_dtypes = [
5241 DType.INT8,
5242 DType.INT16,
5243 DType.INT32,
5244 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005245 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005246 DType.FP16,
5247 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005248 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005249 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5250 outputDType = rng.choice(wrong_dtypes)
5251 else:
5252 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005253
Matthew Haddond6ce7252021-09-29 15:35:44 +01005254 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005255
5256 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005257 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005258 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005259
5260 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5261 del shape[axis]
5262
5263 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5264 remove = rng.choice([True, False])
5265 if remove and len(shape) > 1:
5266 del shape[0]
5267 else:
5268 shape.append(1)
5269 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5270 for i in range(len(shape)):
5271 shape[i] = shape[i] + rng.integers(1, 10)
5272
5273 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005274 all_dtypes = [
5275 DType.INT8,
5276 DType.INT16,
5277 DType.INT32,
5278 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005279 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005280 DType.FP16,
5281 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005282 DType.FP8E4M3,
5283 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005284 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005285 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5286 outputDType = rng.choice(wrong_dtypes)
5287 else:
5288 outputDType = DType.INT32
5289
5290 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005291
5292 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005293 def conv2dOp(
5294 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5295 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005296
5297 # IFM: NHWC
5298 # Filter: OHWI
5299 # OFM: NHWC
5300
Kevin Cheng550ccc52021-03-03 11:21:43 -08005301 h = (
5302 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005303 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005304 + padding[0]
5305 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005306 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005307 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005308
Kevin Cheng550ccc52021-03-03 11:21:43 -08005309 w = (
5310 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005311 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005312 + padding[2]
5313 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005314 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005315 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005316
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005317 if error_name == ErrorIf.ConvOutputShapeMismatch:
5318 choices = [1, 2, 3]
5319 change = rng.choice(choices)
5320 # increment in multiples of stride to not hit non-integer error case
5321 if change in [1, 3]:
5322 h = h + (rng.choice(choices) * strides[0])
5323 if change in [2, 3]:
5324 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005325
Eric Kunzee5e26762020-10-13 16:11:07 -07005326 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5327
James Ward8b390432022-08-12 20:48:56 +01005328 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005329 # Pick some potentially correct output dtype if input type is incorrect
5330 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005331 else:
James Ward8b390432022-08-12 20:48:56 +01005332 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005333
5334 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005335 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005336 excludes = [DType.FP16, DType.FP32]
Jeremy Johnson80fd9b82024-03-12 11:46:50 +00005337 elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
Won Jeon2c34b462024-02-06 18:37:00 +00005338 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005339 else:
5340 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005341 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005342 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005343
Kevin Cheng550ccc52021-03-03 11:21:43 -08005344 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005345
5346 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005347 def conv3dOp(
5348 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5349 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005350
5351 # IFM: NDHWC
5352 # Filter: ODHWI
5353 # OFM: NDHWC
5354
5355 d = (
5356 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005357 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005358 + padding[0]
5359 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005360 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005361 ) // strides[0] + 1
5362
5363 h = (
5364 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005365 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005366 + padding[2]
5367 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005368 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005369 ) // strides[1] + 1
5370
5371 w = (
5372 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005373 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005374 + padding[4]
5375 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005376 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005377 ) // strides[2] + 1
5378
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005379 if error_name == ErrorIf.ConvOutputShapeMismatch:
5380 choices = [1, 2, 3, 4]
5381 change = rng.choice(choices)
5382 # increment in multiples of stride to not hit non-integer error case
5383 if change in [1, 4]:
5384 d = d + (rng.choice(choices) * strides[0])
5385 if change in [2, 4]:
5386 h = h + (rng.choice(choices) * strides[1])
5387 if change in [3, 4]:
5388 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005389
Kevin Cheng1533b852021-09-01 12:51:58 -07005390 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5391
James Ward8b390432022-08-12 20:48:56 +01005392 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005393 # Pick some potentially correct output dtype if input type is incorrect
5394 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005395 else:
James Ward8b390432022-08-12 20:48:56 +01005396 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005397
5398 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005399 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005400 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005401 else:
5402 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005403 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005404 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005405
5406 return ser.addOutput(ofm_shape, out_dtype)
5407
5408 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005409 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005410 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005411 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005412 # IFM: NHWC
5413 # Filter: HWCM
5414 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005415
Kevin Cheng550ccc52021-03-03 11:21:43 -08005416 h = (
5417 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005418 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005419 + padding[0]
5420 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005421 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005422 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005423
Kevin Cheng550ccc52021-03-03 11:21:43 -08005424 w = (
5425 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005426 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005427 + padding[2]
5428 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005429 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005430 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005431
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005432 if error_name == ErrorIf.ConvOutputShapeMismatch:
5433 choices = [1, 2, 3]
5434 change = rng.choice(choices)
5435 # increment in multiples of stride to not hit non-integer error case
5436 if change in [1, 3]:
5437 h = h + (rng.choice(choices) * strides[0])
5438 if change in [2, 3]:
5439 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005440
Eric Kunzee5e26762020-10-13 16:11:07 -07005441 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5442
James Ward8b390432022-08-12 20:48:56 +01005443 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005444 # Pick some potentially correct output dtype if input type is incorrect
5445 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005446 else:
James Ward8b390432022-08-12 20:48:56 +01005447 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005448
5449 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005450 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005451 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005452 else:
5453 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005454 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005455 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005456
Kevin Cheng550ccc52021-03-03 11:21:43 -08005457 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005458
5459 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005460 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005461 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005462 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005463 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005464 h = 1
5465 w = 1
5466 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005467 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5468 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005469
5470 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005471 choices = [1, 2, 3]
5472 change = rng.choice(choices)
5473 # increment in multiples of stride to not hit non-integer error case
5474 if change in [1, 3]:
5475 h = h + (rng.choice(choices) * stride[0])
5476 if change in [2, 3]:
5477 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005478 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005479
5480 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005481 all_dtypes = [
5482 DType.INT8,
5483 DType.INT16,
5484 DType.INT32,
5485 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005486 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005487 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005488 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005489 DType.FP8E4M3,
5490 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005491 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005492 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5493 outputDType = rng.choice(wrong_dtypes)
5494 else:
5495 outputDType = ifm.dtype
5496
5497 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005498
5499 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005500 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005501 # input: N, IC
5502 # filter: OC, IC
5503 # output: N, OC
5504
5505 output_shape = [input.shape[0], filter.shape[0]]
5506
James Ward8b390432022-08-12 20:48:56 +01005507 # Validated in arg_gen (also invalidated for ErrorIf)
5508 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005509
Kevin Cheng550ccc52021-03-03 11:21:43 -08005510 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005511
5512 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005513 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005514 # a: N, H, C
5515 # b: N, C, W
5516 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005517
Kevin Cheng2d60f002021-06-09 14:18:32 -07005518 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005519
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005520 if error_name == ErrorIf.WrongOutputType:
5521 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005522 incorrect_types = (
5523 DType.INT4,
5524 DType.INT8,
5525 DType.INT16,
5526 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005527 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005528 DType.FP16,
5529 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005530 DType.FP8E4M3,
5531 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005532 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005533 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005534 incorrect_types = (
5535 DType.INT4,
5536 DType.INT8,
5537 DType.INT16,
5538 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005539 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005540 DType.FP16,
5541 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005542 DType.FP8E4M3,
5543 DType.FP8E5M2,
5544 )
5545 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5546 incorrect_types = (
5547 DType.INT4,
5548 DType.INT8,
5549 DType.INT16,
5550 DType.INT32,
5551 DType.INT48,
5552 DType.FP32,
5553 DType.BF16,
5554 DType.FP8E4M3,
5555 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005556 )
James Ward24dbc422022-10-19 12:20:31 +01005557 elif (
5558 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5559 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005560 incorrect_types = (
5561 DType.INT4,
5562 DType.INT8,
5563 DType.INT16,
5564 DType.INT32,
5565 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005566 DType.FP8E4M3,
5567 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005568 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005569 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005570 elif error_name == ErrorIf.WrongInputType:
5571 # Pick some potentially correct output dtype if input type is incorrect
5572 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005573 else:
James Ward8b390432022-08-12 20:48:56 +01005574 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005575
Kevin Cheng550ccc52021-03-03 11:21:43 -08005576 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005577
5578 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005579 def concatOp(ser, rng, axis, inputs, error_name=None):
5580 input1 = inputs[0]
5581 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005582
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005583 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005584 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005585 if not (
5586 # unable to concat tensors of different ranks
5587 error_name == ErrorIf.ConcatInputRankMismatch
5588 # unable to concat tensors along an invalid axis
5589 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005590 ):
5591 for tensor in remaining_inputs:
5592 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005593
Matthew Haddon01c359d2021-10-15 16:30:48 +01005594 if error_name == ErrorIf.ConcatShapeSumMismatch:
5595 output_shape[axis] += rng.integers(5, 10)
5596
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005597 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005598 all_dtypes = {
5599 DType.INT8,
5600 DType.INT16,
5601 DType.INT32,
5602 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005603 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005604 DType.FP16,
5605 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005606 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005607 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5608 outputDType = rng.choice(wrong_dtypes)
5609 else:
5610 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005611
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005612 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005613
5614 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005615 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005616
5617 output_shape = a.shape.copy()
5618
5619 for i in range(len(output_shape)):
5620 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5621
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005622 if error_name == ErrorIf.PadOutputShapeMismatch:
5623 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005624 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005625 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005626 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005627
Matthew Haddone807aae2021-10-11 18:12:58 +01005628 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005629 all_dtypes = [
5630 DType.INT8,
5631 DType.INT16,
5632 DType.INT32,
5633 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005634 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005635 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005636 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005637 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005638 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5639 outputDType = rng.choice(wrong_dtypes)
5640 else:
5641 outputDType = a.dtype
5642
5643 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005644
5645 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005646 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005647 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005648
5649 if error_name == ErrorIf.WrongOutputType:
5650 all_dtypes = [
5651 DType.INT8,
5652 DType.INT16,
5653 DType.INT32,
5654 DType.INT48,
5655 DType.FP32,
5656 DType.FP16,
5657 DType.BF16,
5658 ]
5659 wrong_dtypes = list(set(all_dtypes))
5660 outputDType = rng.choice(wrong_dtypes)
5661 else:
5662 outputDType = DType.SHAPE
5663
5664 return ser.addOutput(output_shape, outputDType)
5665
5666 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005667 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005668 output_shape = shape.copy()
5669
Matthew Haddone807aae2021-10-11 18:12:58 +01005670 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5671 for i in range(len(output_shape)):
5672 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5673
5674 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005675 all_dtypes = [
5676 DType.INT8,
5677 DType.INT16,
5678 DType.INT32,
5679 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005680 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005681 DType.FP16,
5682 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005683 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005684 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5685 outputDType = rng.choice(wrong_dtypes)
5686 else:
5687 outputDType = a.dtype
5688
5689 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005690
5691 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005692 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005693
Matthew Haddone807aae2021-10-11 18:12:58 +01005694 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005695 all_dtypes = [
5696 DType.INT8,
5697 DType.INT16,
5698 DType.INT32,
5699 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005700 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005701 DType.FP16,
5702 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005703 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005704 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005705 outputDType = rng.choice(wrong_dtypes)
5706 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005707 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005708
Luke Huttona4e48ca2023-02-22 11:53:48 +00005709 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005710 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005711 for index in range(len(output_shape)):
5712 if output_shape[index] <= 2:
5713 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5714 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005715 output_shape[index] = output_shape[index] + rng.choice(
5716 [-2, -1, 1, 2]
5717 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005718 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5719 output_shape = input.shape.copy()
5720 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005721 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005722
5723 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005724
5725 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005726 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005727
5728 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005729 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005730
5731 for i in range(len(output_shape)):
5732 output_shape[i] = a.shape[i] * multiples[i]
5733
Luke Huttona4e48ca2023-02-22 11:53:48 +00005734 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005735 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005736
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005737 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005738 all_dtypes = [
5739 DType.INT8,
5740 DType.INT16,
5741 DType.INT32,
5742 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005743 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005744 DType.FP16,
5745 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005746 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005747 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5748 outputDType = rng.choice(wrong_dtypes)
5749 else:
5750 outputDType = a.dtype
5751
5752 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005753
5754 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005755 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005756 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005757
Kevin Cheng550ccc52021-03-03 11:21:43 -08005758 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005759
Luke Huttona4e48ca2023-02-22 11:53:48 +00005760 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005761 for i in range(len(output_shape)):
5762 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005763
Luke Huttona4e48ca2023-02-22 11:53:48 +00005764 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5765 for i in range(len(output_shape)):
5766 output_shape[i] += rng.integers(1, 10)
5767 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005768 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005769
Matthew Haddone807aae2021-10-11 18:12:58 +01005770 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005771 all_dtypes = [
5772 DType.INT8,
5773 DType.INT16,
5774 DType.INT32,
5775 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005776 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005777 DType.FP16,
5778 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005779 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005780 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5781 outputDType = rng.choice(wrong_dtypes)
5782 else:
5783 outputDType = a.dtype
5784
5785 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005786
5787 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005788 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005789 if error_name != ErrorIf.WrongRank:
5790 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005791 assert len(indices.shape) == 2
5792 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005793
Kevin Cheng77d0f762020-11-24 10:26:32 -08005794 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5795
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005796 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005797 all_dtypes = [
5798 DType.INT8,
5799 DType.INT16,
5800 DType.INT32,
5801 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005802 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005803 DType.FP16,
5804 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005805 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005806 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5807 outputDType = rng.choice(wrong_dtypes)
5808 else:
5809 outputDType = values.dtype
5810
5811 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005812
5813 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005814 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005815 if error_name != ErrorIf.WrongRank:
5816 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005817 assert len(indices.shape) == 2
5818 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005819 assert values_in.shape[0] == indices.shape[0] # N
5820 assert input.shape[1] == indices.shape[1] # W
5821 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005822
5823 output_shape = values_in.shape
5824
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005825 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005826 all_dtypes = [
5827 DType.INT8,
5828 DType.INT16,
5829 DType.INT32,
5830 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005831 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005832 DType.FP16,
5833 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005834 DType.FP8E4M3,
5835 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005836 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005837 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5838 outputDType = rng.choice(wrong_dtypes)
5839 else:
5840 outputDType = values_in.dtype
5841
5842 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005843
5844 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005845 def tableOp(ser, rng, input, error_name=None):
5846 # Same shape as the input, dtype dependent on input dtype
5847 if error_name != ErrorIf.WrongInputType:
5848 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005849 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005850 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005851 wrong_dtypes = [
5852 DType.INT8,
5853 DType.INT16,
5854 DType.INT32,
5855 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005856 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005857 DType.FP16,
5858 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005859 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005860 wrong_dtypes.remove(output_dtype)
5861 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005862 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005863
5864 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005865 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005866 serializer,
5867 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005868 input,
5869 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005870 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005871 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005872 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005873 input_dtype,
5874 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005875 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005876 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005877 # Calculate OH, OW
5878 scale_y_n = scale[0]
5879 scale_y_d = scale[1]
5880 scale_x_n = scale[2]
5881 scale_x_d = scale[3]
5882 if error_name == ErrorIf.ScaleSmallerEqualZero:
5883 scale_y_n = max(scale_y_n, 1)
5884 scale_y_d = max(scale_y_d, 1)
5885 scale_x_n = max(scale_x_n, 1)
5886 scale_x_d = max(scale_x_d, 1)
5887
5888 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5889 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5890
5891 if error_name is not None:
5892 # Make sure the output tensor is valid, which can occur when
5893 # scale, offset or border have been changed for ERROR_IFs
5894 oh = max(oh, 1)
5895 ow = max(ow, 1)
5896 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005897 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5898 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005899
5900 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5901 choices = [1, 2, 3]
5902 change = rng.choice(choices)
5903 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5904 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005905 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005906 oh -= scale_y_d
5907 assert oh > 0 # Should have been caught in agResize
5908 else:
5909 oh += scale_y_d
5910 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005911 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005912 ow -= scale_x_d
5913 assert ow > 0 # Should have been caught in agResize
5914 else:
5915 ow += scale_x_d
5916
Matthew Haddon848efb42021-09-09 12:30:53 +01005917 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005918 output_dims = [
5919 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005920 oh,
5921 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005922 input.shape[0],
5923 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005924 elif error_name == ErrorIf.BatchMismatch:
5925 output_dims = [
5926 input.shape[0] + rng.integers(1, 10),
5927 oh,
5928 ow,
5929 input.shape[3],
5930 ]
5931 elif error_name == ErrorIf.ChannelMismatch:
5932 output_dims = [
5933 input.shape[0],
5934 oh,
5935 ow,
5936 input.shape[3] + rng.integers(1, 10),
5937 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005938 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005939 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005940
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005941 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005942
5943 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005944 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005945 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005946
5947 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005948 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005949 if error_name == ErrorIf.ConvOutputShapeMismatch:
5950 choices = [1, 2, 3]
5951 change = rng.choice(choices)
5952 if change in [1, 3]:
5953 output_shape[1] = output_shape[1] + rng.choice(choices)
5954 if change in [2, 3]:
5955 output_shape[2] = output_shape[2] + rng.choice(choices)
5956
James Ward8b390432022-08-12 20:48:56 +01005957 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005958 # Pick some potentially correct output dtype if input type is incorrect
5959 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005960 else:
James Ward8b390432022-08-12 20:48:56 +01005961 out_dtype = accum_dtype
Les Bell0e027d42021-11-09 14:42:14 +00005962
5963 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005964 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005965 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005966 else:
5967 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005968 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005969 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005970
Kevin Cheng550ccc52021-03-03 11:21:43 -08005971 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005972
5973 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005974 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5975 outputs = []
5976
5977 assert ifm1.dtype == ifm2.dtype
5978 input_dtype = ifm1.dtype
5979
5980 if error_name != ErrorIf.FFTInputShapeMismatch:
5981 assert ifm1.shape == ifm2.shape
5982
5983 input_shape = ifm1.shape
5984 if error_name != ErrorIf.WrongRank:
5985 assert len(input_shape) == 3
5986
5987 output_shape = input_shape.copy()
5988 output_dtype = input_dtype
5989
5990 if error_name == ErrorIf.WrongOutputType:
5991 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005992 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005993 output_dtype = rng.choice(wrong_dtypes)
5994 elif error_name == ErrorIf.BatchMismatch:
5995 output_shape[0] += rng.integers(1, 10)
5996 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5997 modify_dim = rng.choice([1, 2])
5998 output_shape[modify_dim] += rng.integers(1, 10)
5999
6000 outputs.append(serializer.addOutput(output_shape, output_dtype))
6001 outputs.append(serializer.addOutput(output_shape, output_dtype))
6002 return outputs
6003
6004 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00006005 def rfft2dOp(serializer, rng, value, error_name=None):
6006 outputs = []
6007
6008 input_shape = value.shape
6009 if error_name != ErrorIf.WrongRank:
6010 assert len(input_shape) == 3
6011
6012 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
6013
6014 output_dtype = value.dtype
6015 if error_name == ErrorIf.WrongOutputType:
6016 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006017 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00006018 output_dtype = rng.choice(wrong_dtypes)
6019 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00006020 output_shape[0] += rng.integers(1, 10)
6021 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6022 modify_dim = rng.choice([1, 2])
6023 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00006024
6025 outputs.append(serializer.addOutput(output_shape, output_dtype))
6026 outputs.append(serializer.addOutput(output_shape, output_dtype))
6027 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00006028
6029 @staticmethod
6030 def addShapeOp(ser, rng, a, b, error_name=None):
6031 if error_name != ErrorIf.RankMismatch:
6032 assert len(a.shape) == len(b.shape)
6033 assert a.dtype == b.dtype
6034
6035 shape = []
6036 for i in range(len(a.shape)):
6037 shape.append(a.shape[i])
6038
6039 fuzz_idx = rng.integers(0, len(a.shape))
6040 if error_name == ErrorIf.DimensionMismatch:
6041 shape[fuzz_idx] += 1
6042
6043 if error_name == ErrorIf.WrongOutputType:
6044 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
6045 outputDType = rng.choice(wrong_dtypes)
6046 else:
6047 outputDType = DType.SHAPE
6048 return ser.addOutput(shape, outputDType)