blob: 399fed69095429a6bcfccd627601782a13d5c6ba [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005import os
Tai Ly60dc48c2024-03-08 22:19:41 +00006import struct
Matthew Haddon630c17c2021-10-14 15:05:41 +01007from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01008from datetime import datetime
9from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -070010
Jeremy Johnson1271c442023-09-05 11:39:26 +010011import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000012import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000013import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010014from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010015from generator.tosa_arg_gen import TosaArgGen
16from generator.tosa_arg_gen import TosaQuantGen
17from generator.tosa_arg_gen import TosaTensorGen
18from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000019from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_error_if import TosaErrorIfArgGen
21from generator.tosa_error_if import TosaErrorValidator
22from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010023from generator.tosa_random_gen import TosaHashRandomGenerator
24from generator.tosa_random_gen import TosaRandomGenerator
Jeremy Johnson1271c442023-09-05 11:39:26 +010025from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000026from tosa.DType import DType
27from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010028
Jeremy Johnson1271c442023-09-05 11:39:26 +010029TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
30// SPDX-License-Identifier: Apache-2.0
31// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
32"""
33
Jeremy Johnsonaf090182024-02-13 18:25:39 +000034logging.basicConfig()
35logger = logging.getLogger("tosa_verif_build_tests")
36
Matthew Haddonb724efc2021-08-25 16:40:29 +010037
Eric Kunzee5e26762020-10-13 16:11:07 -070038class TosaTestGen:
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000039 # This currently matches the 8K level defined in the specification.
Jeremy Johnsonb2099702023-04-12 15:59:01 +010040 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010041 TOSA_8K_LEVEL_MAX_KERNEL = 8192
42 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010043
Jeremy Johnson1271c442023-09-05 11:39:26 +010044 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000045 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010046 TOSA_MI_DOT_PRODUCT_MIN = 1000
47
Eric Kunzee5e26762020-10-13 16:11:07 -070048 def __init__(self, args):
49 self.args = args
50 self.basePath = args.output_dir
51 self.random_seed = args.random_seed
52 self.ser = None
Eric Kunzee5e26762020-10-13 16:11:07 -070053 self.createDynamicOpLists()
54 self.initOpListDefaults()
55 self.quantGen = TosaQuantGen()
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010056 self.global_rng = None
Eric Kunzee5e26762020-10-13 16:11:07 -070057 # Force makeShape to do a specific starting shape
58 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010059 # JSON schema validation
60 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010061 # Data generator library is sometimes needed for compliance set up
62 # even if we are generating the data later (lazy_data_generation)
63 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070064
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010065 # Work out floating point range
66 def convertFPRange(rangeFP, maxFP):
67 # Converts program arguments of max/-max to FP max
68 vals = []
69 for v in rangeFP:
70 if v == "max":
71 v = maxFP
72 elif v == "-max":
73 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000074 elif v < 0:
75 # Trim to minimum data type value
76 v = max(v, -maxFP)
77 elif v > 0:
78 # Trim to maximum data type value
79 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 vals.append(v)
81 return tuple(sorted(vals))
82
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010083 self.random_dtype_range = {
84 DType.SHAPE: tuple(self.args.tensor_shape_range[0:2])
85 }
Won Jeon2c34b462024-02-06 18:37:00 +000086 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010087 self.random_dtype_range[dtype] = convertFPRange(
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010088 args.tensor_fp_value_range,
89 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
90 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010091 self.resetGlobalRNG()
92
93 def resetGlobalRNG(self):
94 self.global_rng = TosaRandomGenerator(self.random_seed, self.random_dtype_range)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010095
Eric Kunzee5e26762020-10-13 16:11:07 -070096 def createSerializer(self, opName, testPath):
97 self.testPath = os.path.join(opName, testPath)
98
99 fullPath = os.path.join(self.basePath, self.testPath)
100 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100101 # Embed const data in the flatbuffer
102 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +0100103 if self.args.lazy_data_gen:
104 # Lazy data generation - so make constants files
105 constMode = ts.ConstMode.INPUTS
106 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100107 constMode = ts.ConstMode.EMBED_DUMP
108 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -0700109
110 def getSerializer(self):
111 return self.ser
112
evacha01ad8e1e22024-03-19 12:42:17 +0000113 def serialize(self, testName, metaData=None, tags=None):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100114 path = Path(self.basePath) / self.testPath
115
116 # Write out TOSA flatbuffer binary
117 path_fb = path / f"{testName}.tosa"
118 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700119 fd.write(self.ser.serialize())
120
Jeremy Johnson1271c442023-09-05 11:39:26 +0100121 # Get JSON descriptor from serializer
122 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
123
124 if metaData:
125 # Add extra meta data to desc.json
126 desc["meta"] = metaData
127
evacha01ad8e1e22024-03-19 12:42:17 +0000128 if tags:
129 desc["tag"] = tags
130
Jeremy Johnson1271c442023-09-05 11:39:26 +0100131 # Validate desc.json before we output it
132 self.descSchemaValidator.validate_config(desc)
133
134 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100135 if "data_gen" in metaData:
136 if self.args.lazy_data_gen:
137 # Output datagen meta data as CPP data
138 path_md = path / f"{testName}_meta_data_gen.cpp"
139 with path_md.open("w") as fd:
140 fd.write(TOSA_AUTOGENERATED_HEADER)
141 fd.write("// Test meta data for data generation setup\n\n")
142 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
143 json.dump(metaData["data_gen"], fd)
144 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100145 if "compliance" in metaData:
146 # Output datagen meta data as CPP data
147 path_md = path / f"{testName}_meta_compliance.cpp"
148 with path_md.open("w") as fd:
149 fd.write(TOSA_AUTOGENERATED_HEADER)
150 fd.write("// Test meta data for compliance validation\n\n")
151 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
152 json.dump(metaData["compliance"], fd)
153 fd.write(')";\n\n')
154
155 # Write desc.json
156 path_desc = path / "desc.json"
157 with path_desc.open("w") as fd:
158 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700159
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100160 def buildPlaceholderTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700161 placeholders = []
162
Kevin Cheng989cb052021-04-28 16:29:44 -0700163 assert len(shape_list) == len(dtype_list)
164
Jeremy Johnson1271c442023-09-05 11:39:26 +0100165 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700166 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100167 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100168 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700169 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700170
171 return placeholders
172
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100173 def buildConstTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 consts = []
175
Kevin Cheng989cb052021-04-28 16:29:44 -0700176 assert len(shape_list) == len(dtype_list)
177
Jeremy Johnson1271c442023-09-05 11:39:26 +0100178 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700179 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100180 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100181 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700182 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700183
184 return consts
185
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100186 def makeShape(self, rng, rank):
Eric Kunzee5e26762020-10-13 16:11:07 -0700187 if self.targetted_shape:
188 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800189 return np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100190 rng.integers(
Kevin Cheng550ccc52021-03-03 11:21:43 -0800191 low=self.args.tensor_shape_range[0],
192 high=self.args.tensor_shape_range[1],
193 size=rank,
194 )
195 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700196
197 def setTargetShape(self, shape):
198 self.targetted_shape = shape
199
Eric Kunzee5e26762020-10-13 16:11:07 -0700200 def shapeStr(self, shape):
201
202 sStr = []
203 # Convert to strings
204 for i in shape:
205 sStr.append(str(i))
206
Kevin Cheng550ccc52021-03-03 11:21:43 -0800207 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700208
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100209 def typeStr(self, dtype):
210 if isinstance(dtype, list) or isinstance(dtype, tuple):
211 assert len(dtype) >= 2
212 strs = [self.typeStr(t) for t in dtype]
213 # Limit types to the first 2 as the 3rd is the accumulator
214 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700215 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100216 if dtype in gtu.DTYPE_ATTRIBUTES:
217 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700218 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100219 raise Exception(
220 "Unknown dtype, cannot convert to string: {}".format(dtype)
221 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700222
Luke Hutton57287132023-02-06 14:54:18 +0000223 def constrictBatchSize(self, shape):
224 # Limit the batch size unless an explicit target shape set
225 if self.args.max_batch_size and not self.args.target_shapes:
226 shape[0] = min(shape[0], self.args.max_batch_size)
227 return shape
228
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100229 def makeDimension(self, rng):
230 return rng.randInt(
James Ward30124a82023-02-02 14:56:33 +0000231 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
232 )
233
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100234 def tensorComplianceMetaData(
235 self, op, inputType, argsDict, outputTensor, errorName
236 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000237 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
238 UNSUPPORTED_NON_FP32_INPUT_OPS = (
239 Op.MATMUL,
240 Op.CONV2D,
241 Op.FULLY_CONNECTED,
242 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000243 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000244 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000245 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100246 if (
247 errorName
248 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000249 or (
250 not gtu.dtypeIsSupportedByCompliance(inputType)
251 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
252 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100253 ):
254 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100255 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100256
Jeremy Johnson1271c442023-09-05 11:39:26 +0100257 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100258 compliance_tens = {
259 "mode": None,
260 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
261 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
262 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100263 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
264 mode = gtu.ComplianceMode.DOT_PRODUCT
265 compliance_tens["dot_product_info"] = {
266 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100267 "ks": int(argsDict["ksb"])
268 if "ksb" in argsDict
269 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100270 }
evacha019c96eef2024-02-07 11:21:55 +0000271 elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100272 mode = gtu.ComplianceMode.FP_SPECIAL
273 elif "compliance" in op and "ulp" in op["compliance"]:
274 mode = gtu.ComplianceMode.ULP
275 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000276 elif "compliance" in op and "relative" in op["compliance"]:
277 mode = gtu.ComplianceMode.RELATIVE
278 compliance_tens["relative_info"] = {
279 "max": argsDict["max_abs_value"],
280 "scale": op["compliance"]["relative"],
281 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100282 elif op["op"] == Op.REDUCE_PRODUCT:
283 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000284 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000285 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000286 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000287 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
288 compliance_tens["abs_error_info"] = {
289 "lower_bound": op["compliance"]["abs_error_lower_bound"]
290 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800291 elif op["op"] in (Op.SIN, Op.COS):
292 mode = gtu.ComplianceMode.ABS_ERROR
293 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
294 compliance_tens["abs_error_info"] = {
295 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
296 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100297 else:
298 mode = gtu.ComplianceMode.EXACT
299 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
300
301 return compliance_tens
302
303 # Build Op functions
304 # Create the output tensor (calling OutputShaper as needed)
305 # Do final tweaks to attributes (if necessary for errorIf)
306 # Add Op into graph
307 # Return resulting tensor information or BuildInfo
308
309 class BuildInfo:
310 """Enhanced build information containing result tensor and associated compliance dict."""
311
312 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000313 if isinstance(resultTensor, list):
314 assert complianceDict is None or isinstance(complianceDict, list)
315 self.resultTensorList = resultTensor
316 self.complianceDictList = complianceDict
317 else:
318 self.resultTensorList = [resultTensor]
319 if complianceDict is None:
320 self.complianceDictList = None
321 else:
322 self.complianceDictList = [complianceDict]
323
324 def getComplianceInfo(self):
325 if self.complianceDictList is None:
326 return None
327 else:
328 tens_dict = {}
329 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
330 if comp is not None:
331 tens_dict[tens.name] = comp
332
333 if tens_dict:
334 # Have some compliance data, so return the info
335 compliance = {
336 "version": "0.1",
337 "tensors": tens_dict,
338 }
339 else:
340 compliance = None
341 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700342
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000343 def build_unary(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100344 self,
345 rng,
346 op,
347 inputs,
348 args_dict,
349 validator_fcns=None,
350 error_name=None,
351 qinfo=None,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000352 ):
353 assert len(inputs) == 1
354 a = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100355 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100356
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000357 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100358
359 # Ensure new output type has correct qinfo
360 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000361 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000362 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100363 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, a.dtype),
364 TosaQuantGen.getZeroPoint(
365 rng, self.args.zeropoint, result_tensor.dtype
366 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000367 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100368
369 # Invalidate Input/Output list for error if checks.
370 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000371 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100372 pCount, cCount = op["operands"]
373 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000374 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100375 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000376 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100377
Les Bell729b0352021-11-24 10:28:21 +0000378 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100379 self.ser,
380 validator_fcns,
381 error_name,
382 op=op,
383 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000384 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000385 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000386 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100387 input_list=input_list,
388 output_list=output_list,
389 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000390 ):
391 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100392
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000393 attr = None
394 if op["op"] == Op.NEGATE:
395 attr = ts.TosaSerializerAttribute()
396 attr.NegateAttribute(qinfo[0], qinfo[1])
397
398 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000399
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000400 compliance = self.tensorComplianceMetaData(
401 op, a.dtype, args_dict, result_tensor, error_name
402 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000403 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700404
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000405 def build_binary_broadcast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100406 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000407 ):
408 assert len(inputs) == 2
409 a, b = inputs
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100410 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100411
412 # Invalidate Input/Output list for error if checks.
413 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000414 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100415 pCount, cCount = op["operands"]
416 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000417 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100418 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000419 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100420
Les Bell729b0352021-11-24 10:28:21 +0000421 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100422 self.ser,
423 validator_fcns,
424 error_name,
425 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000426 input1=a,
427 input2=b,
428 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000429 output_dtype=result_tensor.dtype,
430 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100431 input_list=input_list,
432 output_list=output_list,
433 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000434 ):
435 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100436
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000437 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000438
Jeremy Johnson9a758382023-11-07 16:27:35 +0000439 compliance = self.tensorComplianceMetaData(
440 op, a.dtype, args_dict, result_tensor, error_name
441 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000442
443 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700444
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000445 def build_arithmetic_right_shift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100446 self,
447 rng,
448 op,
449 inputs,
450 args_dict,
451 validator_fcns=None,
452 error_name=None,
453 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000454 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000455 assert len(inputs) == 2
456 a, b = inputs
457 round = args_dict["round"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100458 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100459
460 # Invalidate Input/Output list for error if checks.
461 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000462 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100463 pCount, cCount = op["operands"]
464 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000465 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100466 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000467 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100468
Les Bell729b0352021-11-24 10:28:21 +0000469 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100470 self.ser,
471 validator_fcns,
472 error_name,
473 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000474 input1=a,
475 input2=b,
476 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000477 output_dtype=result_tensor.dtype,
478 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100479 input_list=input_list,
480 output_list=output_list,
481 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000482 ):
483 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800484
485 attr = ts.TosaSerializerAttribute()
486 attr.ArithmeticRightShiftAttribute(round)
487
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000488 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000489
490 compliance = self.tensorComplianceMetaData(
491 op, a.dtype, args_dict, result_tensor, error_name
492 )
493
494 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800495
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100496 def build_mul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100497 self,
498 rng,
499 op,
500 inputs,
501 args_dict,
502 validator_fcns=None,
503 error_name=None,
504 qinfo=None,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100505 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000506 # Note that mul is binary operator but it has a shift value tensor
507 assert len(inputs) == 3
508 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100509
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100510 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700511
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100512 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100513 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100514 result_tensor.setDtype(DType.INT32)
515
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100516 if error_name == ErrorIf.WrongOutputType:
517 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100518 outputDType = rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100519 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100520
521 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000522 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100523 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100524 pCount, cCount = op["operands"]
525 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000526 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100527 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000528 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100529
Les Bell729b0352021-11-24 10:28:21 +0000530 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100531 self.ser,
532 validator_fcns,
533 error_name,
534 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000535 input1=a,
536 input2=b,
537 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100538 output_dtype=result_tensor.dtype,
539 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100540 input_list=input_list,
541 output_list=output_list,
542 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000543 ):
544 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700545
Jeremy Johnson0a042992024-02-28 13:20:05 +0000546 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100547
548 compliance = self.tensorComplianceMetaData(
549 op, a.dtype, args_dict, result_tensor, error_name
550 )
551
552 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
Jeremy Johnson587cc842024-02-08 11:45:44 +0000554 def build_table(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100555 self,
556 rng,
557 op,
558 inputs,
559 args_dict,
560 validator_fcns=None,
561 error_name=None,
562 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000563 ):
564 assert len(inputs) == 1
565 a = inputs[0]
566 table = args_dict["table"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100567 result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700568
Kevin Chengfe392ce2021-10-18 21:51:55 +0000569 attr = ts.TosaSerializerAttribute()
570 attr.TableAttribute(table)
571
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100572 # Invalidate Input/Output list for error if checks.
573 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000574 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100575 pCount, cCount = op["operands"]
576 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000577 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100578 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000579 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100580
Les Bell729b0352021-11-24 10:28:21 +0000581 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100582 self.ser,
583 validator_fcns,
584 error_name,
585 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000586 input_shape=a.shape,
587 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000588 output_dtype=result_tensor.dtype,
589 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100590 input_list=input_list,
591 output_list=output_list,
592 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000593 ):
594 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100595
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000596 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700597
Jeremy Johnson587cc842024-02-08 11:45:44 +0000598 compliance = self.tensorComplianceMetaData(
599 op, a.dtype, args_dict, result_tensor, error_name
600 )
601
602 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700603
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000604 def build_select(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100605 self,
606 rng,
607 op,
608 inputs,
609 args_dict,
610 validator_fcns=None,
611 error_name=None,
612 qinfo=None,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000613 ):
614 assert len(inputs) == 3
615 cond, a, b = inputs
616
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100617 result_tensor = OutputShaper.selectOp(self.ser, rng, cond, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100618
619 # Invalidate Input/Output list for error if checks.
620 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000621 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100622 pCount, cCount = op["operands"]
623 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000624 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100625 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000626 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100627
Les Bell729b0352021-11-24 10:28:21 +0000628 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100629 self.ser,
630 validator_fcns,
631 error_name,
632 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000633 input1=cond,
634 input2=a,
635 input3=b,
636 input_shape=a.shape,
637 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000638 output_dtype=result_tensor.dtype,
639 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100640 input_list=input_list,
641 output_list=output_list,
642 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000643 ):
644 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100645
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000646 self.ser.addOperator(
647 op["op"],
648 input_list,
649 output_list,
650 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000651 compliance = self.tensorComplianceMetaData(
652 op, a.dtype, args_dict, result_tensor, error_name
653 )
654
655 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700656
Jeremy Johnsona0150012023-11-15 15:52:06 +0000657 def build_comparison(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100658 self,
659 rng,
660 op,
661 inputs,
662 args_dict,
663 validator_fcns=None,
664 error_name=None,
665 qinfo=None,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000666 ):
667 assert len(inputs) == 2
668 a, b = inputs
669
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100670 result_tensor = OutputShaper.binaryComparisonOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100671
672 # Invalidate Input/Output list for error if checks.
673 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000674 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100675 pCount, cCount = op["operands"]
676 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000677 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100678 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000679 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100680
Les Bell729b0352021-11-24 10:28:21 +0000681 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100682 self.ser,
683 validator_fcns,
684 error_name,
685 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000686 input1=a,
687 input2=b,
688 input_shape=a.shape,
689 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000690 output_shape=result_tensor.shape,
691 output_dtype=result_tensor.dtype,
692 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100693 input_list=input_list,
694 output_list=output_list,
695 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000696 ):
697 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100698
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000699 self.ser.addOperator(
700 op["op"],
701 input_list,
702 output_list,
703 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000704
705 compliance = self.tensorComplianceMetaData(
706 op, a.dtype, args_dict, result_tensor, error_name
707 )
708 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700709
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000710 def build_argmax(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100711 self, rng, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000712 ):
713 assert len(inputs) == 1
714 a = inputs[0]
715 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100716 result_tensor = OutputShaper.argmaxOp(self.ser, rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100717
718 # Invalidate Input/Output list for error if checks.
719 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000720 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100721 pCount, cCount = op["operands"]
722 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000723 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100724 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000725 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100726
Les Bell729b0352021-11-24 10:28:21 +0000727 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100728 self.ser,
729 validator_fcns,
730 error_name,
731 op=op,
732 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000733 input_shape=a.shape,
734 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000735 output_shape=result_tensor.shape,
736 output_dtype=result_tensor.dtype,
737 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100738 input_list=input_list,
739 output_list=output_list,
740 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000741 ):
742 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
744 attr = ts.TosaSerializerAttribute()
745 attr.AxisAttribute(axis)
746
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000747 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000748
749 compliance = self.tensorComplianceMetaData(
750 op, inputs[0].dtype, args_dict, result_tensor, error_name
751 )
752 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700753
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000754 def build_pool2d(
755 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100756 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000757 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100758 inputs,
759 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000760 validator_fcns=None,
761 error_name=None,
762 qinfo=None,
763 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100764 assert len(inputs) == 1
765 input = inputs[0]
766 # max_pool has no accum_dtype
767 accum_dtype = (
768 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
769 )
770 stride = args_dict["stride"]
771 pad = args_dict["pad"]
772 kernel = args_dict["kernel"]
773
Jeremy Johnson0601f802023-11-08 16:28:09 +0000774 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100775 self.ser, rng, input, kernel, stride, pad, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000776 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100777
778 # Ensure new output type has correct qinfo
779 if error_name == ErrorIf.WrongInputType:
780 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000781 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100782 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, input.dtype),
783 TosaQuantGen.getZeroPoint(
784 rng, self.args.zeropoint, result_tensor.dtype
785 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000786 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100787
788 # Invalidate Input/Output list for error if checks.
789 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000790 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100791 pCount, cCount = op["operands"]
792 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000793 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100794 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000795 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100796
Les Bell729b0352021-11-24 10:28:21 +0000797 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100798 self.ser,
799 validator_fcns,
800 error_name,
801 op=op,
802 input_shape=input.shape,
803 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000804 output_shape=result_tensor.shape,
805 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000806 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100807 kernel=kernel,
808 stride=stride,
809 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000810 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000811 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100812 input_list=input_list,
813 output_list=output_list,
814 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000815 ):
816 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000818 if qinfo is None:
819 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700820
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000821 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100822 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000823
824 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100826 compliance = self.tensorComplianceMetaData(
827 op, inputs[0].dtype, args_dict, result_tensor, error_name
828 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100829
830 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100831
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000832 def build_conv2d(
833 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100834 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000835 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100836 inputs,
837 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000838 validator_fcns=None,
839 error_name=None,
840 qinfo=None,
841 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100842 assert len(inputs) == 3
843 ifm, filter, bias = inputs
844 accum_dtype = args_dict["acc_type"]
845 strides = args_dict["stride"]
846 padding = args_dict["pad"]
847 dilations = args_dict["dilation"]
848
Kevin Cheng550ccc52021-03-03 11:21:43 -0800849 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100850 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100851 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100852 rng,
James Ward8b390432022-08-12 20:48:56 +0100853 ifm,
854 filter,
855 accum_dtype,
856 strides,
857 padding,
858 dilations,
859 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000860 )
861
862 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000863 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
864 DType.INT8,
865 DType.UINT8,
866 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000867 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100868 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
869 TosaQuantGen.getZeroPoint(
870 rng, self.args.zeropoint, result_tensor.dtype
871 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000872 ]
Les Bell0e027d42021-11-09 14:42:14 +0000873
874 # Invalidate Input/Output list for error_if checks.
875 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100876 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000877 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000878 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100879 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000880 )
Les Bell0e027d42021-11-09 14:42:14 +0000881
Les Bell729b0352021-11-24 10:28:21 +0000882 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000883 self.ser,
884 validator_fcns,
885 error_name,
886 op=op,
887 input_dtype=ifm.dtype,
888 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100889 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000890 qinfo=qinfo,
891 input_list=input_list,
892 num_operands=num_operands,
893 output_list=output_list,
894 pad=padding,
895 stride=strides,
896 dilation=dilations,
897 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100898 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100899 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000900 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000901 ):
902 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700903
Tai Lyd3797f02023-11-15 23:06:19 +0000904 # TODO - Test local_bound, for now set local bound attribute to False
905 local_bound = False
906
Eric Kunzee5e26762020-10-13 16:11:07 -0700907 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000908 attr.ConvAttribute(
909 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
910 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700911
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000912 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100913
914 compliance = self.tensorComplianceMetaData(
915 op, ifm.dtype, args_dict, result_tensor, error_name
916 )
917
918 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700919
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000920 def build_conv3d(
921 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100922 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000923 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100924 inputs,
925 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000926 validator_fcns=None,
927 error_name=None,
928 qinfo=None,
929 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100930 assert len(inputs) == 3
931 ifm, filter, bias = inputs
932 accum_dtype = args_dict["acc_type"]
933 strides = args_dict["stride"]
934 padding = args_dict["pad"]
935 dilations = args_dict["dilation"]
936
Kevin Cheng1533b852021-09-01 12:51:58 -0700937 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000938 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100939 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100940 rng,
James Ward8b390432022-08-12 20:48:56 +0100941 ifm,
942 filter,
943 accum_dtype,
944 strides,
945 padding,
946 dilations,
947 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000948 )
949
950 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000951 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
952 DType.INT8,
953 DType.UINT8,
954 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000955 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100956 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
957 TosaQuantGen.getZeroPoint(
958 rng, self.args.zeropoint, result_tensor.dtype
959 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000960 ]
Les Bell0e027d42021-11-09 14:42:14 +0000961
962 # Invalidate Input/Output list for error_if checks.
963 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +0000964 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000965 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000966 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100967 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000968 )
Les Bell0e027d42021-11-09 14:42:14 +0000969
Les Bell729b0352021-11-24 10:28:21 +0000970 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000971 self.ser,
972 validator_fcns,
973 error_name,
974 op=op,
975 input_dtype=ifm.dtype,
976 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +0000977 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000978 qinfo=qinfo,
979 input_list=input_list,
980 num_operands=num_operands,
981 output_list=output_list,
982 pad=padding,
983 stride=strides,
984 dilation=dilations,
985 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100986 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +0000987 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000988 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000989 ):
990 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700991
Tai Lyd3797f02023-11-15 23:06:19 +0000992 # TODO - Test local_bound, for now set local bound attribute to False
993 local_bound = False
994
Kevin Cheng1533b852021-09-01 12:51:58 -0700995 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000996 attr.ConvAttribute(
997 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
998 )
Kevin Cheng1533b852021-09-01 12:51:58 -0700999
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001000 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +00001001
1002 compliance = self.tensorComplianceMetaData(
1003 op, ifm.dtype, args_dict, result_tensor, error_name
1004 )
1005
1006 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001007
Kevin Cheng550ccc52021-03-03 11:21:43 -08001008 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001009 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001010 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001011 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001012 inputs,
1013 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001014 validator_fcns=None,
1015 error_name=None,
1016 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001017 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001018 assert len(inputs) == 3
1019 ifm, filter, bias = inputs
1020 accum_dtype = args_dict["acc_type"]
1021 strides = args_dict["stride"]
1022 out_pad = args_dict["pad"]
1023 output_shape = args_dict["out_shape"]
1024
TatWai Chong24594f52022-06-08 00:48:04 -07001025 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001026 result_tensor = OutputShaper.transposeConv2DOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001027 self.ser, rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 )
Les Bell0e027d42021-11-09 14:42:14 +00001029
1030 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001031 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1032 DType.INT8,
1033 DType.UINT8,
1034 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001035 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001036 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1037 TosaQuantGen.getZeroPoint(
1038 rng, self.args.zeropoint, result_tensor.dtype
1039 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001040 ]
Les Bell0e027d42021-11-09 14:42:14 +00001041
1042 # Invalidate Input/Output list for error_if checks.
1043 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001044 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001045 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001046 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001047 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001048 )
Les Bell0e027d42021-11-09 14:42:14 +00001049
Les Bell729b0352021-11-24 10:28:21 +00001050 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001051 self.ser,
1052 validator_fcns,
1053 error_name,
1054 op=op,
1055 input_dtype=ifm.dtype,
1056 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001057 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001058 qinfo=qinfo,
1059 input_list=input_list,
1060 num_operands=num_operands,
1061 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001062 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001063 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001064 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001065 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001066 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001067 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001068 ):
1069 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001070
Tai Lyd3797f02023-11-15 23:06:19 +00001071 # TODO - Test local_bound, for now set local bound attribute to False
1072 local_bound = False
1073
Eric Kunzee5e26762020-10-13 16:11:07 -07001074 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001075 attr.TransposeConvAttribute(
Tai Lyf36f2562024-03-14 16:21:29 +00001076 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
Tai Lyd3797f02023-11-15 23:06:19 +00001077 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001078
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001079 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001080
1081 compliance = self.tensorComplianceMetaData(
1082 op, ifm.dtype, args_dict, result_tensor, error_name
1083 )
1084
1085 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001086
Kevin Cheng550ccc52021-03-03 11:21:43 -08001087 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001088 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001089 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001090 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001091 inputs,
1092 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001093 validator_fcns=None,
1094 error_name=None,
1095 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001096 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001097 assert len(inputs) == 3
1098 ifm, filter, bias = inputs
1099 accum_dtype = args_dict["acc_type"]
1100 strides = args_dict["stride"]
1101 padding = args_dict["pad"]
1102 dilations = args_dict["dilation"]
1103
Jeremy Johnson4f931302024-01-04 17:05:24 +00001104 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001105 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001106 rng,
James Ward8b390432022-08-12 20:48:56 +01001107 ifm,
1108 filter,
1109 accum_dtype,
1110 strides,
1111 padding,
1112 dilations,
1113 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001114 )
1115
1116 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001117 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1118 DType.INT8,
1119 DType.UINT8,
1120 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001121 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001122 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1123 TosaQuantGen.getZeroPoint(
1124 rng, self.args.zeropoint, result_tensor.dtype
1125 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001126 ]
Les Bell0e027d42021-11-09 14:42:14 +00001127
1128 # Invalidate Input/Output list for error_if checks.
1129 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001130 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001131 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001132 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001133 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001134 )
Les Bell0e027d42021-11-09 14:42:14 +00001135
Les Bell729b0352021-11-24 10:28:21 +00001136 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001137 self.ser,
1138 validator_fcns,
1139 error_name,
1140 op=op,
1141 input_dtype=ifm.dtype,
1142 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001143 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001144 qinfo=qinfo,
1145 input_list=input_list,
1146 num_operands=num_operands,
1147 output_list=output_list,
1148 pad=padding,
1149 stride=strides,
1150 dilation=dilations,
1151 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001152 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001153 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001154 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001155 ):
1156 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001157
Tai Lyd3797f02023-11-15 23:06:19 +00001158 # TODO - Test local_bound, for now set local bound attribute to False
1159 local_bound = False
1160
Eric Kunzee5e26762020-10-13 16:11:07 -07001161 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +00001162 attr.ConvAttribute(
1163 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
1164 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001165
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001166 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001167
1168 compliance = self.tensorComplianceMetaData(
1169 op, ifm.dtype, args_dict, result_tensor, error_name
1170 )
1171
1172 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001173
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001174 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001175 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001176 rng,
James Ward8b390432022-08-12 20:48:56 +01001177 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001178 inputs,
1179 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001180 validator_fcns=None,
1181 error_name=None,
1182 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001183 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001184 assert len(inputs) == 3
1185 ifm, filter, bias = inputs
1186 accum_dtype = args_dict["acc_type"]
1187
1188 result_tensor = OutputShaper.fullyConnectedOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001189 self.ser, rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001190 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001191
1192 # Invalidate Input/Output list for error if checks.
1193 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001194 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001195 pCount, cCount = op["operands"]
1196 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001197 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001198 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001199 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001200
Les Bell729b0352021-11-24 10:28:21 +00001201 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001202 self.ser,
1203 validator_fcns,
1204 error_name,
1205 op=op,
1206 input_shape=ifm.shape,
1207 input_dtype=ifm.dtype,
1208 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001209 output_shape=result_tensor.shape,
1210 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001211 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001212 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001213 input_list=input_list,
1214 output_list=output_list,
1215 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001216 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001217 ):
1218 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001219
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001220 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001221 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001222
1223 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001224
1225 compliance = self.tensorComplianceMetaData(
1226 op, ifm.dtype, args_dict, result_tensor, error_name
1227 )
1228
1229 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001230
James Ward8b390432022-08-12 20:48:56 +01001231 def build_matmul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001232 self,
1233 rng,
1234 op,
1235 inputs,
1236 args_dict,
1237 validator_fcns=None,
1238 error_name=None,
1239 qinfo=None,
James Ward8b390432022-08-12 20:48:56 +01001240 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001241 assert len(inputs) == 2
1242 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001243 accum_dtype = args_dict["acc_type"]
1244 result_tensor = OutputShaper.matmulOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001245 self.ser, rng, a, b, accum_dtype, error_name
James Ward8b390432022-08-12 20:48:56 +01001246 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001247
1248 # Invalidate Input/Output list for error if checks.
1249 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001250 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001251 pCount, cCount = op["operands"]
1252 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001253 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001254 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001255 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001256
Les Bell729b0352021-11-24 10:28:21 +00001257 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001258 self.ser,
1259 validator_fcns,
1260 error_name,
1261 op=op,
1262 input_shape=a.shape,
1263 input_dtype=a.dtype,
1264 input2_shape=b.shape,
1265 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001266 output_shape=result_tensor.shape,
1267 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001268 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001269 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001270 input_list=input_list,
1271 output_list=output_list,
1272 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001273 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001274 ):
1275 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001276
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001277 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001278 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001279
1280 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001281
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001282 compliance = self.tensorComplianceMetaData(
1283 op, a.dtype, args_dict, result_tensor, error_name
1284 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001285
1286 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001287
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001288 def build_reduce(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001289 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001290 ):
1291 assert len(inputs) == 1
1292 a = inputs[0]
1293 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001294 result_tensor = OutputShaper.reduceOp(self.ser, rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001295
1296 # Invalidate Input/Output list for error if checks.
1297 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001298 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001299 pCount, cCount = op["operands"]
1300 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001301 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001302 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001303 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001304
Les Bell729b0352021-11-24 10:28:21 +00001305 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001306 self.ser,
1307 validator_fcns,
1308 error_name,
1309 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 axis=axis,
1311 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001312 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001313 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001314 output_dtype=result_tensor.dtype,
1315 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001316 input_list=input_list,
1317 output_list=output_list,
1318 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001319 ):
1320 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001321
1322 attr = ts.TosaSerializerAttribute()
1323 attr.AxisAttribute(axis)
1324
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001325 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001326
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001327 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1328 # Number of products - needed for compliance
1329 args_dict["n"] = a.shape[axis]
1330
1331 compliance = self.tensorComplianceMetaData(
1332 op, a.dtype, args_dict, result_tensor, error_name
1333 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001334
1335 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001336
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001337 def build_clamp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001338 self,
1339 rng,
1340 op,
1341 inputs,
1342 args_dict,
1343 validator_fcns=None,
1344 error_name=None,
1345 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001346 ):
1347 assert len(inputs) == 1
1348 a = inputs[0]
1349
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001350 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001351
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001352 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001353
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001354 if error_name == ErrorIf.MaxSmallerMin:
1355 # Make sure the numbers are different to invoke this error
1356 while v[0] == v[1]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001357 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001358 max_val = min(v)
1359 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001360 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001361 max_val = max(v)
1362 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001363
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001364 # Invalidate Input/Output list for error if checks.
1365 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001366 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001367 pCount, cCount = op["operands"]
1368 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001369 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001370 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001371 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001372
Les Bell729b0352021-11-24 10:28:21 +00001373 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001374 self.ser,
1375 validator_fcns,
1376 error_name,
1377 op=op,
1378 max_val=max_val,
1379 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001380 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001381 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001382 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001383 output_dtype=result_tensor.dtype,
1384 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001385 input_list=input_list,
1386 output_list=output_list,
1387 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001388 ):
1389 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001390
1391 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001392 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1393 if a.dtype == DType.FP16:
1394 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1395 min_val = min_val.astype(np.float32)
1396 max_val = max_val.astype(np.float32)
Tai Ly60dc48c2024-03-08 22:19:41 +00001397 min_val_as_bytes = struct.pack("<f", min_val)
1398 max_val_as_bytes = struct.pack("<f", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001399 elif a.dtype in (DType.INT8, DType.INT16):
Tai Ly60dc48c2024-03-08 22:19:41 +00001400 min_val_as_bytes = struct.pack("<i", min_val)
1401 max_val_as_bytes = struct.pack("<i", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001402 else:
1403 # to avoid internal error for incorrect input types
Tai Ly60dc48c2024-03-08 22:19:41 +00001404 min_val_as_bytes = struct.pack("<i", 0)
1405 max_val_as_bytes = struct.pack("<i", 0)
1406
1407 attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001408
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001409 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001410
1411 compliance = self.tensorComplianceMetaData(
1412 op, a.dtype, args_dict, result_tensor, error_name
1413 )
1414
1415 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001416
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001417 def build_activation(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001418 self,
1419 rng,
1420 op,
1421 inputs,
1422 args_dict,
1423 validator_fcns=None,
1424 error_name=None,
1425 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001426 ):
1427 assert len(inputs) == 1
1428 a = inputs[0]
1429
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001430 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001431
1432 # Invalidate Input/Output list for error if checks.
1433 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001434 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001435 pCount, cCount = op["operands"]
1436 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001437 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001438 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001439 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001440
Les Bell729b0352021-11-24 10:28:21 +00001441 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001442 self.ser,
1443 validator_fcns,
1444 error_name,
1445 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001446 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001447 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001448 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001449 output_dtype=result_tensor.dtype,
1450 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001451 input_list=input_list,
1452 output_list=output_list,
1453 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001454 ):
1455 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001456
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001457 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001458
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001459 compliance = self.tensorComplianceMetaData(
1460 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001461 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001462
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001463 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001464
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001465 def build_concat(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001466 self,
1467 rng,
1468 op,
1469 inputs,
1470 args_dict,
1471 validator_fcns=None,
1472 error_name=None,
1473 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001474 ):
Won Jeon74342e52024-01-09 00:34:40 +00001475 if op["op"] == Op.CONCAT_SHAPE:
1476 axis = 0
1477 else:
1478 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001479 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001480 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001481
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001482 result_tensor = OutputShaper.concatOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001483 self.ser, rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001484 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001485
Matthew Haddon818ab902021-07-27 09:12:49 +01001486 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001487 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001488 input_tensor_names.append(tensor.name)
1489
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001490 # Invalidate Input/Output list for error if checks.
1491 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001492 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001493 pCount, cCount = op["operands"]
1494 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001495 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001496 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001497 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001498
Les Bell729b0352021-11-24 10:28:21 +00001499 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001500 self.ser,
1501 validator_fcns,
1502 error_name,
1503 op=op,
1504 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001505 input_shape=inputs[0].shape,
1506 output_shape=result_tensor.shape,
1507 input_dtype=inputs[0].dtype,
1508 output_dtype=result_tensor.dtype,
1509 inputs=inputs,
1510 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001511 input_list=input_list,
1512 output_list=output_list,
1513 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001514 ):
1515 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001516
Won Jeon74342e52024-01-09 00:34:40 +00001517 if op["op"] == Op.CONCAT:
1518 attr = ts.TosaSerializerAttribute()
1519 attr.AxisAttribute(axis)
1520 else:
1521 assert op["op"] == Op.CONCAT_SHAPE
1522 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001523 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001524
1525 compliance = self.tensorComplianceMetaData(
1526 op, inputs[0].dtype, args_dict, result_tensor, error_name
1527 )
1528
1529 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001530
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001531 def build_pad(
1532 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001533 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001534 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001535 inputs,
1536 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001537 validator_fcns=None,
1538 error_name=None,
1539 qinfo=None,
1540 ):
Tai Lye095da72024-01-25 22:00:18 +00001541 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001542 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001543 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001544 padding = args_dict["pad"]
1545 pad_const_int = args_dict["pad_const_int"]
1546 pad_const_float = args_dict["pad_const_fp"]
1547
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001548 result_tensor = OutputShaper.padOp(self.ser, rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001549
Tai Ly60dc48c2024-03-08 22:19:41 +00001550 # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
1551 if gtu.dtypeIsFloat(a.dtype):
1552 pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
1553 else:
1554 pad_const_val_as_bytes = struct.pack("<i", pad_const_int)
1555
Kevin Chengfe392ce2021-10-18 21:51:55 +00001556 attr = ts.TosaSerializerAttribute()
Tai Ly60dc48c2024-03-08 22:19:41 +00001557 attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
Eric Kunzee5e26762020-10-13 16:11:07 -07001558
Matthew Haddone807aae2021-10-11 18:12:58 +01001559 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001560 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001561 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001562 pCount, cCount = op["operands"]
1563 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001564 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001565 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001566 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001567
Les Bell729b0352021-11-24 10:28:21 +00001568 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001569 self.ser,
1570 validator_fcns,
1571 error_name,
1572 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001573 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001574 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001575 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001576 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001577 pad=padding,
1578 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001579 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001580 input_list=input_list,
1581 output_list=output_list,
1582 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001583 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001584 ):
1585 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001586
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001587 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001588
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001589 compliance = self.tensorComplianceMetaData(
1590 op, a.dtype, args_dict, result_tensor, error_name
1591 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001592
1593 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001594
Won Jeona21b2e82023-08-10 10:33:01 +00001595 def build_dim(
1596 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001597 rng,
Won Jeona21b2e82023-08-10 10:33:01 +00001598 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001599 inputs,
1600 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001601 validator_fcns=None,
1602 error_name=None,
1603 qinfo=None,
1604 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001605 assert len(inputs) == 1
1606 a = inputs[0]
1607 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001608 result_tensor = OutputShaper.dimOp(self.ser, rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001609
1610 # Invalidate Input/Output list for error if checks.
1611 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001612 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001613 pCount, cCount = op["operands"]
1614 num_operands = pCount + cCount
1615 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001616 rng, error_name, input_list, output_list
Won Jeona21b2e82023-08-10 10:33:01 +00001617 )
1618
1619 if not TosaErrorValidator.evValidateErrorIfs(
1620 self.ser,
1621 validator_fcns,
1622 error_name,
1623 op=op,
1624 axis=axis,
1625 input_shape=a.shape,
1626 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001627 output_shape=result_tensor.shape,
1628 output_dtype=result_tensor.dtype,
1629 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001630 input_list=input_list,
1631 output_list=output_list,
1632 num_operands=num_operands,
1633 ):
1634 return None
1635
1636 attr = ts.TosaSerializerAttribute()
1637 attr.AxisAttribute(axis)
1638
1639 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001640 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001641
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001642 def build_reshape(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001643 self,
1644 rng,
1645 op,
1646 inputs,
1647 args_dict,
1648 validator_fcns=None,
1649 error_name=None,
1650 qinfo=None,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001651 ):
Tai Ly8690a082023-12-18 20:40:24 +00001652 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001653 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001654 shape = inputs[1]
1655 shape_attr = args_dict["new_shape"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001656 result_tensor = OutputShaper.reshapeOp(self.ser, rng, a, shape_attr, error_name)
Matthew Haddone807aae2021-10-11 18:12:58 +01001657
1658 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001659 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001660 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001661 pCount, cCount = op["operands"]
1662 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001663 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001664 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001665 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001666
Les Bell729b0352021-11-24 10:28:21 +00001667 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001668 self.ser,
1669 validator_fcns,
1670 error_name,
1671 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001672 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001673 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001674 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001675 output_dtype=result_tensor.dtype,
1676 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001677 input_list=input_list,
1678 output_list=output_list,
1679 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001680 ):
1681 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001682
Tai Ly8690a082023-12-18 20:40:24 +00001683 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001684
1685 compliance = self.tensorComplianceMetaData(
1686 op, a.dtype, args_dict, result_tensor, error_name
1687 )
1688
1689 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001690
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001691 def build_reverse(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001692 self,
1693 rng,
1694 op,
1695 inputs,
1696 args_dict,
1697 validator_fcns=None,
1698 error_name=None,
1699 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001700 ):
1701 assert len(inputs) == 1
1702 a = inputs[0]
1703 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001704 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001705
1706 # Invalidate Input/Output list for error if checks.
1707 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001708 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001709 pCount, cCount = op["operands"]
1710 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001711 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001712 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001713 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001714
Les Bell729b0352021-11-24 10:28:21 +00001715 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001716 self.ser,
1717 validator_fcns,
1718 error_name,
1719 op=op,
1720 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001721 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001722 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001723 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001724 output_dtype=result_tensor.dtype,
1725 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001726 input_list=input_list,
1727 output_list=output_list,
1728 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001729 ):
1730 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001731
1732 attr = ts.TosaSerializerAttribute()
1733 attr.AxisAttribute(axis)
1734
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001735 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001736 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001737
evacha0198477222024-01-26 12:25:32 +00001738 def build_transpose(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001739 self,
1740 rng,
1741 op,
1742 inputs,
1743 args_dict,
1744 validator_fcns=None,
1745 error_name=None,
1746 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00001747 ):
1748 assert len(inputs) == 1
1749 a = inputs[0]
1750 perms = args_dict["perms"]
1751
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001752 result_tensor = OutputShaper.transposeOp(self.ser, rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001753
Kevin Chengfe392ce2021-10-18 21:51:55 +00001754 attr = ts.TosaSerializerAttribute()
1755 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001756
Matthew Haddone807aae2021-10-11 18:12:58 +01001757 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001758 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001759 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001760 pCount, cCount = op["operands"]
1761 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001762 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001763 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001764 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001765
Les Bell729b0352021-11-24 10:28:21 +00001766 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001767 self.ser,
1768 validator_fcns,
1769 error_name,
1770 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001771 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001772 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001773 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001774 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001775 output_dtype=result_tensor.dtype,
1776 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001777 input_list=input_list,
1778 output_list=output_list,
1779 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001780 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001781 ):
1782 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001783
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001784 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001785
1786 compliance = self.tensorComplianceMetaData(
1787 op, a.dtype, args_dict, result_tensor, error_name
1788 )
1789
1790 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001791
evacha017f7d4252024-01-24 12:08:09 +00001792 def build_slice(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001793 self,
1794 rng,
1795 op,
1796 inputs,
1797 args_dict,
1798 validator_fcns=None,
1799 error_name=None,
1800 qinfo=None,
evacha017f7d4252024-01-24 12:08:09 +00001801 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001802 assert len(inputs) == 3
1803 a, start_var, size_var = inputs
1804 start_const = args_dict["start"]
1805 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001806
1807 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001808 self.ser, rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001809 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001810
1811 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001812 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001813 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001814 pCount, cCount = op["operands"]
1815 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001816 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001817 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001818 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001819
Les Bell729b0352021-11-24 10:28:21 +00001820 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001821 self.ser,
1822 validator_fcns,
1823 error_name,
1824 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001825 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001826 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001827 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001828 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001829 start=start_const,
1830 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001831 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001832 input_list=input_list,
1833 output_list=output_list,
1834 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001835 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001836 ):
1837 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001838
Tai Ly8ead6c42024-02-14 22:35:44 +00001839 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001840
1841 compliance = self.tensorComplianceMetaData(
1842 op, a.dtype, args_dict, result_tensor, error_name
1843 )
1844
1845 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001846
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001847 def build_tile(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001848 self,
1849 rng,
1850 op,
1851 inputs,
1852 args_dict,
1853 validator_fcns=None,
1854 error_name=None,
1855 qinfo=None,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001856 ):
Tai Ly8690a082023-12-18 20:40:24 +00001857 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001858 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001859 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001860 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001861 result_tensor = OutputShaper.tileOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001862 self.ser, rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001863 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001864
1865 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001866 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001867 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001868 pCount, cCount = op["operands"]
1869 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001870 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001871 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001872 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001873
Les Bell729b0352021-11-24 10:28:21 +00001874 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001875 self.ser,
1876 validator_fcns,
1877 error_name,
1878 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001879 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001880 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001881 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001882 output_dtype=result_tensor.dtype,
1883 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001884 input_list=input_list,
1885 output_list=output_list,
1886 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001887 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001888 ):
1889 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001890
Tai Ly8690a082023-12-18 20:40:24 +00001891 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001892
1893 compliance = self.tensorComplianceMetaData(
1894 op, a.dtype, args_dict, result_tensor, error_name
1895 )
1896
1897 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001898
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001899 def build_gather(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001900 self,
1901 rng,
1902 op,
1903 inputs,
1904 args_dict,
1905 validator_fcns=None,
1906 error_name=None,
1907 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001908 ):
1909 assert len(inputs) == 2
1910 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001911
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001912 result_tensor = OutputShaper.gatherOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001913 self.ser, rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001914 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001915
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001916 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001917 input_list = [values.name, indices.name]
1918 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001919 pCount, cCount = op["operands"]
1920 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001921 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001922 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001923 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001924
Les Bell729b0352021-11-24 10:28:21 +00001925 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001926 self.ser,
1927 validator_fcns,
1928 error_name,
1929 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001930 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001931 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001932 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001933 output_dtype=result_tensor.dtype,
1934 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001935 input_list=input_list,
1936 output_list=output_list,
1937 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001938 ):
1939 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001940
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001941 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001942
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001943 compliance = self.tensorComplianceMetaData(
1944 op, values.dtype, args_dict, result_tensor, error_name
1945 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001946
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001947 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001948
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001949 def build_scatter(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001950 self,
1951 rng,
1952 op,
1953 inputs,
1954 args_dict,
1955 validator_fcns=None,
1956 error_name=None,
1957 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001958 ):
1959 assert len(inputs) == 3
1960 values_in, indices, input = inputs
1961 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001962 self.ser, rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001963 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001964
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001965 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001966 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001967 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001968 pCount, cCount = op["operands"]
1969 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001970 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001971 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001972 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001973
Les Bell729b0352021-11-24 10:28:21 +00001974 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001975 self.ser,
1976 validator_fcns,
1977 error_name,
1978 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001979 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001980 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001981 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001982 output_dtype=result_tensor.dtype,
1983 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001984 input_list=input_list,
1985 output_list=output_list,
1986 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001987 ):
1988 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001989
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001990 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001991
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001992 compliance = self.tensorComplianceMetaData(
1993 op, values_in.dtype, args_dict, result_tensor, error_name
1994 )
1995
1996 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001997
Kevin Cheng550ccc52021-03-03 11:21:43 -08001998 def build_resize(
1999 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002000 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002001 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002002 inputs,
2003 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01002004 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002005 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002006 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002007 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002008 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002009 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002010 scale_input = inputs[1]
2011 offset_input = inputs[2]
2012 border_input = inputs[3]
2013
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002014 mode = args_dict["mode"]
2015 scale = args_dict["scale"]
2016 offset = args_dict["offset"]
2017 border = args_dict["border"]
2018 output_dtype = args_dict["output_dtype"]
2019
2020 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002021 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002022 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002023 input,
2024 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002025 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002026 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002027 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002028 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002029 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002030 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002031 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002032
Matthew Haddon848efb42021-09-09 12:30:53 +01002033 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002034 input_list = [
2035 input.name,
2036 scale_input.name,
2037 offset_input.name,
2038 border_input.name,
2039 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002040 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002041 pCount, cCount = op["operands"]
2042 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002043 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002044 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002045 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002046
Les Bell729b0352021-11-24 10:28:21 +00002047 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002048 self.ser,
2049 validator_fcns,
2050 error_name,
2051 op=op,
2052 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002053 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002054 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002055 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002056 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002057 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002058 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002059 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002060 input_list=input_list,
2061 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002062 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002063 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002064 ):
2065 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002066
Eric Kunzee5e26762020-10-13 16:11:07 -07002067 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002068 # write empty scale/offset/border into ResizeAttribute
2069 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002070 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002071
2072 compliance = self.tensorComplianceMetaData(
2073 op, input.dtype, args_dict, result_tensor, error_name
2074 )
2075
2076 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002077
evacha0198477222024-01-26 12:25:32 +00002078 def build_const(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002079 self,
2080 rng,
2081 op,
2082 inputs,
2083 args_dict,
2084 validator_fcns=None,
2085 error_name=None,
2086 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00002087 ):
2088 assert len(inputs) == 1
2089 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002090 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002091
2092 compliance = self.tensorComplianceMetaData(
2093 op, val.dtype, args_dict, val, error_name
2094 )
2095
2096 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002097
2098 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002099 def build_cast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002100 self,
2101 rng,
2102 op,
2103 inputs,
2104 args_dict,
2105 validator_fcns=None,
2106 error_name=None,
2107 qinfo=None,
Jeremy Johnson708da822023-11-15 16:25:45 +00002108 ):
2109 assert len(inputs) == 1
2110 val = inputs[0]
2111 out_dtype = args_dict["out_type"]
2112
2113 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002114 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002115 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002116
2117 # Invalidate Input/Output list for error if checks.
2118 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002119 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002120 pCount, cCount = op["operands"]
2121 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002122 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002123 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002124 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002125
Les Bell729b0352021-11-24 10:28:21 +00002126 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002127 self.ser,
2128 validator_fcns,
2129 error_name,
2130 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002131 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002132 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002133 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002134 output_dtype=result_tensor.dtype,
2135 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002136 input_list=input_list,
2137 output_list=output_list,
2138 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002139 ):
2140 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002141
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002142 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002143
2144 compliance = self.tensorComplianceMetaData(
2145 op, val.dtype, args_dict, result_tensor, error_name
2146 )
2147
2148 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002149
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002150 def build_rescale(
2151 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002152 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002153 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002154 inputs,
2155 args_dict,
2156 validator_fcns=None,
2157 error_name=None,
2158 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002159 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002160 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002161 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002162 multiplier_val = inputs[1]
2163 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002164 out_dtype = args_dict["output_dtype"]
2165 scale32 = args_dict["scale"]
2166 double_round = args_dict["double_round"]
2167 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002168 shift_arr = args_dict["shift"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002169 multiplier_arr = args_dict["multiplier"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002170
2171 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002172 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002173 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002174
2175 if per_channel:
2176 nc = val.shape[-1]
2177 else:
2178 nc = 1
2179
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002180 in_type_width = gtu.dtypeWidth(val.dtype)
2181 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002182
Tai Ly8690a082023-12-18 20:40:24 +00002183 input_unsigned = False
2184 output_unsigned = False
2185
Kevin Cheng3a478572021-01-22 17:21:02 -08002186 if val.dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002187 input_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002188 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002189 elif val.dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002190 input_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002191 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002192 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002193 elif error_name in [
2194 ErrorIf.InputZeroPointNotZero,
2195 ErrorIf.U16InputZeroPointNotValid,
2196 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002197 input_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002198 if input_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002199 input_zp = input_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002200 in_type_width += 1
2201 elif val.dtype == DType.UINT16:
2202 # Must come after ErrorIf.U16InputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002203 input_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002204 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002205 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002206 else:
2207 input_zp = 0
2208
Kevin Cheng3a478572021-01-22 17:21:02 -08002209 if out_dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002210 output_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002211 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002212 elif out_dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002213 output_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002214 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002215 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002216 elif error_name in [
2217 ErrorIf.OutputZeroPointNotZero,
2218 ErrorIf.U16OutputZeroPointNotValid,
2219 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002220 output_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002221 if output_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002222 output_zp = output_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002223 out_type_width += 1
2224 elif out_dtype == DType.UINT16:
2225 # Must come after ErrorIf.U16OutputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002226 output_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002227 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002228 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002229 else:
2230 output_zp = 0
2231
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002232 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2233 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002234
2235 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002236 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2237 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002238
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002239 logger.debug(
2240 f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
2241 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002242 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002243 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002244 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002245 assert val.placeholderFilename
2246 values = np.load(
2247 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2248 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002249 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2250 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2251 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002252 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2253 # Check we can safely convert to the expected dtype
2254 assert (
2255 val_adj.all() >= np.iinfo(values.dtype).min
2256 and val_adj.all() <= np.iinfo(values.dtype).max
2257 )
2258
2259 # Force casting to output datatype
2260 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2261
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002262 if not np.all(np.array_equal(values, val_adj)):
2263 # Values changed so overwrite file with new values
2264 np.save(
2265 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2266 val_adj,
2267 False,
2268 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002269
Matthew Haddonc2025212021-10-08 21:21:05 +01002270 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002271 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002272 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002273 pCount, cCount = op["operands"]
2274 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002275 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002276 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002277 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002278
2279 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002280 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002281 self.ser,
2282 validator_fcns,
2283 error_name,
2284 op=op,
2285 input_dtype=val.dtype,
2286 output_dtype=out_dtype,
2287 input_shape=val.shape,
2288 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002289 scale32=scale32,
2290 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002291 input_list=input_list,
2292 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002293 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002294 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002295 ):
2296 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002297
Eric Kunzee5e26762020-10-13 16:11:07 -07002298 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002299 attr.RescaleAttribute(
2300 input_zp,
2301 output_zp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002302 scale32,
2303 double_round,
2304 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002305 input_unsigned,
2306 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002307 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002308
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002309 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002310
2311 compliance = self.tensorComplianceMetaData(
2312 op, val.dtype, args_dict, result_tensor, error_name
2313 )
2314
2315 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002316
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002317 def _get_condition_tensor(self, rng, op, cond, error_name):
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002318 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002319 cond_type = gtu.get_wrong_output_type(op, rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002320 else:
2321 cond_type = DType.BOOL
2322 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002323 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002324 if choice == 1:
2325 cond_shape = [2]
2326 else:
2327 cond_shape = [1, 2]
2328 else:
2329 # Must be of size 1 (rank 0)
2330 cond_shape = []
2331 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2332 return cond_tens
2333
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002334 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002335 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002336 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002337 op,
2338 inputs,
2339 args_dict,
2340 validator_fcns=None,
2341 error_name=None,
2342 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002343 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002344 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002345 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002346 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002347 assert len(inputs) == 2
2348 then_tens, else_tens = inputs
2349
2350 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002351
2352 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002353 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002354
2355 # Make then/else tensors
2356 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002357
Jeremy Johnson587cc842024-02-08 11:45:44 +00002358 dtype = DType.INT32
2359
Matthew Haddon630c17c2021-10-14 15:05:41 +01002360 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002361 if error_name in [
2362 ErrorIf.CondIfOutputListThenGraphMismatch,
2363 ErrorIf.CondIfOutputListElseGraphMismatch,
2364 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002365 incorrect_shape = deepcopy(then_tens.shape)
2366 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002367 incorrect_shape[i] += (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002368 rng.choice([-3, -2, 2, 3])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002369 if incorrect_shape[i] > 3
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002370 else rng.choice([1, 2, 4])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002371 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002372 incorrect_arr = np.int32(rng.integers(0, 256, size=incorrect_shape))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002373
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002374 then_arr = np.int32(rng.integers(0, 256, size=out_shape))
2375 else_arr = np.int32(rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002376
2377 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002378 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002379
2380 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002381 then_block = "THEN_BLOCK"
2382 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002383 attr = ts.TosaSerializerAttribute()
2384 attr.CondIfAttribute(then_block, else_block)
2385
2386 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002387 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002388
Jerry Ge9e94af82022-10-27 09:57:00 -07002389 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002390 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002391 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002392 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002393 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002394 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002395 self.ser.addOutputTensor(then_tens)
2396
Jerry Ge9e94af82022-10-27 09:57:00 -07002397 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002398 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002399 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002400 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002401 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002402 self.ser.addOutputTensor(else_tens)
2403
Les Bell729b0352021-11-24 10:28:21 +00002404 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002405 self.ser,
2406 validator_fcns,
2407 error_name,
2408 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002409 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002410 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002411 ):
2412 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002413
Jeremy Johnson587cc842024-02-08 11:45:44 +00002414 compliance = self.tensorComplianceMetaData(
2415 op, dtype, args_dict, result_tensor, error_name
2416 )
2417
2418 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002419
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002420 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002421 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002422 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002423 op,
2424 inputs,
2425 args_dict,
2426 validator_fcns=None,
2427 error_name=None,
2428 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002429 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002430 # For cond_if with a binary op in the then/else blocks, take a and b and
2431 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002432 assert len(inputs) == 2
2433 a, b = inputs
2434
2435 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002436
2437 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002438 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002439
Jeremy Johnson587cc842024-02-08 11:45:44 +00002440 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002441
2442 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002443 then_block = "THEN_BLOCK"
2444 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002445 attr = ts.TosaSerializerAttribute()
2446 attr.CondIfAttribute(then_block, else_block)
2447
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002448 if error_name in [
2449 ErrorIf.CondIfInputListThenGraphMismatch,
2450 ErrorIf.CondIfInputListElseGraphMismatch,
2451 ErrorIf.CondIfOutputListElseGraphMismatch,
2452 ErrorIf.CondIfOutputListThenGraphMismatch,
2453 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002454 incorrect_shape = a.shape.copy()
2455 for i in range(len(incorrect_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002456 incorrect_shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002457 incorrect_block_input = deepcopy(a)
2458 incorrect_block_input.shape = incorrect_shape
2459
Eric Kunzee5e26762020-10-13 16:11:07 -07002460 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002461 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002462 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002463 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002464
James Ward24dbc422022-10-19 12:20:31 +01002465 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002466 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002467 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002468 then_op, else_op = (
2469 self.TOSA_OP_LIST["logical_right_shift"],
2470 self.TOSA_OP_LIST["logical_left_shift"],
2471 )
Les Bell6040b4d2021-10-11 12:50:31 +01002472 else:
2473 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002474
Jeremy Johnson587cc842024-02-08 11:45:44 +00002475 # Determine the element-wise binary operation that compliance will need to
2476 # check the results of
2477 compliance_op = then_op if cond else else_op
2478
2479 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002480 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002481 if (
2482 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2483 and block == then_block
2484 ) or (
2485 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2486 and block == else_block
2487 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002488 self.ser.addInputTensor(incorrect_block_input)
2489 self.ser.addInputTensor(b)
2490 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002491 elif (
2492 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2493 and block == then_block
2494 ) or (
2495 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2496 and block == else_block
2497 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002498 self.ser.addInputTensor(a)
2499 self.ser.addInputTensor(b)
2500 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2501 else:
2502 self.ser.addInputTensor(a)
2503 self.ser.addInputTensor(b)
2504 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002505 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002506
Les Bell729b0352021-11-24 10:28:21 +00002507 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002508 self.ser,
2509 validator_fcns,
2510 error_name,
2511 op=op,
2512 a=a,
2513 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002514 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002515 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002516 ):
2517 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002518
Jeremy Johnson587cc842024-02-08 11:45:44 +00002519 compliance = self.tensorComplianceMetaData(
2520 compliance_op, a.dtype, args_dict, result_tensor, error_name
2521 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002522
Jeremy Johnson587cc842024-02-08 11:45:44 +00002523 return TosaTestGen.BuildInfo(result_tensor, compliance)
2524
2525 def build_while_loop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002526 self,
2527 rng,
2528 op,
2529 inputs,
2530 args_dict,
2531 validator_fcns=None,
2532 error_name=None,
2533 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002534 ):
2535 assert len(inputs) == 1
2536 a = inputs[0]
2537 iter_val = args_dict["iterations"]
2538
Kevin Cheng550ccc52021-03-03 11:21:43 -08002539 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002540
Kevin Cheng550ccc52021-03-03 11:21:43 -08002541 cond_block = "COND_BLOCK"
2542 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002543
2544 attr = ts.TosaSerializerAttribute()
2545 attr.WhileLoopAttribute(cond_block, body_block)
2546
2547 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002548 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002549 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002550 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002551
2552 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002553 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2554 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002555 if error_name == ErrorIf.InputListOutputListMismatch:
2556 incorrect_acc = deepcopy(acc)
2557 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002558 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002559 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2560 else:
2561 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002562
2563 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002564 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002565 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002566 [iter.name, a.name, acc.name],
2567 [iter_out.name, a_out.name, acc_out.name],
2568 attr,
2569 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002570 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002571
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002572 if error_name in [
2573 ErrorIf.InputListCondGraphMismatch,
2574 ErrorIf.InputListBodyGraphInputMismatch,
2575 ErrorIf.InputListBodyGraphOutputMismatch,
2576 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002577 incorrect_iter = deepcopy(iter)
2578 for i in range(len(incorrect_iter.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002579 incorrect_iter.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002580 if len(incorrect_iter.shape) == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002581 incorrect_iter.shape.append(rng.choice([-3, -2, 2, 3]))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002582
2583 incorrect_acc = deepcopy(acc)
2584 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002585 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002586
Eric Kunzee5e26762020-10-13 16:11:07 -07002587 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002588 self.ser.addBasicBlock(cond_block)
2589
Matthew Haddon630c17c2021-10-14 15:05:41 +01002590 if error_name == ErrorIf.InputListCondGraphMismatch:
2591 self.ser.addInputTensor(incorrect_iter)
2592 self.ser.addInputTensor(a)
2593 self.ser.addInputTensor(incorrect_acc)
2594 else:
2595 self.ser.addInputTensor(iter)
2596 self.ser.addInputTensor(a)
2597 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002598 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002599
2600 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002601 cond_type = rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002602 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002603 cond_type = DType.BOOL
2604 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002605 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002606 if choice == 1:
2607 cond_shape = [3]
2608 else:
2609 cond_shape = [1, 2]
2610 else:
2611 cond_shape = []
2612 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002613
Kevin Cheng550ccc52021-03-03 11:21:43 -08002614 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002615
2616 # BODY block (input: a, acc, iter, output: a, acc, iter)
2617 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002618 self.ser.addBasicBlock(body_block)
2619
Matthew Haddon630c17c2021-10-14 15:05:41 +01002620 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2621 self.ser.addInputTensor(incorrect_iter)
2622 self.ser.addInputTensor(a)
2623 self.ser.addInputTensor(incorrect_acc)
2624 else:
2625 self.ser.addInputTensor(iter)
2626 self.ser.addInputTensor(a)
2627 self.ser.addInputTensor(acc)
2628
Kevin Cheng550ccc52021-03-03 11:21:43 -08002629 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002630
2631 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002632 iter_body_out = self.ser.addIntermediate(
2633 incorrect_iter.shape, incorrect_iter.dtype
2634 )
2635 acc_body_out = self.ser.addIntermediate(
2636 incorrect_acc.shape, incorrect_acc.dtype
2637 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002638 else:
2639 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2640 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2641
Eric Kunzee5e26762020-10-13 16:11:07 -07002642 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2643 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2644 self.ser.addOutputTensor(iter_body_out)
2645 self.ser.addOutputTensor(a)
2646 self.ser.addOutputTensor(acc_body_out)
2647
Les Bell729b0352021-11-24 10:28:21 +00002648 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002649 self.ser,
2650 validator_fcns,
2651 error_name,
2652 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002653 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002654 ):
2655 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002656
Jeremy Johnson587cc842024-02-08 11:45:44 +00002657 compliance = self.tensorComplianceMetaData(
2658 op, a.dtype, args_dict, acc_out, error_name
2659 )
2660
2661 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002662
Luke Hutton57287132023-02-06 14:54:18 +00002663 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002664 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002665 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002666 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002667 inputs,
2668 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002669 validator_fcns=None,
2670 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002671 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002672 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002673 assert len(inputs) == 2
2674 val1, val2 = inputs
2675 inverse = args_dict["inverse"]
2676
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002677 results = OutputShaper.fft2dOp(self.ser, rng, val1, val2, error_name)
Luke Hutton57287132023-02-06 14:54:18 +00002678
2679 input_names = [val1.name, val2.name]
2680 pCount, cCount = op["operands"]
2681 num_operands = pCount + cCount
2682
2683 output_names = [res.name for res in results]
2684 output_shapes = [res.shape for res in results]
2685 output_dtypes = [res.dtype for res in results]
2686
2687 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002688 rng, error_name, input_names, output_names
Luke Hutton57287132023-02-06 14:54:18 +00002689 )
2690
2691 if not TosaErrorValidator.evValidateErrorIfs(
2692 self.ser,
2693 validator_fcns,
2694 error_name,
2695 op=op,
2696 inverse=inverse,
2697 input1=val1,
2698 input2=val2,
2699 input_shape=val1.shape,
2700 input_dtype=val1.dtype,
2701 output_shape=output_shapes,
2702 output_dtype=output_dtypes,
2703 result_tensors=results,
2704 input_list=input_names,
2705 output_list=output_names,
2706 num_operands=num_operands,
2707 ):
2708 return None
2709
Tai Lyd3797f02023-11-15 23:06:19 +00002710 # TODO - Test local_bound, for now set local bound attribute to False
2711 local_bound = False
2712
Luke Hutton57287132023-02-06 14:54:18 +00002713 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002714 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002715
2716 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002717
2718 compliance = []
2719 for res in results:
2720 compliance.append(
2721 self.tensorComplianceMetaData(
2722 op, val1.dtype, args_dict, res, error_name
2723 )
2724 )
2725
2726 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002727
Tai Lyd3797f02023-11-15 23:06:19 +00002728 def build_rfft2d(
2729 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002730 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002731 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002732 inputs,
2733 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002734 validator_fcns=None,
2735 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002736 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002737 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002738 assert len(inputs) == 1
2739 val = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002740 results = OutputShaper.rfft2dOp(self.ser, rng, val, error_name)
Luke Hutton261b7b62023-01-10 14:50:31 +00002741
2742 input_names = [val.name]
2743 pCount, cCount = op["operands"]
2744 num_operands = pCount + cCount
2745
2746 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002747 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002748 output_dtypes = [res.dtype for res in results]
2749
2750 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002751 rng, error_name, input_names, output_names
Luke Hutton261b7b62023-01-10 14:50:31 +00002752 )
2753
2754 if not TosaErrorValidator.evValidateErrorIfs(
2755 self.ser,
2756 validator_fcns,
2757 error_name,
2758 op=op,
2759 input_shape=val.shape,
2760 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002761 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002762 output_dtype=output_dtypes,
2763 result_tensors=results,
2764 input_list=input_names,
2765 output_list=output_names,
2766 num_operands=num_operands,
2767 ):
2768 return None
2769
Tai Lyd3797f02023-11-15 23:06:19 +00002770 # TODO - Test local_bound, for now set local bound attribute to False
2771 local_bound = False
2772
2773 attr = ts.TosaSerializerAttribute()
2774 attr.RFFTAttribute(local_bound)
2775
2776 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002777
2778 compliance = []
2779 for res in results:
2780 compliance.append(
2781 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2782 )
2783
2784 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002785
Won Jeon74342e52024-01-09 00:34:40 +00002786 def build_shape_op(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002787 self,
2788 rng,
2789 op,
2790 inputs,
2791 args_dict,
2792 validator_fcns=None,
2793 error_name=None,
2794 qinfo=None,
Won Jeon74342e52024-01-09 00:34:40 +00002795 ):
2796 assert len(inputs) == 2
2797 a, b = inputs
2798
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002799 result_tensor = OutputShaper.addShapeOp(self.ser, rng, a, b, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00002800
2801 # Invalidate Input/Output list for error if checks.
2802 input_list = [a.name, b.name]
2803 output_list = [result_tensor.name]
2804 pCount, cCount = op["operands"]
2805 num_operands = pCount + cCount
2806 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2807 self, error_name, input_list, output_list
2808 )
2809
2810 if not TosaErrorValidator.evValidateErrorIfs(
2811 self.ser,
2812 validator_fcns,
2813 error_name,
2814 op=op,
2815 input1=a,
2816 input2=b,
2817 input_shape=a.shape,
2818 input_dtype=a.dtype,
2819 output_shape=result_tensor.shape,
2820 output_dtype=result_tensor.dtype,
2821 result_tensors=[result_tensor],
2822 input_list=input_list,
2823 output_list=output_list,
2824 num_operands=num_operands,
2825 ):
2826 return None
2827
2828 self.ser.addOperator(
2829 op["op"],
2830 input_list,
2831 output_list,
2832 )
2833 compliance = self.tensorComplianceMetaData(
2834 op, a.dtype, args_dict, result_tensor, error_name
2835 )
2836
2837 return TosaTestGen.BuildInfo(result_tensor, compliance)
2838
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002839 def create_filter_lists(
2840 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2841 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002842 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2843 default_test_rank_range = range(1, 5)
2844 if not shapeFilter:
2845 shapeFilter = [None]
2846
2847 # Calculate the filters based on what is requested and what the operator allows
2848 rmin, rmax = op["rank"]
2849 if rankFilter is not None:
2850 cleanRankFilter = []
2851 # Ensure rankFilter values are allowed by operator
2852 for rank in rankFilter:
2853 if rank >= rmin and rank <= rmax:
2854 cleanRankFilter.append(rank)
2855 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002856 # Ensure default behaviour is bounded by default range or by operator,
2857 # whichever is the smaller range of ranks.
2858 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002859 cleanRankFilter = (
2860 opRankRange
2861 if len(opRankRange) <= len(default_test_rank_range)
2862 else default_test_rank_range
2863 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002864 else:
2865 cleanRankFilter = range(rmin, rmax + 1)
2866
2867 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002868
Matthew Haddon1c00b712021-10-01 15:51:03 +01002869 if dtypeFilter is not None:
2870 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002871 # Create list of operator dtypes filtered by requested dtypes
2872 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002873 if dtype in dtypeFilter or (
2874 isinstance(dtype, list) and dtype[0] in dtypeFilter
2875 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002876 cleanDtypeFilter.append(dtype)
2877 else:
2878 cleanDtypeFilter = dtypes
2879
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002880 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002881 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002882 "shapeFilter": shapeFilter,
2883 "rankFilter": cleanRankFilter,
2884 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002885 }
2886 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002887 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002888 if validator is not None:
2889 validator_info = validator(check=False, op=op)
2890 else:
2891 return None
2892
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002893 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002894
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002895 # Set parameters as required
2896 if error_arguments["rank"] is not None:
2897 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002898 else:
2899 rankFilter = cleanRankFilter
2900
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002901 if error_arguments["dtype"] is not None:
2902 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002903 else:
2904 dtypeFilter = cleanDtypeFilter
2905
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002906 if error_arguments["shape"] is not None:
2907 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002908 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002909 shapeFilter = shapeFilter[
2910 :2
2911 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002912
2913 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002914 "shapeFilter": shapeFilter,
2915 "rankFilter": rankFilter,
2916 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002917 }
2918 return filterDict
2919
Kevin Cheng550ccc52021-03-03 11:21:43 -08002920 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002921 self,
2922 opName,
2923 shapeFilter=[None],
2924 rankFilter=None,
2925 dtypeFilter=None,
2926 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002927 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002928
2929 try:
2930 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002931 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002932 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002933
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002934 if not self.args.stable_rng:
2935 # Initialize a new random number generator per op
2936 self.resetGlobalRNG()
Eric Kunzee5e26762020-10-13 16:11:07 -07002937
Jeremy Johnson1271c442023-09-05 11:39:26 +01002938 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002939
Eric Kunzee5e26762020-10-13 16:11:07 -07002940 # Test list consists of a tuple of:
2941 # (opName, testNameStr, dtype, shapeList, argumentsList)
2942 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002943 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002944 error_if_validators = op["error_if_validators"]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002945 num_error_types_created = 0
Matthew Haddon1c00b712021-10-01 15:51:03 +01002946 else:
2947 error_if_validators = [None]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002948 num_error_types_created = None
Eric Kunzee5e26762020-10-13 16:11:07 -07002949
Matthew Haddon1c00b712021-10-01 15:51:03 +01002950 for validator in error_if_validators:
2951 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002952 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002953 else:
2954 error_name = None
2955
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002956 filterDict = self.create_filter_lists(
2957 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2958 )
2959 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002960 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002961 cleanRankFilter = filterDict["rankFilter"]
2962 cleanDtypeFilter = filterDict["dtypeFilter"]
2963 cleanShapeFilter = filterDict["shapeFilter"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002964 logger.debug(
2965 f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
2966 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002967
2968 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002969 for t in cleanDtypeFilter:
2970 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002971 # Filter out by rank
2972 if shape is not None and len(shape) != r:
2973 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002974 self.setTargetShape(shape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002975 typeStr = self.typeStr(t)
2976 if self.args.stable_rng:
2977 shape_rng = TosaHashRandomGenerator(
2978 self.random_seed,
2979 [opName, r, typeStr],
2980 self.random_dtype_range,
2981 )
2982 else:
2983 shape_rng = self.global_rng
2984 shapeList = tgen_fcn(self, shape_rng, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002985
Matthew Haddon74567092021-07-16 15:38:20 +01002986 shapeStr = self.shapeStr(shapeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07002987
Matthew Haddon74567092021-07-16 15:38:20 +01002988 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2989 argList = []
2990 if agen_fcn:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002991 if self.args.stable_rng:
2992 arg_rng = TosaHashRandomGenerator(
2993 self.random_seed,
2994 [opName, shapeStr, typeStr],
2995 self.random_dtype_range,
2996 )
2997 else:
2998 arg_rng = self.global_rng
2999
3000 argList = agen_fcn(
3001 self, arg_rng, opName, shapeList, t, error_name
3002 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003003 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003004 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003005
Matthew Haddon74567092021-07-16 15:38:20 +01003006 for argStr, args in argList:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003007 # Create the test name string - for example: add_1x2x3_i32
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003008 if testType == "positive":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003009 name_parts = [opName, shapeStr, typeStr]
3010 else:
3011 assert testType == "negative"
3012 name_parts = [
3013 opName,
3014 "ERRORIF",
3015 error_name,
3016 shapeStr,
3017 typeStr,
3018 ]
3019 if argStr:
3020 name_parts.append(argStr)
3021 testStr = "_".join(name_parts)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003022
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003023 testList.append(
3024 (opName, testStr, t, error_name, shapeList, args)
3025 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003026 if error_name is not None:
3027 # Check the last test is of the error we wanted
3028 if len(testList) == 0 or testList[-1][3] != error_name:
3029 if self.args.level8k:
3030 logger.info(f"Missing {error_name} tests due to level8k mode")
3031 else:
3032 logger.error(f"ERROR: Failed to create any {error_name} tests")
3033 logger.debug(
3034 "Last test created: {}".format(
3035 testList[-1] if testList else None
3036 )
3037 )
3038 else:
3039 # Successfully created at least one ERRROR_IF test
3040 num_error_types_created += 1
Matthew Haddon1c00b712021-10-01 15:51:03 +01003041
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003042 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01003043 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3044 if "invalid_test_validators" in op:
3045 invalid_test_validators = op["invalid_test_validators"]
3046 clean_testList = []
3047 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01003048 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003049 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003050 if validator_fcn(
3051 opName=test[0],
3052 input_dtype=test[2],
3053 shapeList=test[4],
3054 args=test[5],
3055 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003056 remove_test = True
3057 if not remove_test:
3058 clean_testList.append(test)
3059 testList = clean_testList
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003060 else:
3061 if num_error_types_created is not None and not self.args.level8k:
3062 remaining_error_types = (
3063 len(error_if_validators) - num_error_types_created
3064 )
3065 if remaining_error_types:
3066 raise Exception(
3067 f"Failed to create {remaining_error_types} error types for {opName}"
3068 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003069
3070 return testList
3071
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003072 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00003073 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003074 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003075 try:
3076 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003077 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003078 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003079
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003080 logger.info(f"Creating {testStr}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01003081
Eric Kunzee5e26762020-10-13 16:11:07 -07003082 # Create a serializer
3083 self.createSerializer(opName, testStr)
3084
Jeremy Johnson1271c442023-09-05 11:39:26 +01003085 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003086 if "error_if_validators" in op:
3087 error_if_validators = op["error_if_validators"]
3088 else:
3089 error_if_validators = None
3090
Kevin Cheng550ccc52021-03-03 11:21:43 -08003091 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003092 num_operands = pCount + cCount
3093
3094 if isinstance(dtype_or_dtypeList, list):
3095 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003096 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003097 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003098 else:
3099 dtypeList = [dtype_or_dtypeList] * (num_operands)
3100
Won Jeon74342e52024-01-09 00:34:40 +00003101 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003102 assert (
3103 len(shapeList) == num_operands
3104 ), "shapeList length {} must match number of operands {}".format(
3105 len(shapeList), num_operands
3106 )
3107 assert (
3108 len(dtypeList) == num_operands
3109 ), "dtypeList length {} must match number of operands {}".format(
3110 len(dtypeList), num_operands
3111 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003112
3113 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003114 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003115 except KeyError:
3116 qgen = None
3117
3118 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003119
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003120 # Set the random number generator
3121 if self.args.stable_rng:
3122 build_rng = TosaHashRandomGenerator(
3123 self.random_seed, [testStr], self.random_dtype_range
3124 )
3125 else:
3126 build_rng = self.global_rng
3127
Matthew Haddon1c00b712021-10-01 15:51:03 +01003128 if qgen is not None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003129 qinfo = qgen(
3130 build_rng, self.args.zeropoint, op, dtype_or_dtypeList, error_name
3131 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003132 else:
3133 qinfo = None
3134
Jeremy Johnson1271c442023-09-05 11:39:26 +01003135 # Extra meta data for the desc.json
3136 tensMeta = {}
3137
Jeremy Johnson587cc842024-02-08 11:45:44 +00003138 # Check we are using the new interface with an argsDict dictionary
3139 assert isinstance(
3140 argsDict, dict
3141 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003142
Jeremy Johnson587cc842024-02-08 11:45:44 +00003143 # New interface with args info in dictionary
3144 assert "dg_type" in argsDict
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003145 tvgInfo = tvgen_fcn(
3146 self, build_rng, opName, dtypeList, shapeList, argsDict, error_name
3147 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003148 if tvgInfo.dataGenDict:
3149 tensMeta["data_gen"] = tvgInfo.dataGenDict
3150 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003151
evacha01ad8e1e22024-03-19 12:42:17 +00003152 tags = argsDict.get("tags", None)
3153
Jeremy Johnson587cc842024-02-08 11:45:44 +00003154 result = build_fcn(
3155 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003156 build_rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003157 op,
3158 tens,
3159 argsDict,
3160 validator_fcns=error_if_validators,
3161 error_name=error_name,
3162 qinfo=qinfo,
3163 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003164
Jeremy Johnson1271c442023-09-05 11:39:26 +01003165 if result:
Les Bell729b0352021-11-24 10:28:21 +00003166 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003167 if isinstance(result, TosaTestGen.BuildInfo):
3168 # Add the compliance meta data (if any)
3169 compliance = result.getComplianceInfo()
3170 if compliance:
3171 tensMeta["compliance"] = compliance
evacha01ad8e1e22024-03-19 12:42:17 +00003172 self.serialize("test", tensMeta, tags)
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003173 return True
Les Bell729b0352021-11-24 10:28:21 +00003174 else:
3175 # The test is not valid
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003176 logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003177 return False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003178
Eric Kunzee5e26762020-10-13 16:11:07 -07003179 def createDynamicOpLists(self):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003180 # Find all the ops marked as templates
3181 templateKeys = []
3182 for opName in self.TOSA_OP_LIST:
Eric Kunzee5e26762020-10-13 16:11:07 -07003183 try:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003184 if self.TOSA_OP_LIST[opName]["template"]:
3185 templateKeys.append(opName)
Eric Kunzee5e26762020-10-13 16:11:07 -07003186 except KeyError:
3187 pass
3188
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003189 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3190
3191 # Add dynamic ops based on kernel sizes
3192 for opName in templateKeys:
3193 assert opName.endswith("_TEMPLATE"), "Found incorrect template"
3194 realName = opName[: len(opName) - len("_TEMPLATE")]
3195 template = self.TOSA_OP_LIST[opName]
3196 k_rank = 3 if realName == "conv3d" else 2
3197
3198 # Choose kernels to build tests for from the template or args
3199 if self.args.level8k:
3200 if k_rank == 3:
3201 kernels = [[1, bigK, 1], [2, 2, bigK]]
3202 else:
3203 kernels = [[1, bigK], [bigK, 2]]
3204 else:
3205 kernels = []
3206 if len(self.args.conv_kernels) > 0:
3207 kernels = [k for k in self.args.conv_kernels if len(k) == k_rank]
3208 if len(kernels) == 0:
3209 logger.debug(
3210 f"{realName} op using defaults as no rank {k_rank} kernels found in {self.args.conv_kernels}"
3211 )
3212 if len(kernels) == 0:
3213 # Fallback to use the defined template kernels
3214 kernels = self.TOSA_OP_LIST[opName]["filter"]
3215
3216 # Dynamically create ops for listed kernel sizes
3217 for k in kernels:
3218 kernelStr = "x".join([str(d) for d in k])
3219 testName = f"{realName}_{kernelStr}"
3220 kernelOp = template.copy()
3221 kernelOp["filter"] = k
3222 kernelOp["template"] = False
3223 kernelOp["real_name"] = realName
3224 self.TOSA_OP_LIST[testName] = kernelOp
3225
3226 # Delete the template after having created the dynamic ops
3227 del self.TOSA_OP_LIST[opName]
Eric Kunzee5e26762020-10-13 16:11:07 -07003228
3229 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003230 """Fill in default fields for ops if they aren't already specified.
3231 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003232 for op in self.TOSA_OP_LIST:
3233
3234 # Required fields
3235 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003236 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003237 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003238 raise Exception(
3239 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3240 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003241
3242 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003243 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003244 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003245 raise Exception(
3246 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3247 op
3248 )
3249 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003250
3251 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003252 _ = self.TOSA_OP_LIST[op]["types"]
3253 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003254 raise Exception(
3255 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3256 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003257
3258 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003259 _ = self.TOSA_OP_LIST[op]["op"]
3260 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003261 raise Exception(
3262 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3263 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003264
3265 # Put in default rank range, if missing
3266 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003267 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003268 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003269 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003270
3271 # Tensor operator list
3272 # 'op': op name
3273 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003274 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3275 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003276 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3277 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003278 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003279
Kevin Cheng550ccc52021-03-03 11:21:43 -08003280 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003281 TYPE_INT_FP = [
3282 DType.INT8,
3283 DType.INT16,
3284 DType.INT32,
3285 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003286 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003287 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003288 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003289
Kevin Cheng550ccc52021-03-03 11:21:43 -08003290 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003291 TYPE_FI32 = [
3292 DType.FP32,
3293 DType.FP16,
3294 DType.BF16,
3295 DType.INT32,
3296 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003297 TYPE_FIB = [
3298 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003299 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003300 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003301 DType.INT8,
3302 DType.INT16,
3303 DType.INT32,
3304 DType.BOOL,
3305 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003306 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003307
Won Jeon2c34b462024-02-06 18:37:00 +00003308 TYPE_NARROW_INT_FP = [
3309 DType.INT8,
3310 DType.INT16,
3311 DType.FP16,
3312 DType.BF16,
3313 DType.FP32,
3314 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003315
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003316 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003317 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003318 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003319 [DType.INT8, DType.INT8, DType.INT32],
3320 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003321 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003322 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003323 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003324 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003325 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3326 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003327 ]
3328
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003329 DEFAULT_RANK_RANGE = (1, gtu.MAX_TENSOR_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003330
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003331 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3332 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3333
evacha01ad8e1e22024-03-19 12:42:17 +00003334 PSEUDO_RANDOM_DATAGEN = {
3335 DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM,),
3336 DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM,),
3337 }
3338 DOT_PRODUCT_DATAGEN = {
3339 DType.FP16: (gtu.DataGenType.DOT_PRODUCT,),
3340 DType.FP32: (gtu.DataGenType.DOT_PRODUCT,),
3341 }
3342 EW_UNARY_DATAGEN = {
3343 DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE)
3344 }
3345
Eric Kunzee5e26762020-10-13 16:11:07 -07003346 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003347 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003348 "argmax": {
3349 "op": Op.ARGMAX,
3350 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003351 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003352 "build_fcn": (
3353 build_argmax,
3354 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003355 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003356 TosaArgGen.agAxis,
3357 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003358 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003359 "error_if_validators": (
3360 TosaErrorValidator.evAxisSmallerZero,
3361 TosaErrorValidator.evAxisLargerRank,
3362 TosaErrorValidator.evArgmaxOutputRankMismatch,
3363 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3364 TosaErrorValidator.evWrongRank,
3365 TosaErrorValidator.evWrongInputType,
3366 TosaErrorValidator.evWrongOutputType,
3367 TosaErrorValidator.evWrongInputList,
3368 TosaErrorValidator.evWrongOutputList,
3369 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003370 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003371 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003372 "avg_pool2d": {
3373 "op": Op.AVG_POOL2D,
3374 "operands": (1, 0),
3375 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003376 "build_fcn": (
3377 build_pool2d,
3378 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003379 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003380 TosaArgGen.agPooling,
3381 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003382 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003383 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003384 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003385 "error_if_validators": (
3386 TosaErrorValidator.evKernelSmallerOne,
3387 TosaErrorValidator.evStrideSmallerOne,
3388 TosaErrorValidator.evPadSmallerZero,
3389 TosaErrorValidator.evWrongRank,
3390 TosaErrorValidator.evWrongInputType,
3391 TosaErrorValidator.evWrongOutputType,
3392 TosaErrorValidator.evWrongInputList,
3393 TosaErrorValidator.evWrongOutputList,
3394 TosaErrorValidator.evInputZeroPointNotZero,
3395 TosaErrorValidator.evOutputZeroPointNotZero,
3396 TosaErrorValidator.evPadLargerEqualKernel,
3397 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003398 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003399 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003400 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003401 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003402 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003403 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003404 "conv2d_TEMPLATE": {
3405 "op": Op.CONV2D,
3406 "operands": (1, 2),
3407 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003408 "build_fcn": (
3409 build_conv2d,
3410 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003411 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003412 TosaArgGen.agConv,
3413 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003414 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003415 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003416 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3417 "error_if_validators": (
3418 TosaErrorValidator.evWrongInputType,
3419 TosaErrorValidator.evWrongOutputType,
3420 TosaErrorValidator.evWrongInputList,
3421 TosaErrorValidator.evWrongOutputList,
3422 TosaErrorValidator.evInputZeroPointNotZero,
3423 TosaErrorValidator.evWeightZeroPointNotZero,
3424 TosaErrorValidator.evPadSmallerZero,
3425 TosaErrorValidator.evStrideSmallerOne,
3426 TosaErrorValidator.evDilationSmallerOne,
3427 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003428 TosaErrorValidator.evConvOutputShapeMismatch,
3429 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003430 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003431 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003432 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003433 "broadcastable_bias": True,
3434 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003435 "template": True,
3436 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003437 # Templated operator. Filled in by createDynamicOpLists
3438 "conv3d_TEMPLATE": {
3439 "op": Op.CONV3D,
3440 "operands": (1, 2),
3441 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003442 "build_fcn": (
3443 build_conv3d,
3444 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003445 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003446 TosaArgGen.agConv,
3447 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003448 "qgen": TosaQuantGen.qgConv,
3449 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003450 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3451 "error_if_validators": (
3452 TosaErrorValidator.evWrongInputType,
3453 TosaErrorValidator.evWrongOutputType,
3454 TosaErrorValidator.evWrongInputList,
3455 TosaErrorValidator.evWrongOutputList,
3456 TosaErrorValidator.evInputZeroPointNotZero,
3457 TosaErrorValidator.evWeightZeroPointNotZero,
3458 TosaErrorValidator.evPadSmallerZero,
3459 TosaErrorValidator.evStrideSmallerOne,
3460 TosaErrorValidator.evDilationSmallerOne,
3461 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003462 TosaErrorValidator.evConvOutputShapeMismatch,
3463 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003464 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003465 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003466 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003467 "filter": KERNELS_3D,
Kevin Cheng1533b852021-09-01 12:51:58 -07003468 "template": True,
3469 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003470 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003471 "depthwise_conv2d_TEMPLATE": {
3472 "op": Op.DEPTHWISE_CONV2D,
3473 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003474 "rank": (4, 4),
3475 "build_fcn": (
3476 build_depthwise_conv2d,
3477 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003478 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003479 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003480 ),
3481 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003482 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003483 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3484 "error_if_validators": (
3485 TosaErrorValidator.evWrongInputType,
3486 TosaErrorValidator.evWrongOutputType,
3487 TosaErrorValidator.evWrongInputList,
3488 TosaErrorValidator.evWrongOutputList,
3489 TosaErrorValidator.evInputZeroPointNotZero,
3490 TosaErrorValidator.evWeightZeroPointNotZero,
3491 TosaErrorValidator.evPadSmallerZero,
3492 TosaErrorValidator.evStrideSmallerOne,
3493 TosaErrorValidator.evDilationSmallerOne,
3494 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003495 TosaErrorValidator.evConvOutputShapeMismatch,
3496 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003497 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003498 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003499 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003500 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003501 "template": True,
3502 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003503 "fully_connected": {
3504 "op": Op.FULLY_CONNECTED,
3505 "operands": (1, 2),
3506 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003507 "build_fcn": (
3508 build_fully_connected,
3509 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003510 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003511 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003512 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003513 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003514 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003515 "error_if_validators": (
3516 TosaErrorValidator.evInputZeroPointNotZero,
3517 TosaErrorValidator.evWeightZeroPointNotZero,
3518 TosaErrorValidator.evWrongRank,
3519 TosaErrorValidator.evWrongInputType,
3520 TosaErrorValidator.evWrongOutputType,
3521 TosaErrorValidator.evWrongInputList,
3522 TosaErrorValidator.evWrongOutputList,
3523 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003524 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003526 "matmul": {
3527 "op": Op.MATMUL,
3528 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003529 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003530 "build_fcn": (
3531 build_matmul,
3532 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003533 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003534 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003535 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003536 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003537 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003538 "error_if_validators": (
3539 TosaErrorValidator.evInputZeroPointNotZero,
3540 TosaErrorValidator.evWrongRank,
3541 TosaErrorValidator.evWrongInputType,
3542 TosaErrorValidator.evWrongOutputType,
3543 TosaErrorValidator.evWrongInputList,
3544 TosaErrorValidator.evWrongOutputList,
3545 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003546 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003547 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 "max_pool2d": {
3549 "op": Op.MAX_POOL2D,
3550 "operands": (1, 0),
3551 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003552 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003553 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003554 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003555 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003556 TosaArgGen.agPooling,
3557 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003558 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003559 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003560 "error_if_validators": (
3561 TosaErrorValidator.evKernelSmallerOne,
3562 TosaErrorValidator.evStrideSmallerOne,
3563 TosaErrorValidator.evPadSmallerZero,
3564 TosaErrorValidator.evWrongRank,
3565 TosaErrorValidator.evWrongInputType,
3566 TosaErrorValidator.evWrongOutputType,
3567 TosaErrorValidator.evWrongInputList,
3568 TosaErrorValidator.evWrongOutputList,
3569 TosaErrorValidator.evPadLargerEqualKernel,
3570 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003571 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003572 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003573 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003574 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003575 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003576 "transpose_conv2d_TEMPLATE": {
3577 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003578 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003579 "rank": (4, 4),
3580 "build_fcn": (
3581 build_transpose_conv2d,
3582 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003583 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003584 TosaArgGen.agTransposeConv2D,
3585 ),
3586 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003587 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003588 "invalid_test_validators": (
3589 TosaInvalidValidator.ivHeightWidthInvalid,
3590 TosaInvalidValidator.ivNonPositiveOutputShape,
3591 ),
3592 "error_if_validators": (
3593 TosaErrorValidator.evWrongInputType,
3594 TosaErrorValidator.evWrongOutputType,
3595 TosaErrorValidator.evWrongInputList,
3596 TosaErrorValidator.evWrongOutputList,
3597 TosaErrorValidator.evInputZeroPointNotZero,
3598 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003599 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003600 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003601 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003602 TosaErrorValidator.evConvOutputShapeMismatch,
Tai Lyf36f2562024-03-14 16:21:29 +00003603 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003604 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003605 "data_gen": DOT_PRODUCT_DATAGEN,
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003606 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003607 "template": True,
3608 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003609 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003610 "clamp": {
3611 "op": Op.CLAMP,
3612 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003613 "build_fcn": (
3614 build_clamp,
3615 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003616 TosaTensorValuesGen.tvgLazyGenDefault,
3617 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003618 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003619 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003620 "error_if_validators": (
3621 TosaErrorValidator.evMaxSmallerMin,
3622 TosaErrorValidator.evWrongInputType,
3623 TosaErrorValidator.evWrongOutputType,
3624 TosaErrorValidator.evWrongInputList,
3625 TosaErrorValidator.evWrongOutputList,
3626 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003627 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003628 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003629 "sigmoid": {
3630 "op": Op.SIGMOID,
3631 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003632 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003633 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003634 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003635 TosaTensorValuesGen.tvgLazyGenDefault,
3636 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003637 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003638 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003639 "error_if_validators": (
3640 TosaErrorValidator.evWrongInputType,
3641 TosaErrorValidator.evWrongOutputType,
3642 TosaErrorValidator.evWrongInputList,
3643 TosaErrorValidator.evWrongOutputList,
3644 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003645 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003646 },
3647 "tanh": {
3648 "op": Op.TANH,
3649 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003650 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003651 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003652 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003653 TosaTensorValuesGen.tvgLazyGenDefault,
3654 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003655 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003656 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003657 "error_if_validators": (
3658 TosaErrorValidator.evWrongInputType,
3659 TosaErrorValidator.evWrongOutputType,
3660 TosaErrorValidator.evWrongInputList,
3661 TosaErrorValidator.evWrongOutputList,
3662 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003663 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003664 "compliance": {
3665 "abs_error_lower_bound": 0.5,
3666 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003667 },
Won Jeon78155c62023-06-10 00:20:04 +00003668 "erf": {
3669 "op": Op.ERF,
3670 "operands": (1, 0),
3671 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003672 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003673 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003674 TosaTensorValuesGen.tvgLazyGenDefault,
3675 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003676 ),
3677 "types": TYPE_FP,
3678 "error_if_validators": (
3679 TosaErrorValidator.evWrongInputType,
3680 TosaErrorValidator.evWrongOutputType,
3681 TosaErrorValidator.evWrongInputList,
3682 TosaErrorValidator.evWrongOutputList,
3683 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003684 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003685 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003686 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003687 # Elementwise Binary Operators
3688 "add": {
3689 "op": Op.ADD,
3690 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003691 "build_fcn": (
3692 build_binary_broadcast,
3693 TosaTensorGen.tgBroadcastFuzz,
3694 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003695 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003696 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003697 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003698 "error_if_validators": (
3699 TosaErrorValidator.evRankMismatch,
3700 TosaErrorValidator.evWrongInputType,
3701 TosaErrorValidator.evWrongOutputType,
3702 TosaErrorValidator.evWrongInputList,
3703 TosaErrorValidator.evWrongOutputList,
3704 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003705 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003706 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003707 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003708 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003709 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003710 "arithmetic_right_shift": {
3711 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3712 "operands": (2, 0),
3713 "build_fcn": (
3714 build_arithmetic_right_shift,
3715 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003716 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003717 TosaArgGen.agArithmeticRightShift,
3718 ),
3719 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003720 "error_if_validators": (
3721 TosaErrorValidator.evRankMismatch,
3722 TosaErrorValidator.evWrongInputType,
3723 TosaErrorValidator.evWrongOutputType,
3724 TosaErrorValidator.evWrongInputList,
3725 TosaErrorValidator.evWrongOutputList,
3726 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003727 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003728 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003729 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003730 "bitwise_and": {
3731 "op": Op.BITWISE_AND,
3732 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003733 "build_fcn": (
3734 build_binary_broadcast,
3735 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003736 TosaTensorValuesGen.tvgLazyGenDefault,
3737 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003738 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003739 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003740 "error_if_validators": (
3741 TosaErrorValidator.evRankMismatch,
3742 TosaErrorValidator.evWrongInputType,
3743 TosaErrorValidator.evWrongOutputType,
3744 TosaErrorValidator.evWrongInputList,
3745 TosaErrorValidator.evWrongOutputList,
3746 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003747 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003748 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003749 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003750 "bitwise_or": {
3751 "op": Op.BITWISE_OR,
3752 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003753 "build_fcn": (
3754 build_binary_broadcast,
3755 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003756 TosaTensorValuesGen.tvgLazyGenDefault,
3757 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003758 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003759 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003760 "error_if_validators": (
3761 TosaErrorValidator.evRankMismatch,
3762 TosaErrorValidator.evWrongInputType,
3763 TosaErrorValidator.evWrongOutputType,
3764 TosaErrorValidator.evWrongInputList,
3765 TosaErrorValidator.evWrongOutputList,
3766 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003767 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003768 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003769 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003770 "bitwise_xor": {
3771 "op": Op.BITWISE_XOR,
3772 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003773 "build_fcn": (
3774 build_binary_broadcast,
3775 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003776 TosaTensorValuesGen.tvgLazyGenDefault,
3777 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003778 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003779 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003780 "error_if_validators": (
3781 TosaErrorValidator.evRankMismatch,
3782 TosaErrorValidator.evWrongInputType,
3783 TosaErrorValidator.evWrongOutputType,
3784 TosaErrorValidator.evWrongInputList,
3785 TosaErrorValidator.evWrongOutputList,
3786 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003787 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003788 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003789 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003790 "intdiv": {
3791 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003792 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003793 "build_fcn": (
3794 build_binary_broadcast,
3795 TosaTensorGen.tgBroadcastFuzz,
3796 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003797 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003798 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003799 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003800 "error_if_validators": (
3801 TosaErrorValidator.evRankMismatch,
3802 TosaErrorValidator.evWrongInputType,
3803 TosaErrorValidator.evWrongOutputType,
3804 TosaErrorValidator.evWrongInputList,
3805 TosaErrorValidator.evWrongOutputList,
3806 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003807 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003808 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003809 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003810 "logical_and": {
3811 "op": Op.LOGICAL_AND,
3812 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003813 "build_fcn": (
3814 build_binary_broadcast,
3815 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003816 TosaTensorValuesGen.tvgLazyGenDefault,
3817 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003818 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003819 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003820 "error_if_validators": (
3821 TosaErrorValidator.evRankMismatch,
3822 TosaErrorValidator.evWrongInputType,
3823 TosaErrorValidator.evWrongOutputType,
3824 TosaErrorValidator.evWrongInputList,
3825 TosaErrorValidator.evWrongOutputList,
3826 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003827 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003828 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003829 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003830 "logical_left_shift": {
3831 "op": Op.LOGICAL_LEFT_SHIFT,
3832 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003833 "build_fcn": (
3834 build_binary_broadcast,
3835 TosaTensorGen.tgBroadcastFuzz,
3836 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003837 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003838 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003839 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003840 "error_if_validators": (
3841 TosaErrorValidator.evRankMismatch,
3842 TosaErrorValidator.evWrongInputType,
3843 TosaErrorValidator.evWrongOutputType,
3844 TosaErrorValidator.evWrongInputList,
3845 TosaErrorValidator.evWrongOutputList,
3846 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003847 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003848 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003849 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 "logical_right_shift": {
3851 "op": Op.LOGICAL_RIGHT_SHIFT,
3852 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003853 "build_fcn": (
3854 build_binary_broadcast,
3855 TosaTensorGen.tgBroadcastFuzz,
3856 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003857 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003858 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003859 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003860 "error_if_validators": (
3861 TosaErrorValidator.evRankMismatch,
3862 TosaErrorValidator.evWrongInputType,
3863 TosaErrorValidator.evWrongOutputType,
3864 TosaErrorValidator.evWrongInputList,
3865 TosaErrorValidator.evWrongOutputList,
3866 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003867 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003868 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003869 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003870 "logical_or": {
3871 "op": Op.LOGICAL_OR,
3872 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003873 "build_fcn": (
3874 build_binary_broadcast,
3875 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003876 TosaTensorValuesGen.tvgLazyGenDefault,
3877 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003878 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003879 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003880 "error_if_validators": (
3881 TosaErrorValidator.evRankMismatch,
3882 TosaErrorValidator.evWrongInputType,
3883 TosaErrorValidator.evWrongOutputType,
3884 TosaErrorValidator.evWrongInputList,
3885 TosaErrorValidator.evWrongOutputList,
3886 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003887 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003888 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003889 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 "logical_xor": {
3891 "op": Op.LOGICAL_XOR,
3892 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003893 "build_fcn": (
3894 build_binary_broadcast,
3895 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003896 TosaTensorValuesGen.tvgLazyGenDefault,
3897 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003898 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003899 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003900 "error_if_validators": (
3901 TosaErrorValidator.evRankMismatch,
3902 TosaErrorValidator.evWrongInputType,
3903 TosaErrorValidator.evWrongOutputType,
3904 TosaErrorValidator.evWrongInputList,
3905 TosaErrorValidator.evWrongOutputList,
3906 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003907 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003908 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003909 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003910 "maximum": {
3911 "op": Op.MAXIMUM,
3912 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003913 "build_fcn": (
3914 build_binary_broadcast,
3915 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003916 TosaTensorValuesGen.tvgLazyGenDefault,
3917 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003918 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003919 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003920 "error_if_validators": (
3921 TosaErrorValidator.evRankMismatch,
3922 TosaErrorValidator.evWrongInputType,
3923 TosaErrorValidator.evWrongOutputType,
3924 TosaErrorValidator.evWrongInputList,
3925 TosaErrorValidator.evWrongOutputList,
3926 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003927 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003928 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003929 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003931 "minimum": {
3932 "op": Op.MINIMUM,
3933 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003934 "build_fcn": (
3935 build_binary_broadcast,
3936 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003937 TosaTensorValuesGen.tvgLazyGenDefault,
3938 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003939 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003940 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003941 "error_if_validators": (
3942 TosaErrorValidator.evRankMismatch,
3943 TosaErrorValidator.evWrongInputType,
3944 TosaErrorValidator.evWrongOutputType,
3945 TosaErrorValidator.evWrongInputList,
3946 TosaErrorValidator.evWrongOutputList,
3947 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003948 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003949 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003950 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003951 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003952 "mul": {
3953 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003954 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003955 "build_fcn": (
3956 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003957 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003958 TosaTensorValuesGen.tvgMul,
3959 TosaArgGen.agMul,
3960 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003961 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003962 "error_if_validators": (
3963 TosaErrorValidator.evWrongInputType,
3964 TosaErrorValidator.evWrongOutputType,
3965 TosaErrorValidator.evWrongInputList,
3966 TosaErrorValidator.evWrongOutputList,
3967 TosaErrorValidator.evRankMismatch,
3968 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003969 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003970 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003971 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003972 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003973 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003974 "pow": {
3975 "op": Op.POW,
3976 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003977 "build_fcn": (
3978 build_binary_broadcast,
3979 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003980 TosaTensorValuesGen.tvgPow,
3981 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003982 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003983 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003984 "error_if_validators": (
3985 TosaErrorValidator.evRankMismatch,
3986 TosaErrorValidator.evWrongInputType,
3987 TosaErrorValidator.evWrongOutputType,
3988 TosaErrorValidator.evWrongInputList,
3989 TosaErrorValidator.evWrongOutputList,
3990 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003991 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003992 ),
evacha01ad8e1e22024-03-19 12:42:17 +00003993 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08003994 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003995 "sub": {
3996 "op": Op.SUB,
3997 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003998 "build_fcn": (
3999 build_binary_broadcast,
4000 TosaTensorGen.tgBroadcastFuzz,
4001 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004002 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004003 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004004 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004005 "error_if_validators": (
4006 TosaErrorValidator.evRankMismatch,
4007 TosaErrorValidator.evWrongInputType,
4008 TosaErrorValidator.evWrongOutputType,
4009 TosaErrorValidator.evWrongInputList,
4010 TosaErrorValidator.evWrongOutputList,
4011 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004012 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004013 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004014 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004015 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004016 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004017 "table": {
4018 "op": Op.TABLE,
4019 # Use the automatic generation functions to create the input array
4020 # but create the table tensor in the build function, as it may be
4021 # a different type from the input
4022 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004023 "build_fcn": (
4024 build_table,
4025 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004026 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004027 TosaArgGen.agTable,
4028 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004029 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004030 "error_if_validators": (
4031 TosaErrorValidator.evWrongInputType,
4032 TosaErrorValidator.evWrongOutputType,
4033 TosaErrorValidator.evWrongInputList,
4034 TosaErrorValidator.evWrongOutputList,
4035 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004036 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004037 # Elementwise Unary operators
4038 "abs": {
4039 "op": Op.ABS,
4040 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004041 "build_fcn": (
4042 build_unary,
4043 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004044 TosaTensorValuesGen.tvgLazyGenDefault,
4045 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004046 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004047 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004048 "error_if_validators": (
4049 TosaErrorValidator.evWrongInputType,
4050 TosaErrorValidator.evWrongOutputType,
4051 TosaErrorValidator.evWrongInputList,
4052 TosaErrorValidator.evWrongOutputList,
4053 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004054 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004055 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004056 "bitwise_not": {
4057 "op": Op.BITWISE_NOT,
4058 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004059 "build_fcn": (
4060 build_unary,
4061 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004062 TosaTensorValuesGen.tvgLazyGenDefault,
4063 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004064 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004065 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004066 "error_if_validators": (
4067 TosaErrorValidator.evWrongInputType,
4068 TosaErrorValidator.evWrongOutputType,
4069 TosaErrorValidator.evWrongInputList,
4070 TosaErrorValidator.evWrongOutputList,
4071 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004072 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004073 "ceil": {
4074 "op": Op.CEIL,
4075 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004076 "build_fcn": (
4077 build_unary,
4078 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004079 TosaTensorValuesGen.tvgLazyGenDefault,
4080 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004081 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004083 "error_if_validators": (
4084 TosaErrorValidator.evWrongInputType,
4085 TosaErrorValidator.evWrongOutputType,
4086 TosaErrorValidator.evWrongInputList,
4087 TosaErrorValidator.evWrongOutputList,
4088 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004089 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004090 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004091 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004092 "clz": {
4093 "op": Op.CLZ,
4094 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004095 "build_fcn": (
4096 build_unary,
4097 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004098 TosaTensorValuesGen.tvgLazyGenDefault,
4099 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004100 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004101 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004102 "error_if_validators": (
4103 TosaErrorValidator.evWrongInputType,
4104 TosaErrorValidator.evWrongOutputType,
4105 TosaErrorValidator.evWrongInputList,
4106 TosaErrorValidator.evWrongOutputList,
4107 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004108 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004109 "cos": {
4110 "op": Op.COS,
4111 "operands": (1, 0),
4112 "build_fcn": (
4113 build_unary,
4114 TosaTensorGen.tgBasic,
4115 TosaTensorValuesGen.tvgLazyGenDefault,
4116 TosaArgGen.agNone,
4117 ),
4118 "types": TYPE_FP,
4119 "error_if_validators": (
4120 TosaErrorValidator.evWrongInputType,
4121 TosaErrorValidator.evWrongOutputType,
4122 TosaErrorValidator.evWrongInputList,
4123 TosaErrorValidator.evWrongOutputList,
4124 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004125 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jerry Ge51bd4f52024-02-20 11:21:19 -08004126 "compliance": {"abs_error_normal_divisor": 2},
4127 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004128 "exp": {
4129 "op": Op.EXP,
4130 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004131 "build_fcn": (
4132 build_unary,
4133 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004134 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004135 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004136 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004137 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004138 "error_if_validators": (
4139 TosaErrorValidator.evWrongInputType,
4140 TosaErrorValidator.evWrongOutputType,
4141 TosaErrorValidator.evWrongInputList,
4142 TosaErrorValidator.evWrongOutputList,
4143 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004144 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004145 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004146 "floor": {
4147 "op": Op.FLOOR,
4148 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004149 "build_fcn": (
4150 build_unary,
4151 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004152 TosaTensorValuesGen.tvgLazyGenDefault,
4153 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004154 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004155 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004156 "error_if_validators": (
4157 TosaErrorValidator.evWrongInputType,
4158 TosaErrorValidator.evWrongOutputType,
4159 TosaErrorValidator.evWrongInputList,
4160 TosaErrorValidator.evWrongOutputList,
4161 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004162 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004163 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004165 "log": {
4166 "op": Op.LOG,
4167 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004168 "build_fcn": (
4169 build_unary,
4170 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004171 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004172 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004173 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004174 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004175 "error_if_validators": (
4176 TosaErrorValidator.evWrongInputType,
4177 TosaErrorValidator.evWrongOutputType,
4178 TosaErrorValidator.evWrongInputList,
4179 TosaErrorValidator.evWrongOutputList,
4180 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004181 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004182 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004183 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004184 "logical_not": {
4185 "op": Op.LOGICAL_NOT,
4186 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004187 "build_fcn": (
4188 build_unary,
4189 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004190 TosaTensorValuesGen.tvgLazyGenDefault,
4191 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004192 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004193 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004194 "error_if_validators": (
4195 TosaErrorValidator.evWrongInputType,
4196 TosaErrorValidator.evWrongOutputType,
4197 TosaErrorValidator.evWrongInputList,
4198 TosaErrorValidator.evWrongOutputList,
4199 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004200 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004201 "negate": {
4202 "op": Op.NEGATE,
4203 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004204 "build_fcn": (
4205 build_unary,
4206 TosaTensorGen.tgBasic,
4207 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004208 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004209 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004210 "qgen": TosaQuantGen.qgUnary,
4211 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004212 "error_if_validators": (
4213 TosaErrorValidator.evInputZeroPointNotZero,
4214 TosaErrorValidator.evOutputZeroPointNotZero,
4215 TosaErrorValidator.evWrongInputType,
4216 TosaErrorValidator.evWrongOutputType,
4217 TosaErrorValidator.evWrongInputList,
4218 TosaErrorValidator.evWrongOutputList,
4219 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004220 "data_gen": EW_UNARY_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004221 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004222 "reciprocal": {
4223 "op": Op.RECIPROCAL,
4224 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004225 "build_fcn": (
4226 build_unary,
4227 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004228 TosaTensorValuesGen.tvgLazyGenDefault,
4229 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004230 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004231 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004232 "error_if_validators": (
4233 TosaErrorValidator.evWrongInputType,
4234 TosaErrorValidator.evWrongOutputType,
4235 TosaErrorValidator.evWrongInputList,
4236 TosaErrorValidator.evWrongOutputList,
4237 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004238 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004239 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004240 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004241 "rsqrt": {
4242 "op": Op.RSQRT,
4243 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004244 "build_fcn": (
4245 build_unary,
4246 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004247 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004248 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004249 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004250 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004251 "error_if_validators": (
4252 TosaErrorValidator.evWrongInputType,
4253 TosaErrorValidator.evWrongOutputType,
4254 TosaErrorValidator.evWrongInputList,
4255 TosaErrorValidator.evWrongOutputList,
4256 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004257 "data_gen": EW_UNARY_DATAGEN,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004258 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004259 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004260 "sin": {
4261 "op": Op.SIN,
4262 "operands": (1, 0),
4263 "build_fcn": (
4264 build_unary,
4265 TosaTensorGen.tgBasic,
4266 TosaTensorValuesGen.tvgLazyGenDefault,
4267 TosaArgGen.agNone,
4268 ),
4269 "types": TYPE_FP,
4270 "error_if_validators": (
4271 TosaErrorValidator.evWrongInputType,
4272 TosaErrorValidator.evWrongOutputType,
4273 TosaErrorValidator.evWrongInputList,
4274 TosaErrorValidator.evWrongOutputList,
4275 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004276 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jerry Ge51bd4f52024-02-20 11:21:19 -08004277 "compliance": {"abs_error_normal_divisor": 2},
4278 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004279 # Elementwise Ternary operators
4280 "select": {
4281 "op": Op.SELECT,
4282 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004283 "build_fcn": (
4284 build_select,
4285 TosaTensorGen.tgBroadcastFuzz,
4286 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004287 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004288 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004289 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004290 "error_if_validators": (
4291 TosaErrorValidator.evRankMismatch,
4292 TosaErrorValidator.evWrongInputType,
4293 TosaErrorValidator.evWrongOutputType,
4294 TosaErrorValidator.evWrongInputList,
4295 TosaErrorValidator.evWrongOutputList,
4296 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004297 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004298 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004299 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004300 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004301 # Comparison operators
4302 "equal": {
4303 "op": Op.EQUAL,
4304 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004305 "build_fcn": (
4306 build_comparison,
4307 TosaTensorGen.tgBroadcastFuzz,
4308 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004309 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004310 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004311 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004312 "error_if_validators": (
4313 TosaErrorValidator.evRankMismatch,
4314 TosaErrorValidator.evWrongInputType,
4315 TosaErrorValidator.evWrongOutputType,
4316 TosaErrorValidator.evWrongInputList,
4317 TosaErrorValidator.evWrongOutputList,
4318 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004319 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004320 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004321 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004322 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004323 "greater_equal": {
4324 "op": Op.GREATER_EQUAL,
4325 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004326 "build_fcn": (
4327 build_comparison,
4328 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004329 TosaTensorValuesGen.tvgLazyGenDefault,
4330 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004331 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004332 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004333 "error_if_validators": (
4334 TosaErrorValidator.evRankMismatch,
4335 TosaErrorValidator.evWrongInputType,
4336 TosaErrorValidator.evWrongOutputType,
4337 TosaErrorValidator.evWrongInputList,
4338 TosaErrorValidator.evWrongOutputList,
4339 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004340 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004341 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004342 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004343 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004344 "greater": {
4345 "op": Op.GREATER,
4346 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004347 "build_fcn": (
4348 build_comparison,
4349 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004350 TosaTensorValuesGen.tvgLazyGenDefault,
4351 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004352 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004353 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004354 "error_if_validators": (
4355 TosaErrorValidator.evRankMismatch,
4356 TosaErrorValidator.evWrongInputType,
4357 TosaErrorValidator.evWrongOutputType,
4358 TosaErrorValidator.evWrongInputList,
4359 TosaErrorValidator.evWrongOutputList,
4360 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004361 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004362 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004363 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004364 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004365 # Reduction operators
4366 "reduce_all": {
4367 "op": Op.REDUCE_ALL,
4368 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004369 "build_fcn": (
4370 build_reduce,
4371 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004372 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004373 TosaArgGen.agAxis,
4374 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004375 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004376 "error_if_validators": (
4377 TosaErrorValidator.evAxisLargerRank,
4378 TosaErrorValidator.evAxisSmallerZero,
4379 TosaErrorValidator.evShapeOfAxisNotOne,
4380 TosaErrorValidator.evWrongInputType,
4381 TosaErrorValidator.evWrongOutputType,
4382 TosaErrorValidator.evWrongRank,
4383 TosaErrorValidator.evWrongInputList,
4384 TosaErrorValidator.evWrongOutputList,
4385 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004386 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004387 "reduce_any": {
4388 "op": Op.REDUCE_ANY,
4389 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004390 "build_fcn": (
4391 build_reduce,
4392 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004393 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004394 TosaArgGen.agAxis,
4395 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004396 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004397 "error_if_validators": (
4398 TosaErrorValidator.evAxisLargerRank,
4399 TosaErrorValidator.evAxisSmallerZero,
4400 TosaErrorValidator.evShapeOfAxisNotOne,
4401 TosaErrorValidator.evWrongInputType,
4402 TosaErrorValidator.evWrongOutputType,
4403 TosaErrorValidator.evWrongRank,
4404 TosaErrorValidator.evWrongInputList,
4405 TosaErrorValidator.evWrongOutputList,
4406 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004407 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004408 "reduce_max": {
4409 "op": Op.REDUCE_MAX,
4410 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004411 "build_fcn": (
4412 build_reduce,
4413 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004414 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004415 TosaArgGen.agAxis,
4416 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004417 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004418 "error_if_validators": (
4419 TosaErrorValidator.evAxisLargerRank,
4420 TosaErrorValidator.evAxisSmallerZero,
4421 TosaErrorValidator.evShapeOfAxisNotOne,
4422 TosaErrorValidator.evWrongInputType,
4423 TosaErrorValidator.evWrongOutputType,
4424 TosaErrorValidator.evWrongRank,
4425 TosaErrorValidator.evWrongInputList,
4426 TosaErrorValidator.evWrongOutputList,
4427 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004428 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004429 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004430 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004431 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004432 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004433 "build_fcn": (
4434 build_reduce,
4435 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004436 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004437 TosaArgGen.agAxis,
4438 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004439 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004440 "error_if_validators": (
4441 TosaErrorValidator.evAxisLargerRank,
4442 TosaErrorValidator.evAxisSmallerZero,
4443 TosaErrorValidator.evShapeOfAxisNotOne,
4444 TosaErrorValidator.evWrongInputType,
4445 TosaErrorValidator.evWrongOutputType,
4446 TosaErrorValidator.evWrongRank,
4447 TosaErrorValidator.evWrongInputList,
4448 TosaErrorValidator.evWrongOutputList,
4449 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004450 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004451 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004452 "reduce_product": {
4453 "op": Op.REDUCE_PRODUCT,
4454 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004455 "build_fcn": (
4456 build_reduce,
4457 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004458 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004459 TosaArgGen.agAxis,
4460 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004461 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004462 "error_if_validators": (
4463 TosaErrorValidator.evAxisLargerRank,
4464 TosaErrorValidator.evAxisSmallerZero,
4465 TosaErrorValidator.evShapeOfAxisNotOne,
4466 TosaErrorValidator.evWrongInputType,
4467 TosaErrorValidator.evWrongOutputType,
4468 TosaErrorValidator.evWrongRank,
4469 TosaErrorValidator.evWrongInputList,
4470 TosaErrorValidator.evWrongOutputList,
4471 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004472 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004473 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004474 "reduce_sum": {
4475 "op": Op.REDUCE_SUM,
4476 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004477 "build_fcn": (
4478 build_reduce,
4479 TosaTensorGen.tgBasic,
4480 TosaTensorValuesGen.tvgReduceSum,
4481 TosaArgGen.agAxis,
4482 ),
James Ward24dbc422022-10-19 12:20:31 +01004483 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004484 "error_if_validators": (
4485 TosaErrorValidator.evAxisLargerRank,
4486 TosaErrorValidator.evAxisSmallerZero,
4487 TosaErrorValidator.evShapeOfAxisNotOne,
4488 TosaErrorValidator.evWrongInputType,
4489 TosaErrorValidator.evWrongOutputType,
4490 TosaErrorValidator.evWrongRank,
4491 TosaErrorValidator.evWrongInputList,
4492 TosaErrorValidator.evWrongOutputList,
4493 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004494 "data_gen": DOT_PRODUCT_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004495 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004496 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004497 "concat": {
4498 "op": Op.CONCAT,
4499 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004500 "build_fcn": (
4501 build_concat,
4502 TosaTensorGen.tgConcat,
4503 TosaTensorValuesGen.tvgConcat,
4504 TosaArgGen.agAxis,
4505 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004506 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004507 "error_if_validators": (
4508 TosaErrorValidator.evAxisLargerRank,
4509 TosaErrorValidator.evAxisSmallerZero,
4510 TosaErrorValidator.evConcatInputRankMismatch,
4511 TosaErrorValidator.evConcatShapeSumMismatch,
4512 TosaErrorValidator.evConcatInputDimMismatch,
4513 TosaErrorValidator.evWrongInputType,
4514 TosaErrorValidator.evWrongOutputType,
4515 TosaErrorValidator.evWrongOutputList,
4516 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004517 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004518 },
4519 "pad": {
4520 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004521 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004522 "build_fcn": (
4523 build_pad,
4524 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004525 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004526 TosaArgGen.agPad,
4527 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004528 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004529 "error_if_validators": (
4530 TosaErrorValidator.evWrongInputType,
4531 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004532 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004533 TosaErrorValidator.evWrongOutputType,
4534 TosaErrorValidator.evWrongInputList,
4535 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004536 TosaErrorValidator.evRankMismatch,
4537 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004538 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004539 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004540 },
Won Jeona21b2e82023-08-10 10:33:01 +00004541 "dim": {
4542 "op": Op.DIM,
4543 "operands": (1, 0),
4544 "build_fcn": (
4545 build_dim,
4546 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004547 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004548 TosaArgGen.agAxis,
4549 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004550 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004551 "error_if_validators": (
4552 TosaErrorValidator.evAxisLargerRank,
4553 TosaErrorValidator.evAxisSmallerZero,
4554 TosaErrorValidator.evWrongInputType,
4555 TosaErrorValidator.evWrongInputList,
4556 TosaErrorValidator.evWrongOutputList,
4557 TosaErrorValidator.evWrongRank,
4558 ),
4559 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004560 "reshape": {
4561 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004562 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004563 "build_fcn": (
4564 build_reshape,
4565 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004566 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004567 TosaArgGen.agReshape,
4568 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004569 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004570 "error_if_validators": (
4571 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4572 TosaErrorValidator.evWrongInputType,
4573 TosaErrorValidator.evWrongOutputType,
4574 TosaErrorValidator.evWrongInputList,
4575 TosaErrorValidator.evWrongOutputList,
4576 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004577 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004578 },
4579 "reverse": {
4580 "op": Op.REVERSE,
4581 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004582 "build_fcn": (
4583 build_reverse,
4584 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004585 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004586 TosaArgGen.agAxis,
4587 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004588 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004589 "error_if_validators": (
4590 TosaErrorValidator.evAxisSmallerZero,
4591 TosaErrorValidator.evAxisLargerRank,
4592 TosaErrorValidator.evWrongInputType,
4593 TosaErrorValidator.evWrongOutputType,
4594 TosaErrorValidator.evWrongInputList,
4595 TosaErrorValidator.evWrongOutputList,
4596 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004597 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004598 },
4599 "slice": {
4600 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004601 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004602 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004603 "build_fcn": (
4604 build_slice,
4605 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004606 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004607 TosaArgGen.agSlice,
4608 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004609 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004610 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004611 # TODO Turn off these error categories for now as the reference
4612 # model cannot allocate memory space for empty tensor. We probably
4613 # can report an accurate error messege at the right place during
4614 # exeuction.
4615 # TosaErrorValidator.evStartSmallerZero,
4616 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004617 TosaErrorValidator.evStartSizeOutsideBounds,
4618 TosaErrorValidator.evSizeOutputShapeMismatch,
4619 TosaErrorValidator.evInputSizeStartLengthMismatch,
4620 TosaErrorValidator.evWrongRank,
4621 TosaErrorValidator.evWrongInputType,
4622 TosaErrorValidator.evWrongOutputType,
4623 TosaErrorValidator.evWrongInputList,
4624 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004625 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004626 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004627 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004628 },
4629 "tile": {
4630 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004631 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004632 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004633 "build_fcn": (
4634 build_tile,
4635 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004636 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004637 TosaArgGen.agTile,
4638 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004639 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004640 "error_if_validators": (
4641 TosaErrorValidator.evWrongInputType,
4642 TosaErrorValidator.evWrongOutputType,
4643 TosaErrorValidator.evWrongInputList,
4644 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004645 TosaErrorValidator.evRankMismatch,
4646 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004647 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004648 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004649 },
4650 "transpose": {
4651 "op": Op.TRANSPOSE,
4652 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004653 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004654 "build_fcn": (
4655 build_transpose,
4656 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004657 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004658 TosaArgGen.agTranspose,
4659 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004660 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004661 "error_if_validators": (
4662 TosaErrorValidator.evIndexOutsideBounds,
4663 TosaErrorValidator.evIndexUsedTwice,
4664 TosaErrorValidator.evWrongInputType,
4665 TosaErrorValidator.evWrongOutputType,
4666 TosaErrorValidator.evWrongInputList,
4667 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004668 TosaErrorValidator.evWrongRank,
4669 TosaErrorValidator.evRankMismatch,
4670 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004671 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004672 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004673 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004674 # Data nodes
4675 "const": {
4676 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004677 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004678 "build_fcn": (
4679 build_const,
4680 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004681 TosaTensorValuesGen.tvgLazyGenDefault,
4682 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004683 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004684 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha01ad8e1e22024-03-19 12:42:17 +00004685 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004686 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004687 "identity": {
4688 "op": Op.IDENTITY,
4689 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004690 "build_fcn": (
4691 build_unary,
4692 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004693 TosaTensorValuesGen.tvgLazyGenDefault,
4694 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004695 ),
evacha011adff832024-03-06 17:33:44 +00004696 "types": TYPE_FIB + [DType.INT4, DType.INT48],
evacha01ad8e1e22024-03-19 12:42:17 +00004697 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004698 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004699 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004700 "gather": {
4701 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004702 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004703 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004704 "build_fcn": (
4705 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004706 TosaTensorGen.tgGather,
4707 TosaTensorValuesGen.tvgGather,
4708 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004709 ),
James Ward24dbc422022-10-19 12:20:31 +01004710 "types": (
4711 DType.INT8,
4712 DType.INT16,
4713 DType.INT32,
4714 DType.FP16,
4715 DType.BF16,
4716 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004717 DType.FP8E4M3,
4718 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004719 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004720 "error_if_validators": (
4721 TosaErrorValidator.evWrongInputType,
4722 TosaErrorValidator.evWrongOutputType,
4723 TosaErrorValidator.evWrongInputList,
4724 TosaErrorValidator.evWrongOutputList,
4725 TosaErrorValidator.evWrongRank,
4726 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004727 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004728 },
4729 "scatter": {
4730 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004731 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004732 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004733 "build_fcn": (
4734 build_scatter,
4735 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004736 TosaTensorValuesGen.tvgScatter,
4737 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004738 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004739 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004740 "error_if_validators": (
4741 TosaErrorValidator.evWrongInputType,
4742 TosaErrorValidator.evWrongOutputType,
4743 TosaErrorValidator.evWrongInputList,
4744 TosaErrorValidator.evWrongOutputList,
4745 TosaErrorValidator.evWrongRank,
4746 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004747 "data_gen": PSEUDO_RANDOM_DATAGEN,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004748 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004749 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004750 "resize": {
4751 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004752 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004753 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004754 "build_fcn": (
4755 build_resize,
4756 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004757 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004758 TosaArgGen.agResize,
4759 ),
James Ward24dbc422022-10-19 12:20:31 +01004760 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004761 "invalid_test_validators": (
4762 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004763 ),
4764 "error_if_validators": (
4765 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004766 TosaErrorValidator.evScaleSmallerEqualZero,
4767 TosaErrorValidator.evScaleNLargerMax,
4768 TosaErrorValidator.evScaleDLargerMax,
4769 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004770 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004771 TosaErrorValidator.evBorderSmallerMin,
4772 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004773 TosaErrorValidator.evWrongInputType,
4774 TosaErrorValidator.evWrongOutputType,
4775 TosaErrorValidator.evWrongRank,
4776 TosaErrorValidator.evWrongInputList,
4777 TosaErrorValidator.evWrongOutputList,
4778 TosaErrorValidator.evBatchMismatch,
4779 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004780 TosaErrorValidator.evResizeOutputShapeMismatch,
4781 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004782 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004783 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004784 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004785 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004786 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004787 "cast": {
4788 "op": Op.CAST,
4789 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004790 "build_fcn": (
4791 build_cast,
4792 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004793 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004794 TosaArgGen.agCast,
4795 ),
James Ward8b390432022-08-12 20:48:56 +01004796 "types": (
4797 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004798 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004799 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004800 DType.INT8,
4801 DType.INT16,
4802 DType.INT32,
4803 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004804 DType.FP8E4M3,
4805 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004806 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004807 "error_if_validators": (
4808 TosaErrorValidator.evWrongInputType,
4809 TosaErrorValidator.evWrongOutputType,
4810 TosaErrorValidator.evWrongInputList,
4811 TosaErrorValidator.evWrongOutputList,
4812 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004813 "data_gen": PSEUDO_RANDOM_DATAGEN,
Jeremy Johnson708da822023-11-15 16:25:45 +00004814 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004815 },
4816 "rescale": {
4817 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004818 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004819 "build_fcn": (
4820 build_rescale,
4821 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004822 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004823 TosaArgGen.agRescale,
4824 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004825 "types": [
4826 DType.UINT8,
4827 DType.INT8,
4828 DType.INT16,
4829 DType.INT32,
4830 DType.INT48,
4831 DType.UINT16,
4832 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004833 "error_if_validators": (
4834 TosaErrorValidator.evInputZeroPointNotZero,
4835 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004836 TosaErrorValidator.evU16InputZeroPointNotValid,
4837 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004838 TosaErrorValidator.evScaleTrue,
4839 TosaErrorValidator.evScaleNotTrue,
4840 TosaErrorValidator.evWrongInputType,
4841 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004842 TosaErrorValidator.evWrongInputList,
4843 TosaErrorValidator.evWrongOutputList,
4844 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004845 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004846 # Custom
4847 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004848 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004849 # Two varients of cond_if, one that generates one of two constant tensors (no
4850 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4851 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004852 "cond_if_const": {
4853 "op": Op.COND_IF,
4854 "operands": (0, 2),
4855 "build_fcn": (
4856 build_cond_if_const,
4857 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004858 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004859 TosaArgGen.agCondIf,
4860 ),
4861 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004862 "error_if_validators": (
4863 TosaErrorValidator.evOutputListThenGraphMismatch,
4864 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004865 TosaErrorValidator.evCondIfCondNotMatchingBool,
4866 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004867 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004868 },
4869 "cond_if_binary": {
4870 "op": Op.COND_IF,
4871 "operands": (2, 0),
4872 "build_fcn": (
4873 build_cond_if_binary,
4874 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004875 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004876 TosaArgGen.agCondIf,
4877 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004878 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004879 "error_if_validators": (
4880 TosaErrorValidator.evInputListThenGraphMismatch,
4881 TosaErrorValidator.evInputListElseGraphMismatch,
4882 TosaErrorValidator.evOutputListThenGraphMismatch,
4883 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004884 TosaErrorValidator.evCondIfCondNotMatchingBool,
4885 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004886 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004887 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004888 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004889 "while_loop": {
4890 "op": Op.WHILE_LOOP,
4891 "operands": (0, 1),
4892 "build_fcn": (
4893 build_while_loop,
4894 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004895 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004896 TosaArgGen.agWhileLoop,
4897 ),
4898 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004899 "error_if_validators": (
4900 TosaErrorValidator.evInputListOutputListMismatch,
4901 TosaErrorValidator.evInputListCondGraphMismatch,
4902 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4903 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4904 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004905 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004906 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004907 },
Luke Hutton57287132023-02-06 14:54:18 +00004908 "fft2d": {
4909 "op": Op.FFT2D,
4910 "operands": (2, 0),
4911 "rank": (3, 3),
4912 "build_fcn": (
4913 build_fft2d,
4914 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004915 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004916 TosaArgGen.agFFT2d,
4917 ),
4918 "types": [DType.FP32],
4919 "error_if_validators": (
4920 TosaErrorValidator.evWrongInputType,
4921 TosaErrorValidator.evWrongOutputType,
4922 TosaErrorValidator.evWrongInputList,
4923 TosaErrorValidator.evWrongOutputList,
4924 TosaErrorValidator.evWrongRank,
4925 TosaErrorValidator.evBatchMismatch,
4926 TosaErrorValidator.evKernelNotPowerOfTwo,
4927 TosaErrorValidator.evFFTInputShapeMismatch,
4928 TosaErrorValidator.evFFTOutputShapeMismatch,
4929 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004930 "data_gen": DOT_PRODUCT_DATAGEN,
Luke Hutton57287132023-02-06 14:54:18 +00004931 },
Luke Hutton261b7b62023-01-10 14:50:31 +00004932 "rfft2d": {
4933 "op": Op.RFFT2D,
4934 "operands": (1, 0),
4935 "rank": (3, 3),
4936 "build_fcn": (
4937 build_rfft2d,
4938 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00004939 TosaTensorValuesGen.tvgLazyGenDefault,
4940 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00004941 ),
4942 "types": [DType.FP32],
4943 "error_if_validators": (
4944 TosaErrorValidator.evWrongInputType,
4945 TosaErrorValidator.evWrongOutputType,
4946 TosaErrorValidator.evWrongInputList,
4947 TosaErrorValidator.evWrongOutputList,
4948 TosaErrorValidator.evWrongRank,
4949 TosaErrorValidator.evBatchMismatch,
4950 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00004951 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00004952 ),
evacha01ad8e1e22024-03-19 12:42:17 +00004953 "data_gen": DOT_PRODUCT_DATAGEN,
Luke Hutton261b7b62023-01-10 14:50:31 +00004954 },
Won Jeon74342e52024-01-09 00:34:40 +00004955 # Shape
4956 "add_shape": {
4957 "op": Op.ADD_SHAPE,
4958 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004959 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004960 "build_fcn": (
4961 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004962 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004963 TosaTensorValuesGen.tvgAddSub,
4964 TosaArgGen.agNone,
4965 ),
4966 "types": [DType.SHAPE],
4967 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4968 },
4969 "sub_shape": {
4970 "op": Op.SUB_SHAPE,
4971 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004972 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004973 "build_fcn": (
4974 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004975 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004976 TosaTensorValuesGen.tvgAddSub,
4977 TosaArgGen.agNone,
4978 ),
4979 "types": [DType.SHAPE],
4980 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4981 },
4982 "mul_shape": {
4983 "op": Op.MUL_SHAPE,
4984 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004985 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004986 "build_fcn": (
4987 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004988 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00004989 TosaTensorValuesGen.tvgMul,
4990 TosaArgGen.agNone,
4991 ),
4992 "types": [DType.SHAPE],
4993 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
4994 },
4995 "div_shape": {
4996 "op": Op.DIV_SHAPE,
4997 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00004998 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00004999 "build_fcn": (
5000 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005001 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005002 TosaTensorValuesGen.tvgIntDiv,
5003 TosaArgGen.agNone,
5004 ),
5005 "types": [DType.SHAPE],
5006 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5007 },
5008 "concat_shape": {
5009 "op": Op.CONCAT_SHAPE,
5010 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005011 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005012 "build_fcn": (
5013 build_concat,
5014 TosaTensorGen.tgConcat,
5015 TosaTensorValuesGen.tvgConcat,
5016 TosaArgGen.agNone,
5017 ),
5018 "types": [DType.SHAPE],
5019 "error_if_validators": (),
5020 },
5021 "const_shape": {
5022 "op": Op.CONST_SHAPE,
5023 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005024 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005025 "build_fcn": (
5026 build_const,
5027 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005028 TosaTensorValuesGen.tvgLazyGenDefault,
5029 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005030 ),
5031 "types": [DType.SHAPE],
5032 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005033 }
5034
Kevin Cheng550ccc52021-03-03 11:21:43 -08005035
Eric Kunzee5e26762020-10-13 16:11:07 -07005036class OutputShaper:
5037 # Methods in this class compute the expected output shape and datatype
5038 # for common classes of operations
5039 def __init__(self):
5040 pass
5041
5042 # These methods return arguments that can be used for
5043 # creating a new output tensor
5044 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005045 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5046 if error_name != ErrorIf.RankMismatch:
5047 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005048 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005049
5050 shape = []
5051 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005052 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005053 shape.append(b.shape[i])
5054 else:
5055 shape.append(a.shape[i])
5056
Jerry Ge135c9552023-05-23 20:59:32 +00005057 fuzz_idx = rng.integers(0, len(a.shape))
5058 if error_name == ErrorIf.DimensionMismatch:
5059 shape[fuzz_idx] += 1
5060
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005061 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005062 all_dtypes = [
5063 DType.INT8,
5064 DType.INT16,
5065 DType.INT32,
5066 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005067 DType.FP16,
5068 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005069 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005070 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005071 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5072 outputDType = rng.choice(wrong_dtypes)
5073 else:
5074 outputDType = a.dtype
5075
5076 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005077
5078 @staticmethod
5079 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005080 assert len(a.shape) == len(b.shape)
5081 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005082
5083 shape = []
5084 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005085 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005086 shape.append(a.shape[i])
5087
Kevin Cheng550ccc52021-03-03 11:21:43 -08005088 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005089
5090 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005091 def unaryOp(ser, rng, a, error_name=None):
5092 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005093 all_dtypes = [
5094 DType.INT8,
5095 DType.INT16,
5096 DType.INT32,
5097 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005098 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005099 DType.FP16,
5100 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005101 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005102 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5103 outputDType = rng.choice(wrong_dtypes)
5104 else:
5105 outputDType = a.dtype
5106
5107 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005108
5109 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005110 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005111 if error_name != ErrorIf.RankMismatch:
5112 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005113 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005114
5115 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005116 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005117 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005118 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5119 else:
5120 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005121
Jerry Ge135c9552023-05-23 20:59:32 +00005122 fuzz_idx = rng.integers(0, len(a.shape))
5123 if error_name == ErrorIf.DimensionMismatch:
5124 shape[fuzz_idx] += 1
5125
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005126 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005127 all_dtypes = [
5128 DType.INT8,
5129 DType.INT16,
5130 DType.INT32,
5131 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005132 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005133 DType.FP16,
5134 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005135 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005136 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5137 outputDType = rng.choice(wrong_dtypes)
5138 else:
5139 outputDType = a.dtype
5140
5141 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005142
5143 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005144 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005145 if error_name != ErrorIf.RankMismatch:
5146 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005147 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005148
5149 # Do broadcast
5150 shape = []
5151 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005152 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005153 shape.append(b.shape[i])
5154 else:
5155 shape.append(a.shape[i])
5156
Jerry Ge135c9552023-05-23 20:59:32 +00005157 fuzz_idx = rng.integers(0, len(a.shape))
5158 if error_name == ErrorIf.DimensionMismatch:
5159 shape[fuzz_idx] += 1
5160
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005161 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005162 wrong_dtypes = [
5163 DType.INT8,
5164 DType.INT16,
5165 DType.INT32,
5166 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005167 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005168 DType.FP16,
5169 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005170 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005171 outputDType = rng.choice(wrong_dtypes)
5172 else:
5173 outputDType = DType.BOOL
5174
5175 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005176
5177 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005178 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005179 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005180 if error_name not in [
5181 ErrorIf.AxisSmallerZero,
5182 ErrorIf.AxisLargerRank,
5183 ErrorIf.ShapeOfAxisNotOne,
5184 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005185 shape[axis] = 1
5186 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5187 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005188
Matthew Haddond6ce7252021-09-29 15:35:44 +01005189 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005190 all_dtypes = [
5191 DType.INT8,
5192 DType.INT16,
5193 DType.INT32,
5194 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005195 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005196 DType.FP16,
5197 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005198 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005199 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5200 outputDType = rng.choice(wrong_dtypes)
5201 else:
5202 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005203
Matthew Haddond6ce7252021-09-29 15:35:44 +01005204 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005205
5206 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005207 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005208 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005209
5210 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5211 del shape[axis]
5212
5213 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5214 remove = rng.choice([True, False])
5215 if remove and len(shape) > 1:
5216 del shape[0]
5217 else:
5218 shape.append(1)
5219 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5220 for i in range(len(shape)):
5221 shape[i] = shape[i] + rng.integers(1, 10)
5222
5223 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005224 all_dtypes = [
5225 DType.INT8,
5226 DType.INT16,
5227 DType.INT32,
5228 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005229 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005230 DType.FP16,
5231 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005232 DType.FP8E4M3,
5233 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005234 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005235 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5236 outputDType = rng.choice(wrong_dtypes)
5237 else:
5238 outputDType = DType.INT32
5239
5240 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005241
5242 @staticmethod
Tai Lyf36f2562024-03-14 16:21:29 +00005243 def _get_conv_output_type(input_dtype):
5244 if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
5245 return input_dtype
5246 elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
5247 return DType.FP16
5248 elif input_dtype in (DType.INT8, DType.INT4):
5249 return DType.INT32
5250 elif input_dtype in (DType.INT16,):
5251 return DType.INT48
5252 assert True, f"Unsupported convolution data type {input_dtype}"
5253
5254 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005255 def conv2dOp(
5256 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5257 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005258
5259 # IFM: NHWC
5260 # Filter: OHWI
5261 # OFM: NHWC
5262
Kevin Cheng550ccc52021-03-03 11:21:43 -08005263 h = (
5264 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005265 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005266 + padding[0]
5267 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005268 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005269 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005270
Kevin Cheng550ccc52021-03-03 11:21:43 -08005271 w = (
5272 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005273 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005274 + padding[2]
5275 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005276 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005277 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005278
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005279 if error_name == ErrorIf.ConvOutputShapeMismatch:
5280 choices = [1, 2, 3]
5281 change = rng.choice(choices)
5282 # increment in multiples of stride to not hit non-integer error case
5283 if change in [1, 3]:
5284 h = h + (rng.choice(choices) * strides[0])
5285 if change in [2, 3]:
5286 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005287
Eric Kunzee5e26762020-10-13 16:11:07 -07005288 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5289
James Ward8b390432022-08-12 20:48:56 +01005290 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005291 # Pick some potentially correct output dtype if input type is incorrect
5292 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005293 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005294 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005295
5296 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005297 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005298 excludes = [DType.FP16, DType.FP32]
Jeremy Johnson80fd9b82024-03-12 11:46:50 +00005299 elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
Won Jeon2c34b462024-02-06 18:37:00 +00005300 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005301 else:
5302 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005303 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005304 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005305
Kevin Cheng550ccc52021-03-03 11:21:43 -08005306 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005307
5308 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005309 def conv3dOp(
5310 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5311 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005312
5313 # IFM: NDHWC
5314 # Filter: ODHWI
5315 # OFM: NDHWC
5316
5317 d = (
5318 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005319 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005320 + padding[0]
5321 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005322 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005323 ) // strides[0] + 1
5324
5325 h = (
5326 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005327 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005328 + padding[2]
5329 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005330 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005331 ) // strides[1] + 1
5332
5333 w = (
5334 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005335 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005336 + padding[4]
5337 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005338 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005339 ) // strides[2] + 1
5340
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005341 if error_name == ErrorIf.ConvOutputShapeMismatch:
5342 choices = [1, 2, 3, 4]
5343 change = rng.choice(choices)
5344 # increment in multiples of stride to not hit non-integer error case
5345 if change in [1, 4]:
5346 d = d + (rng.choice(choices) * strides[0])
5347 if change in [2, 4]:
5348 h = h + (rng.choice(choices) * strides[1])
5349 if change in [3, 4]:
5350 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005351
Kevin Cheng1533b852021-09-01 12:51:58 -07005352 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5353
James Ward8b390432022-08-12 20:48:56 +01005354 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005355 # Pick some potentially correct output dtype if input type is incorrect
5356 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005357 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005358 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005359
5360 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005361 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005362 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005363 else:
5364 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005365 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005366 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005367
5368 return ser.addOutput(ofm_shape, out_dtype)
5369
5370 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005371 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005372 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005373 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005374 # IFM: NHWC
5375 # Filter: HWCM
5376 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005377
Kevin Cheng550ccc52021-03-03 11:21:43 -08005378 h = (
5379 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005380 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005381 + padding[0]
5382 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005383 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005384 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005385
Kevin Cheng550ccc52021-03-03 11:21:43 -08005386 w = (
5387 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005388 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005389 + padding[2]
5390 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005391 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005392 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005393
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005394 if error_name == ErrorIf.ConvOutputShapeMismatch:
5395 choices = [1, 2, 3]
5396 change = rng.choice(choices)
5397 # increment in multiples of stride to not hit non-integer error case
5398 if change in [1, 3]:
5399 h = h + (rng.choice(choices) * strides[0])
5400 if change in [2, 3]:
5401 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005402
Eric Kunzee5e26762020-10-13 16:11:07 -07005403 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5404
James Ward8b390432022-08-12 20:48:56 +01005405 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005406 # Pick some potentially correct output dtype if input type is incorrect
5407 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005408 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005409 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005410
5411 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005412 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005413 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005414 else:
5415 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005416 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005417 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005418
Kevin Cheng550ccc52021-03-03 11:21:43 -08005419 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005420
5421 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005422 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005423 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005424 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005425 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005426 h = 1
5427 w = 1
5428 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005429 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5430 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005431
5432 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005433 choices = [1, 2, 3]
5434 change = rng.choice(choices)
5435 # increment in multiples of stride to not hit non-integer error case
5436 if change in [1, 3]:
5437 h = h + (rng.choice(choices) * stride[0])
5438 if change in [2, 3]:
5439 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005440 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005441
5442 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005443 all_dtypes = [
5444 DType.INT8,
5445 DType.INT16,
5446 DType.INT32,
5447 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005448 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005449 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005450 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005451 DType.FP8E4M3,
5452 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005453 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005454 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5455 outputDType = rng.choice(wrong_dtypes)
5456 else:
5457 outputDType = ifm.dtype
5458
5459 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005460
5461 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005462 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005463 # input: N, IC
5464 # filter: OC, IC
5465 # output: N, OC
5466
5467 output_shape = [input.shape[0], filter.shape[0]]
5468
James Ward8b390432022-08-12 20:48:56 +01005469 # Validated in arg_gen (also invalidated for ErrorIf)
5470 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005471
Kevin Cheng550ccc52021-03-03 11:21:43 -08005472 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005473
5474 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005475 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005476 # a: N, H, C
5477 # b: N, C, W
5478 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005479
Kevin Cheng2d60f002021-06-09 14:18:32 -07005480 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005481
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005482 if error_name == ErrorIf.WrongOutputType:
5483 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005484 incorrect_types = (
5485 DType.INT4,
5486 DType.INT8,
5487 DType.INT16,
5488 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005489 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005490 DType.FP16,
5491 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005492 DType.FP8E4M3,
5493 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005494 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005495 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005496 incorrect_types = (
5497 DType.INT4,
5498 DType.INT8,
5499 DType.INT16,
5500 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005501 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005502 DType.FP16,
5503 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005504 DType.FP8E4M3,
5505 DType.FP8E5M2,
5506 )
5507 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5508 incorrect_types = (
5509 DType.INT4,
5510 DType.INT8,
5511 DType.INT16,
5512 DType.INT32,
5513 DType.INT48,
5514 DType.FP32,
5515 DType.BF16,
5516 DType.FP8E4M3,
5517 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005518 )
James Ward24dbc422022-10-19 12:20:31 +01005519 elif (
5520 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5521 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005522 incorrect_types = (
5523 DType.INT4,
5524 DType.INT8,
5525 DType.INT16,
5526 DType.INT32,
5527 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005528 DType.FP8E4M3,
5529 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005530 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005531 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005532 elif error_name == ErrorIf.WrongInputType:
5533 # Pick some potentially correct output dtype if input type is incorrect
5534 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005535 else:
James Ward8b390432022-08-12 20:48:56 +01005536 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005537
Kevin Cheng550ccc52021-03-03 11:21:43 -08005538 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005539
5540 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005541 def concatOp(ser, rng, axis, inputs, error_name=None):
5542 input1 = inputs[0]
5543 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005544
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005545 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005546 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005547 if not (
5548 # unable to concat tensors of different ranks
5549 error_name == ErrorIf.ConcatInputRankMismatch
5550 # unable to concat tensors along an invalid axis
5551 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005552 ):
5553 for tensor in remaining_inputs:
5554 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005555
Matthew Haddon01c359d2021-10-15 16:30:48 +01005556 if error_name == ErrorIf.ConcatShapeSumMismatch:
5557 output_shape[axis] += rng.integers(5, 10)
5558
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005559 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005560 all_dtypes = {
5561 DType.INT8,
5562 DType.INT16,
5563 DType.INT32,
5564 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005565 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005566 DType.FP16,
5567 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005568 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005569 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5570 outputDType = rng.choice(wrong_dtypes)
5571 else:
5572 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005573
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005574 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005575
5576 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005577 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005578
5579 output_shape = a.shape.copy()
5580
5581 for i in range(len(output_shape)):
5582 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5583
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005584 if error_name == ErrorIf.PadOutputShapeMismatch:
5585 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005586 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005587 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005588 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005589
Matthew Haddone807aae2021-10-11 18:12:58 +01005590 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005591 all_dtypes = [
5592 DType.INT8,
5593 DType.INT16,
5594 DType.INT32,
5595 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005596 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005597 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005598 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005599 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005600 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5601 outputDType = rng.choice(wrong_dtypes)
5602 else:
5603 outputDType = a.dtype
5604
5605 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005606
5607 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005608 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005609 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005610
5611 if error_name == ErrorIf.WrongOutputType:
5612 all_dtypes = [
5613 DType.INT8,
5614 DType.INT16,
5615 DType.INT32,
5616 DType.INT48,
5617 DType.FP32,
5618 DType.FP16,
5619 DType.BF16,
5620 ]
5621 wrong_dtypes = list(set(all_dtypes))
5622 outputDType = rng.choice(wrong_dtypes)
5623 else:
5624 outputDType = DType.SHAPE
5625
5626 return ser.addOutput(output_shape, outputDType)
5627
5628 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005629 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005630 output_shape = shape.copy()
5631
Matthew Haddone807aae2021-10-11 18:12:58 +01005632 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5633 for i in range(len(output_shape)):
5634 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5635
5636 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005637 all_dtypes = [
5638 DType.INT8,
5639 DType.INT16,
5640 DType.INT32,
5641 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005642 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005643 DType.FP16,
5644 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005645 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005646 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5647 outputDType = rng.choice(wrong_dtypes)
5648 else:
5649 outputDType = a.dtype
5650
5651 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005652
5653 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005654 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005655
Matthew Haddone807aae2021-10-11 18:12:58 +01005656 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005657 all_dtypes = [
5658 DType.INT8,
5659 DType.INT16,
5660 DType.INT32,
5661 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005662 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005663 DType.FP16,
5664 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005665 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005666 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005667 outputDType = rng.choice(wrong_dtypes)
5668 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005669 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005670
Luke Huttona4e48ca2023-02-22 11:53:48 +00005671 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005672 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005673 for index in range(len(output_shape)):
5674 if output_shape[index] <= 2:
5675 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5676 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005677 output_shape[index] = output_shape[index] + rng.choice(
5678 [-2, -1, 1, 2]
5679 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005680 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5681 output_shape = input.shape.copy()
5682 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005683 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005684
5685 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005686
5687 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005688 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005689
5690 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005691 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005692
5693 for i in range(len(output_shape)):
5694 output_shape[i] = a.shape[i] * multiples[i]
5695
Luke Huttona4e48ca2023-02-22 11:53:48 +00005696 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005697 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005698
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005699 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005700 all_dtypes = [
5701 DType.INT8,
5702 DType.INT16,
5703 DType.INT32,
5704 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005705 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005706 DType.FP16,
5707 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005708 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005709 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5710 outputDType = rng.choice(wrong_dtypes)
5711 else:
5712 outputDType = a.dtype
5713
5714 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005715
5716 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005717 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005718 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005719
Kevin Cheng550ccc52021-03-03 11:21:43 -08005720 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005721
Luke Huttona4e48ca2023-02-22 11:53:48 +00005722 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005723 for i in range(len(output_shape)):
5724 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005725
Luke Huttona4e48ca2023-02-22 11:53:48 +00005726 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5727 for i in range(len(output_shape)):
5728 output_shape[i] += rng.integers(1, 10)
5729 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005730 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005731
Matthew Haddone807aae2021-10-11 18:12:58 +01005732 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005733 all_dtypes = [
5734 DType.INT8,
5735 DType.INT16,
5736 DType.INT32,
5737 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005738 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005739 DType.FP16,
5740 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005741 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005742 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5743 outputDType = rng.choice(wrong_dtypes)
5744 else:
5745 outputDType = a.dtype
5746
5747 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005748
5749 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005750 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005751 if error_name != ErrorIf.WrongRank:
5752 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005753 assert len(indices.shape) == 2
5754 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005755
Kevin Cheng77d0f762020-11-24 10:26:32 -08005756 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5757
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005758 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005759 all_dtypes = [
5760 DType.INT8,
5761 DType.INT16,
5762 DType.INT32,
5763 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005764 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005765 DType.FP16,
5766 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005767 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005768 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5769 outputDType = rng.choice(wrong_dtypes)
5770 else:
5771 outputDType = values.dtype
5772
5773 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005774
5775 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005776 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005777 if error_name != ErrorIf.WrongRank:
5778 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005779 assert len(indices.shape) == 2
5780 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005781 assert values_in.shape[0] == indices.shape[0] # N
5782 assert input.shape[1] == indices.shape[1] # W
5783 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005784
5785 output_shape = values_in.shape
5786
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005787 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005788 all_dtypes = [
5789 DType.INT8,
5790 DType.INT16,
5791 DType.INT32,
5792 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005793 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005794 DType.FP16,
5795 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005796 DType.FP8E4M3,
5797 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005798 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005799 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5800 outputDType = rng.choice(wrong_dtypes)
5801 else:
5802 outputDType = values_in.dtype
5803
5804 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005805
5806 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005807 def tableOp(ser, rng, input, error_name=None):
5808 # Same shape as the input, dtype dependent on input dtype
5809 if error_name != ErrorIf.WrongInputType:
5810 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005811 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005812 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005813 wrong_dtypes = [
5814 DType.INT8,
5815 DType.INT16,
5816 DType.INT32,
5817 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005818 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005819 DType.FP16,
5820 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005821 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005822 wrong_dtypes.remove(output_dtype)
5823 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005824 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005825
5826 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005827 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005828 serializer,
5829 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005830 input,
5831 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005832 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005833 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005834 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005835 input_dtype,
5836 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005837 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005838 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005839 # Calculate OH, OW
5840 scale_y_n = scale[0]
5841 scale_y_d = scale[1]
5842 scale_x_n = scale[2]
5843 scale_x_d = scale[3]
5844 if error_name == ErrorIf.ScaleSmallerEqualZero:
5845 scale_y_n = max(scale_y_n, 1)
5846 scale_y_d = max(scale_y_d, 1)
5847 scale_x_n = max(scale_x_n, 1)
5848 scale_x_d = max(scale_x_d, 1)
5849
5850 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5851 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5852
5853 if error_name is not None:
5854 # Make sure the output tensor is valid, which can occur when
5855 # scale, offset or border have been changed for ERROR_IFs
5856 oh = max(oh, 1)
5857 ow = max(ow, 1)
5858 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005859 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5860 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005861
5862 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5863 choices = [1, 2, 3]
5864 change = rng.choice(choices)
5865 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5866 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005867 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005868 oh -= scale_y_d
5869 assert oh > 0 # Should have been caught in agResize
5870 else:
5871 oh += scale_y_d
5872 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005873 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005874 ow -= scale_x_d
5875 assert ow > 0 # Should have been caught in agResize
5876 else:
5877 ow += scale_x_d
5878
Matthew Haddon848efb42021-09-09 12:30:53 +01005879 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005880 output_dims = [
5881 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005882 oh,
5883 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005884 input.shape[0],
5885 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005886 elif error_name == ErrorIf.BatchMismatch:
5887 output_dims = [
5888 input.shape[0] + rng.integers(1, 10),
5889 oh,
5890 ow,
5891 input.shape[3],
5892 ]
5893 elif error_name == ErrorIf.ChannelMismatch:
5894 output_dims = [
5895 input.shape[0],
5896 oh,
5897 ow,
5898 input.shape[3] + rng.integers(1, 10),
5899 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005900 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005901 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005902
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005903 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005904
5905 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005906 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005907 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005908
5909 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005910 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005911 if error_name == ErrorIf.ConvOutputShapeMismatch:
5912 choices = [1, 2, 3]
5913 change = rng.choice(choices)
5914 if change in [1, 3]:
5915 output_shape[1] = output_shape[1] + rng.choice(choices)
5916 if change in [2, 3]:
5917 output_shape[2] = output_shape[2] + rng.choice(choices)
5918
James Ward8b390432022-08-12 20:48:56 +01005919 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005920 # Pick some potentially correct output dtype if input type is incorrect
5921 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005922 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005923 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005924
5925 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005926 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005927 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005928 else:
5929 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005930 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005931 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005932
Kevin Cheng550ccc52021-03-03 11:21:43 -08005933 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00005934
5935 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00005936 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
5937 outputs = []
5938
5939 assert ifm1.dtype == ifm2.dtype
5940 input_dtype = ifm1.dtype
5941
5942 if error_name != ErrorIf.FFTInputShapeMismatch:
5943 assert ifm1.shape == ifm2.shape
5944
5945 input_shape = ifm1.shape
5946 if error_name != ErrorIf.WrongRank:
5947 assert len(input_shape) == 3
5948
5949 output_shape = input_shape.copy()
5950 output_dtype = input_dtype
5951
5952 if error_name == ErrorIf.WrongOutputType:
5953 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005954 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00005955 output_dtype = rng.choice(wrong_dtypes)
5956 elif error_name == ErrorIf.BatchMismatch:
5957 output_shape[0] += rng.integers(1, 10)
5958 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5959 modify_dim = rng.choice([1, 2])
5960 output_shape[modify_dim] += rng.integers(1, 10)
5961
5962 outputs.append(serializer.addOutput(output_shape, output_dtype))
5963 outputs.append(serializer.addOutput(output_shape, output_dtype))
5964 return outputs
5965
5966 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00005967 def rfft2dOp(serializer, rng, value, error_name=None):
5968 outputs = []
5969
5970 input_shape = value.shape
5971 if error_name != ErrorIf.WrongRank:
5972 assert len(input_shape) == 3
5973
5974 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
5975
5976 output_dtype = value.dtype
5977 if error_name == ErrorIf.WrongOutputType:
5978 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005979 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00005980 output_dtype = rng.choice(wrong_dtypes)
5981 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00005982 output_shape[0] += rng.integers(1, 10)
5983 elif error_name == ErrorIf.FFTOutputShapeMismatch:
5984 modify_dim = rng.choice([1, 2])
5985 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00005986
5987 outputs.append(serializer.addOutput(output_shape, output_dtype))
5988 outputs.append(serializer.addOutput(output_shape, output_dtype))
5989 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00005990
5991 @staticmethod
5992 def addShapeOp(ser, rng, a, b, error_name=None):
5993 if error_name != ErrorIf.RankMismatch:
5994 assert len(a.shape) == len(b.shape)
5995 assert a.dtype == b.dtype
5996
5997 shape = []
5998 for i in range(len(a.shape)):
5999 shape.append(a.shape[i])
6000
6001 fuzz_idx = rng.integers(0, len(a.shape))
6002 if error_name == ErrorIf.DimensionMismatch:
6003 shape[fuzz_idx] += 1
6004
6005 if error_name == ErrorIf.WrongOutputType:
6006 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
6007 outputDType = rng.choice(wrong_dtypes)
6008 else:
6009 outputDType = DType.SHAPE
6010 return ser.addOutput(shape, outputDType)