blob: 28b3d28023620d10e6f769f54d3d7e26f5bcf89f [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005import os
Tai Ly60dc48c2024-03-08 22:19:41 +00006import struct
Matthew Haddon630c17c2021-10-14 15:05:41 +01007from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01008from datetime import datetime
9from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -070010
Jeremy Johnson1271c442023-09-05 11:39:26 +010011import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000012import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000013import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010014from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010015from generator.tosa_arg_gen import TosaArgGen
16from generator.tosa_arg_gen import TosaQuantGen
17from generator.tosa_arg_gen import TosaTensorGen
18from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000019from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_error_if import TosaErrorIfArgGen
21from generator.tosa_error_if import TosaErrorValidator
22from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010023from generator.tosa_random_gen import TosaHashRandomGenerator
24from generator.tosa_random_gen import TosaRandomGenerator
Jeremy Johnson1271c442023-09-05 11:39:26 +010025from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000026from tosa.DType import DType
27from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010028
Jeremy Johnson1271c442023-09-05 11:39:26 +010029TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
30// SPDX-License-Identifier: Apache-2.0
31// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
32"""
33
Jeremy Johnsonaf090182024-02-13 18:25:39 +000034logging.basicConfig()
35logger = logging.getLogger("tosa_verif_build_tests")
36
Matthew Haddonb724efc2021-08-25 16:40:29 +010037
Eric Kunzee5e26762020-10-13 16:11:07 -070038class TosaTestGen:
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000039 # This currently matches the 8K level defined in the specification.
Jeremy Johnsonb2099702023-04-12 15:59:01 +010040 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010041 TOSA_8K_LEVEL_MAX_KERNEL = 8192
42 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010043
Jeremy Johnson1271c442023-09-05 11:39:26 +010044 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000045 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010046 TOSA_MI_DOT_PRODUCT_MIN = 1000
47
Eric Kunzee5e26762020-10-13 16:11:07 -070048 def __init__(self, args):
49 self.args = args
50 self.basePath = args.output_dir
51 self.random_seed = args.random_seed
52 self.ser = None
Eric Kunzee5e26762020-10-13 16:11:07 -070053 self.createDynamicOpLists()
54 self.initOpListDefaults()
55 self.quantGen = TosaQuantGen()
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010056 self.global_rng = None
Eric Kunzee5e26762020-10-13 16:11:07 -070057 # Force makeShape to do a specific starting shape
58 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010059 # JSON schema validation
60 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010061 # Data generator library is sometimes needed for compliance set up
62 # even if we are generating the data later (lazy_data_generation)
63 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070064
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010065 # Work out floating point range
66 def convertFPRange(rangeFP, maxFP):
67 # Converts program arguments of max/-max to FP max
68 vals = []
69 for v in rangeFP:
70 if v == "max":
71 v = maxFP
72 elif v == "-max":
73 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000074 elif v < 0:
75 # Trim to minimum data type value
76 v = max(v, -maxFP)
77 elif v > 0:
78 # Trim to maximum data type value
79 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 vals.append(v)
81 return tuple(sorted(vals))
82
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010083 self.random_dtype_range = {
84 DType.SHAPE: tuple(self.args.tensor_shape_range[0:2])
85 }
Won Jeon2c34b462024-02-06 18:37:00 +000086 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010087 self.random_dtype_range[dtype] = convertFPRange(
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010088 args.tensor_fp_value_range,
89 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
90 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010091 self.resetGlobalRNG()
92
93 def resetGlobalRNG(self):
94 self.global_rng = TosaRandomGenerator(self.random_seed, self.random_dtype_range)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010095
Eric Kunzee5e26762020-10-13 16:11:07 -070096 def createSerializer(self, opName, testPath):
97 self.testPath = os.path.join(opName, testPath)
98
99 fullPath = os.path.join(self.basePath, self.testPath)
100 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100101 # Embed const data in the flatbuffer
102 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +0100103 if self.args.lazy_data_gen:
104 # Lazy data generation - so make constants files
105 constMode = ts.ConstMode.INPUTS
106 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100107 constMode = ts.ConstMode.EMBED_DUMP
108 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -0700109
110 def getSerializer(self):
111 return self.ser
112
Jeremy Johnson1271c442023-09-05 11:39:26 +0100113 def serialize(self, testName, metaData=None):
114 path = Path(self.basePath) / self.testPath
115
116 # Write out TOSA flatbuffer binary
117 path_fb = path / f"{testName}.tosa"
118 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700119 fd.write(self.ser.serialize())
120
Jeremy Johnson1271c442023-09-05 11:39:26 +0100121 # Get JSON descriptor from serializer
122 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
123
124 if metaData:
125 # Add extra meta data to desc.json
126 desc["meta"] = metaData
127
128 # Validate desc.json before we output it
129 self.descSchemaValidator.validate_config(desc)
130
131 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100132 if "data_gen" in metaData:
133 if self.args.lazy_data_gen:
134 # Output datagen meta data as CPP data
135 path_md = path / f"{testName}_meta_data_gen.cpp"
136 with path_md.open("w") as fd:
137 fd.write(TOSA_AUTOGENERATED_HEADER)
138 fd.write("// Test meta data for data generation setup\n\n")
139 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
140 json.dump(metaData["data_gen"], fd)
141 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100142 if "compliance" in metaData:
143 # Output datagen meta data as CPP data
144 path_md = path / f"{testName}_meta_compliance.cpp"
145 with path_md.open("w") as fd:
146 fd.write(TOSA_AUTOGENERATED_HEADER)
147 fd.write("// Test meta data for compliance validation\n\n")
148 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
149 json.dump(metaData["compliance"], fd)
150 fd.write(')";\n\n')
151
152 # Write desc.json
153 path_desc = path / "desc.json"
154 with path_desc.open("w") as fd:
155 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700156
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100157 def buildPlaceholderTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700158 placeholders = []
159
Kevin Cheng989cb052021-04-28 16:29:44 -0700160 assert len(shape_list) == len(dtype_list)
161
Jeremy Johnson1271c442023-09-05 11:39:26 +0100162 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700163 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100164 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100165 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700166 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700167
168 return placeholders
169
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100170 def buildConstTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 consts = []
172
Kevin Cheng989cb052021-04-28 16:29:44 -0700173 assert len(shape_list) == len(dtype_list)
174
Jeremy Johnson1271c442023-09-05 11:39:26 +0100175 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700176 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100177 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100178 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700179 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700180
181 return consts
182
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100183 def makeShape(self, rng, rank):
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 if self.targetted_shape:
185 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 return np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100187 rng.integers(
Kevin Cheng550ccc52021-03-03 11:21:43 -0800188 low=self.args.tensor_shape_range[0],
189 high=self.args.tensor_shape_range[1],
190 size=rank,
191 )
192 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700193
194 def setTargetShape(self, shape):
195 self.targetted_shape = shape
196
Eric Kunzee5e26762020-10-13 16:11:07 -0700197 def shapeStr(self, shape):
198
199 sStr = []
200 # Convert to strings
201 for i in shape:
202 sStr.append(str(i))
203
Kevin Cheng550ccc52021-03-03 11:21:43 -0800204 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700205
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100206 def typeStr(self, dtype):
207 if isinstance(dtype, list) or isinstance(dtype, tuple):
208 assert len(dtype) >= 2
209 strs = [self.typeStr(t) for t in dtype]
210 # Limit types to the first 2 as the 3rd is the accumulator
211 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 if dtype in gtu.DTYPE_ATTRIBUTES:
214 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700215 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100216 raise Exception(
217 "Unknown dtype, cannot convert to string: {}".format(dtype)
218 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
Luke Hutton57287132023-02-06 14:54:18 +0000220 def constrictBatchSize(self, shape):
221 # Limit the batch size unless an explicit target shape set
222 if self.args.max_batch_size and not self.args.target_shapes:
223 shape[0] = min(shape[0], self.args.max_batch_size)
224 return shape
225
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100226 def makeDimension(self, rng):
227 return rng.randInt(
James Ward30124a82023-02-02 14:56:33 +0000228 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
229 )
230
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100231 def tensorComplianceMetaData(
232 self, op, inputType, argsDict, outputTensor, errorName
233 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000234 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
235 UNSUPPORTED_NON_FP32_INPUT_OPS = (
236 Op.MATMUL,
237 Op.CONV2D,
238 Op.FULLY_CONNECTED,
239 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000240 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000241 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000242 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100243 if (
244 errorName
245 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000246 or (
247 not gtu.dtypeIsSupportedByCompliance(inputType)
248 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
249 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100250 ):
251 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100252 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100253
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100255 compliance_tens = {
256 "mode": None,
257 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
258 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
259 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
261 mode = gtu.ComplianceMode.DOT_PRODUCT
262 compliance_tens["dot_product_info"] = {
263 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100264 "ks": int(argsDict["ksb"])
265 if "ksb" in argsDict
266 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100267 }
evacha019c96eef2024-02-07 11:21:55 +0000268 elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100269 mode = gtu.ComplianceMode.FP_SPECIAL
270 elif "compliance" in op and "ulp" in op["compliance"]:
271 mode = gtu.ComplianceMode.ULP
272 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000273 elif "compliance" in op and "relative" in op["compliance"]:
274 mode = gtu.ComplianceMode.RELATIVE
275 compliance_tens["relative_info"] = {
276 "max": argsDict["max_abs_value"],
277 "scale": op["compliance"]["relative"],
278 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 elif op["op"] == Op.REDUCE_PRODUCT:
280 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000281 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000282 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000283 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000284 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
285 compliance_tens["abs_error_info"] = {
286 "lower_bound": op["compliance"]["abs_error_lower_bound"]
287 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800288 elif op["op"] in (Op.SIN, Op.COS):
289 mode = gtu.ComplianceMode.ABS_ERROR
290 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
291 compliance_tens["abs_error_info"] = {
292 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
293 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100294 else:
295 mode = gtu.ComplianceMode.EXACT
296 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
297
298 return compliance_tens
299
300 # Build Op functions
301 # Create the output tensor (calling OutputShaper as needed)
302 # Do final tweaks to attributes (if necessary for errorIf)
303 # Add Op into graph
304 # Return resulting tensor information or BuildInfo
305
306 class BuildInfo:
307 """Enhanced build information containing result tensor and associated compliance dict."""
308
309 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000310 if isinstance(resultTensor, list):
311 assert complianceDict is None or isinstance(complianceDict, list)
312 self.resultTensorList = resultTensor
313 self.complianceDictList = complianceDict
314 else:
315 self.resultTensorList = [resultTensor]
316 if complianceDict is None:
317 self.complianceDictList = None
318 else:
319 self.complianceDictList = [complianceDict]
320
321 def getComplianceInfo(self):
322 if self.complianceDictList is None:
323 return None
324 else:
325 tens_dict = {}
326 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
327 if comp is not None:
328 tens_dict[tens.name] = comp
329
330 if tens_dict:
331 # Have some compliance data, so return the info
332 compliance = {
333 "version": "0.1",
334 "tensors": tens_dict,
335 }
336 else:
337 compliance = None
338 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700339
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000340 def build_unary(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100341 self,
342 rng,
343 op,
344 inputs,
345 args_dict,
346 validator_fcns=None,
347 error_name=None,
348 qinfo=None,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000349 ):
350 assert len(inputs) == 1
351 a = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100352 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100353
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000354 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100355
356 # Ensure new output type has correct qinfo
357 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000358 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000359 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100360 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, a.dtype),
361 TosaQuantGen.getZeroPoint(
362 rng, self.args.zeropoint, result_tensor.dtype
363 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000364 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100365
366 # Invalidate Input/Output list for error if checks.
367 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000368 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100369 pCount, cCount = op["operands"]
370 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000371 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100372 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000373 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100374
Les Bell729b0352021-11-24 10:28:21 +0000375 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100376 self.ser,
377 validator_fcns,
378 error_name,
379 op=op,
380 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000381 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000382 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000383 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100384 input_list=input_list,
385 output_list=output_list,
386 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000387 ):
388 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100389
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000390 attr = None
391 if op["op"] == Op.NEGATE:
392 attr = ts.TosaSerializerAttribute()
393 attr.NegateAttribute(qinfo[0], qinfo[1])
394
395 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000396
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000397 compliance = self.tensorComplianceMetaData(
398 op, a.dtype, args_dict, result_tensor, error_name
399 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000400 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700401
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000402 def build_binary_broadcast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100403 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000404 ):
405 assert len(inputs) == 2
406 a, b = inputs
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100407 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100408
409 # Invalidate Input/Output list for error if checks.
410 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000411 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100412 pCount, cCount = op["operands"]
413 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000414 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100415 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000416 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100417
Les Bell729b0352021-11-24 10:28:21 +0000418 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100419 self.ser,
420 validator_fcns,
421 error_name,
422 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000423 input1=a,
424 input2=b,
425 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000426 output_dtype=result_tensor.dtype,
427 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100428 input_list=input_list,
429 output_list=output_list,
430 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000431 ):
432 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100433
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000434 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000435
Jeremy Johnson9a758382023-11-07 16:27:35 +0000436 compliance = self.tensorComplianceMetaData(
437 op, a.dtype, args_dict, result_tensor, error_name
438 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000439
440 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700441
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000442 def build_arithmetic_right_shift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100443 self,
444 rng,
445 op,
446 inputs,
447 args_dict,
448 validator_fcns=None,
449 error_name=None,
450 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000451 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000452 assert len(inputs) == 2
453 a, b = inputs
454 round = args_dict["round"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100455 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100456
457 # Invalidate Input/Output list for error if checks.
458 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000459 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100460 pCount, cCount = op["operands"]
461 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000462 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100463 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000464 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100465
Les Bell729b0352021-11-24 10:28:21 +0000466 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100467 self.ser,
468 validator_fcns,
469 error_name,
470 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000471 input1=a,
472 input2=b,
473 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000474 output_dtype=result_tensor.dtype,
475 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100476 input_list=input_list,
477 output_list=output_list,
478 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000479 ):
480 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800481
482 attr = ts.TosaSerializerAttribute()
483 attr.ArithmeticRightShiftAttribute(round)
484
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000485 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000486
487 compliance = self.tensorComplianceMetaData(
488 op, a.dtype, args_dict, result_tensor, error_name
489 )
490
491 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800492
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100493 def build_mul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100494 self,
495 rng,
496 op,
497 inputs,
498 args_dict,
499 validator_fcns=None,
500 error_name=None,
501 qinfo=None,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100502 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000503 # Note that mul is binary operator but it has a shift value tensor
504 assert len(inputs) == 3
505 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100506
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100507 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700508
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100509 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100510 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100511 result_tensor.setDtype(DType.INT32)
512
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100513 if error_name == ErrorIf.WrongOutputType:
514 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100515 outputDType = rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100516 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100517
518 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000519 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100520 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100521 pCount, cCount = op["operands"]
522 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000523 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100524 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000525 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100526
Les Bell729b0352021-11-24 10:28:21 +0000527 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100528 self.ser,
529 validator_fcns,
530 error_name,
531 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000532 input1=a,
533 input2=b,
534 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100535 output_dtype=result_tensor.dtype,
536 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537 input_list=input_list,
538 output_list=output_list,
539 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000540 ):
541 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700542
Jeremy Johnson0a042992024-02-28 13:20:05 +0000543 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100544
545 compliance = self.tensorComplianceMetaData(
546 op, a.dtype, args_dict, result_tensor, error_name
547 )
548
549 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700550
Jeremy Johnson587cc842024-02-08 11:45:44 +0000551 def build_table(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100552 self,
553 rng,
554 op,
555 inputs,
556 args_dict,
557 validator_fcns=None,
558 error_name=None,
559 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000560 ):
561 assert len(inputs) == 1
562 a = inputs[0]
563 table = args_dict["table"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100564 result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700565
Kevin Chengfe392ce2021-10-18 21:51:55 +0000566 attr = ts.TosaSerializerAttribute()
567 attr.TableAttribute(table)
568
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100569 # Invalidate Input/Output list for error if checks.
570 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000571 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100572 pCount, cCount = op["operands"]
573 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000574 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100575 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000576 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100577
Les Bell729b0352021-11-24 10:28:21 +0000578 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100579 self.ser,
580 validator_fcns,
581 error_name,
582 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000583 input_shape=a.shape,
584 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000585 output_dtype=result_tensor.dtype,
586 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100587 input_list=input_list,
588 output_list=output_list,
589 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000590 ):
591 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100592
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000593 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700594
Jeremy Johnson587cc842024-02-08 11:45:44 +0000595 compliance = self.tensorComplianceMetaData(
596 op, a.dtype, args_dict, result_tensor, error_name
597 )
598
599 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700600
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000601 def build_select(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100602 self,
603 rng,
604 op,
605 inputs,
606 args_dict,
607 validator_fcns=None,
608 error_name=None,
609 qinfo=None,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000610 ):
611 assert len(inputs) == 3
612 cond, a, b = inputs
613
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100614 result_tensor = OutputShaper.selectOp(self.ser, rng, cond, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100615
616 # Invalidate Input/Output list for error if checks.
617 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000618 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100619 pCount, cCount = op["operands"]
620 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000621 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100622 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100624
Les Bell729b0352021-11-24 10:28:21 +0000625 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626 self.ser,
627 validator_fcns,
628 error_name,
629 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000630 input1=cond,
631 input2=a,
632 input3=b,
633 input_shape=a.shape,
634 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000635 output_dtype=result_tensor.dtype,
636 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100637 input_list=input_list,
638 output_list=output_list,
639 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000640 ):
641 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100642
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000643 self.ser.addOperator(
644 op["op"],
645 input_list,
646 output_list,
647 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000648 compliance = self.tensorComplianceMetaData(
649 op, a.dtype, args_dict, result_tensor, error_name
650 )
651
652 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700653
Jeremy Johnsona0150012023-11-15 15:52:06 +0000654 def build_comparison(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100655 self,
656 rng,
657 op,
658 inputs,
659 args_dict,
660 validator_fcns=None,
661 error_name=None,
662 qinfo=None,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000663 ):
664 assert len(inputs) == 2
665 a, b = inputs
666
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100667 result_tensor = OutputShaper.binaryComparisonOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100668
669 # Invalidate Input/Output list for error if checks.
670 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000671 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672 pCount, cCount = op["operands"]
673 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000674 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100675 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000676 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100677
Les Bell729b0352021-11-24 10:28:21 +0000678 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100679 self.ser,
680 validator_fcns,
681 error_name,
682 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000683 input1=a,
684 input2=b,
685 input_shape=a.shape,
686 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000687 output_shape=result_tensor.shape,
688 output_dtype=result_tensor.dtype,
689 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100690 input_list=input_list,
691 output_list=output_list,
692 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000693 ):
694 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100695
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000696 self.ser.addOperator(
697 op["op"],
698 input_list,
699 output_list,
700 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000701
702 compliance = self.tensorComplianceMetaData(
703 op, a.dtype, args_dict, result_tensor, error_name
704 )
705 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700706
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000707 def build_argmax(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100708 self, rng, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000709 ):
710 assert len(inputs) == 1
711 a = inputs[0]
712 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100713 result_tensor = OutputShaper.argmaxOp(self.ser, rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100714
715 # Invalidate Input/Output list for error if checks.
716 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000717 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100718 pCount, cCount = op["operands"]
719 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000720 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100721 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000722 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100723
Les Bell729b0352021-11-24 10:28:21 +0000724 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100725 self.ser,
726 validator_fcns,
727 error_name,
728 op=op,
729 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000730 input_shape=a.shape,
731 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000732 output_shape=result_tensor.shape,
733 output_dtype=result_tensor.dtype,
734 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100735 input_list=input_list,
736 output_list=output_list,
737 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000738 ):
739 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700740
741 attr = ts.TosaSerializerAttribute()
742 attr.AxisAttribute(axis)
743
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000744 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000745
746 compliance = self.tensorComplianceMetaData(
747 op, inputs[0].dtype, args_dict, result_tensor, error_name
748 )
749 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700750
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000751 def build_pool2d(
752 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100753 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000754 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100755 inputs,
756 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000757 validator_fcns=None,
758 error_name=None,
759 qinfo=None,
760 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100761 assert len(inputs) == 1
762 input = inputs[0]
763 # max_pool has no accum_dtype
764 accum_dtype = (
765 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
766 )
767 stride = args_dict["stride"]
768 pad = args_dict["pad"]
769 kernel = args_dict["kernel"]
770
Jeremy Johnson0601f802023-11-08 16:28:09 +0000771 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100772 self.ser, rng, input, kernel, stride, pad, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000773 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100774
775 # Ensure new output type has correct qinfo
776 if error_name == ErrorIf.WrongInputType:
777 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000778 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100779 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, input.dtype),
780 TosaQuantGen.getZeroPoint(
781 rng, self.args.zeropoint, result_tensor.dtype
782 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000783 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100784
785 # Invalidate Input/Output list for error if checks.
786 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000787 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100788 pCount, cCount = op["operands"]
789 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000790 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100791 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000792 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100793
Les Bell729b0352021-11-24 10:28:21 +0000794 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100795 self.ser,
796 validator_fcns,
797 error_name,
798 op=op,
799 input_shape=input.shape,
800 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000801 output_shape=result_tensor.shape,
802 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000803 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100804 kernel=kernel,
805 stride=stride,
806 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000807 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000808 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100809 input_list=input_list,
810 output_list=output_list,
811 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000812 ):
813 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700814
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000815 if qinfo is None:
816 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000818 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100819 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000820
821 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700822
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100823 compliance = self.tensorComplianceMetaData(
824 op, inputs[0].dtype, args_dict, result_tensor, error_name
825 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100826
827 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100828
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000829 def build_conv2d(
830 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100831 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000832 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100833 inputs,
834 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000835 validator_fcns=None,
836 error_name=None,
837 qinfo=None,
838 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100839 assert len(inputs) == 3
840 ifm, filter, bias = inputs
841 accum_dtype = args_dict["acc_type"]
842 strides = args_dict["stride"]
843 padding = args_dict["pad"]
844 dilations = args_dict["dilation"]
845
Kevin Cheng550ccc52021-03-03 11:21:43 -0800846 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100847 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100848 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100849 rng,
James Ward8b390432022-08-12 20:48:56 +0100850 ifm,
851 filter,
852 accum_dtype,
853 strides,
854 padding,
855 dilations,
856 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000857 )
858
859 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000860 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
861 DType.INT8,
862 DType.UINT8,
863 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000864 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100865 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
866 TosaQuantGen.getZeroPoint(
867 rng, self.args.zeropoint, result_tensor.dtype
868 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000869 ]
Les Bell0e027d42021-11-09 14:42:14 +0000870
871 # Invalidate Input/Output list for error_if checks.
872 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100873 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000874 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000875 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100876 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000877 )
Les Bell0e027d42021-11-09 14:42:14 +0000878
Les Bell729b0352021-11-24 10:28:21 +0000879 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000880 self.ser,
881 validator_fcns,
882 error_name,
883 op=op,
884 input_dtype=ifm.dtype,
885 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100886 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000887 qinfo=qinfo,
888 input_list=input_list,
889 num_operands=num_operands,
890 output_list=output_list,
891 pad=padding,
892 stride=strides,
893 dilation=dilations,
894 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100895 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100896 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000897 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000898 ):
899 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700900
Tai Lyd3797f02023-11-15 23:06:19 +0000901 # TODO - Test local_bound, for now set local bound attribute to False
902 local_bound = False
903
Eric Kunzee5e26762020-10-13 16:11:07 -0700904 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000905 attr.ConvAttribute(
906 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
907 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700908
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000909 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100910
911 compliance = self.tensorComplianceMetaData(
912 op, ifm.dtype, args_dict, result_tensor, error_name
913 )
914
915 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700916
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000917 def build_conv3d(
918 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100919 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000920 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100921 inputs,
922 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000923 validator_fcns=None,
924 error_name=None,
925 qinfo=None,
926 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100927 assert len(inputs) == 3
928 ifm, filter, bias = inputs
929 accum_dtype = args_dict["acc_type"]
930 strides = args_dict["stride"]
931 padding = args_dict["pad"]
932 dilations = args_dict["dilation"]
933
Kevin Cheng1533b852021-09-01 12:51:58 -0700934 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000935 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100936 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100937 rng,
James Ward8b390432022-08-12 20:48:56 +0100938 ifm,
939 filter,
940 accum_dtype,
941 strides,
942 padding,
943 dilations,
944 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000945 )
946
947 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000948 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
949 DType.INT8,
950 DType.UINT8,
951 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000952 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100953 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
954 TosaQuantGen.getZeroPoint(
955 rng, self.args.zeropoint, result_tensor.dtype
956 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000957 ]
Les Bell0e027d42021-11-09 14:42:14 +0000958
959 # Invalidate Input/Output list for error_if checks.
960 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +0000961 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000962 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000963 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100964 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000965 )
Les Bell0e027d42021-11-09 14:42:14 +0000966
Les Bell729b0352021-11-24 10:28:21 +0000967 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000968 self.ser,
969 validator_fcns,
970 error_name,
971 op=op,
972 input_dtype=ifm.dtype,
973 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +0000974 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000975 qinfo=qinfo,
976 input_list=input_list,
977 num_operands=num_operands,
978 output_list=output_list,
979 pad=padding,
980 stride=strides,
981 dilation=dilations,
982 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100983 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +0000984 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000985 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000986 ):
987 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700988
Tai Lyd3797f02023-11-15 23:06:19 +0000989 # TODO - Test local_bound, for now set local bound attribute to False
990 local_bound = False
991
Kevin Cheng1533b852021-09-01 12:51:58 -0700992 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000993 attr.ConvAttribute(
994 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
995 )
Kevin Cheng1533b852021-09-01 12:51:58 -0700996
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000997 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +0000998
999 compliance = self.tensorComplianceMetaData(
1000 op, ifm.dtype, args_dict, result_tensor, error_name
1001 )
1002
1003 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001004
Kevin Cheng550ccc52021-03-03 11:21:43 -08001005 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001006 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001007 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001008 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001009 inputs,
1010 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001011 validator_fcns=None,
1012 error_name=None,
1013 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001014 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001015 assert len(inputs) == 3
1016 ifm, filter, bias = inputs
1017 accum_dtype = args_dict["acc_type"]
1018 strides = args_dict["stride"]
1019 out_pad = args_dict["pad"]
1020 output_shape = args_dict["out_shape"]
1021
TatWai Chong24594f52022-06-08 00:48:04 -07001022 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001023 result_tensor = OutputShaper.transposeConv2DOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001024 self.ser, rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001025 )
Les Bell0e027d42021-11-09 14:42:14 +00001026
1027 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1029 DType.INT8,
1030 DType.UINT8,
1031 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001032 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001033 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1034 TosaQuantGen.getZeroPoint(
1035 rng, self.args.zeropoint, result_tensor.dtype
1036 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001037 ]
Les Bell0e027d42021-11-09 14:42:14 +00001038
1039 # Invalidate Input/Output list for error_if checks.
1040 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001041 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001042 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001043 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001044 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001045 )
Les Bell0e027d42021-11-09 14:42:14 +00001046
Les Bell729b0352021-11-24 10:28:21 +00001047 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001048 self.ser,
1049 validator_fcns,
1050 error_name,
1051 op=op,
1052 input_dtype=ifm.dtype,
1053 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001054 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001055 qinfo=qinfo,
1056 input_list=input_list,
1057 num_operands=num_operands,
1058 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001059 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001060 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001061 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001062 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001063 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001064 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001065 ):
1066 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001067
Tai Lyd3797f02023-11-15 23:06:19 +00001068 # TODO - Test local_bound, for now set local bound attribute to False
1069 local_bound = False
1070
Eric Kunzee5e26762020-10-13 16:11:07 -07001071 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001072 attr.TransposeConvAttribute(
Tai Lyf36f2562024-03-14 16:21:29 +00001073 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
Tai Lyd3797f02023-11-15 23:06:19 +00001074 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001075
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001076 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001077
1078 compliance = self.tensorComplianceMetaData(
1079 op, ifm.dtype, args_dict, result_tensor, error_name
1080 )
1081
1082 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001083
Kevin Cheng550ccc52021-03-03 11:21:43 -08001084 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001085 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001086 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001088 inputs,
1089 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001090 validator_fcns=None,
1091 error_name=None,
1092 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001093 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001094 assert len(inputs) == 3
1095 ifm, filter, bias = inputs
1096 accum_dtype = args_dict["acc_type"]
1097 strides = args_dict["stride"]
1098 padding = args_dict["pad"]
1099 dilations = args_dict["dilation"]
1100
Jeremy Johnson4f931302024-01-04 17:05:24 +00001101 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001102 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001103 rng,
James Ward8b390432022-08-12 20:48:56 +01001104 ifm,
1105 filter,
1106 accum_dtype,
1107 strides,
1108 padding,
1109 dilations,
1110 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001111 )
1112
1113 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001114 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1115 DType.INT8,
1116 DType.UINT8,
1117 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001118 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001119 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1120 TosaQuantGen.getZeroPoint(
1121 rng, self.args.zeropoint, result_tensor.dtype
1122 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001123 ]
Les Bell0e027d42021-11-09 14:42:14 +00001124
1125 # Invalidate Input/Output list for error_if checks.
1126 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001127 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001128 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001129 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001130 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001131 )
Les Bell0e027d42021-11-09 14:42:14 +00001132
Les Bell729b0352021-11-24 10:28:21 +00001133 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001134 self.ser,
1135 validator_fcns,
1136 error_name,
1137 op=op,
1138 input_dtype=ifm.dtype,
1139 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001140 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001141 qinfo=qinfo,
1142 input_list=input_list,
1143 num_operands=num_operands,
1144 output_list=output_list,
1145 pad=padding,
1146 stride=strides,
1147 dilation=dilations,
1148 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001149 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001150 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001151 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001152 ):
1153 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001154
Tai Lyd3797f02023-11-15 23:06:19 +00001155 # TODO - Test local_bound, for now set local bound attribute to False
1156 local_bound = False
1157
Eric Kunzee5e26762020-10-13 16:11:07 -07001158 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +00001159 attr.ConvAttribute(
1160 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
1161 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001162
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001163 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001164
1165 compliance = self.tensorComplianceMetaData(
1166 op, ifm.dtype, args_dict, result_tensor, error_name
1167 )
1168
1169 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001170
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001171 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001172 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001173 rng,
James Ward8b390432022-08-12 20:48:56 +01001174 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001175 inputs,
1176 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001177 validator_fcns=None,
1178 error_name=None,
1179 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001180 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001181 assert len(inputs) == 3
1182 ifm, filter, bias = inputs
1183 accum_dtype = args_dict["acc_type"]
1184
1185 result_tensor = OutputShaper.fullyConnectedOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001186 self.ser, rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001187 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001188
1189 # Invalidate Input/Output list for error if checks.
1190 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001191 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001192 pCount, cCount = op["operands"]
1193 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001194 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001195 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001196 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001197
Les Bell729b0352021-11-24 10:28:21 +00001198 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001199 self.ser,
1200 validator_fcns,
1201 error_name,
1202 op=op,
1203 input_shape=ifm.shape,
1204 input_dtype=ifm.dtype,
1205 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001206 output_shape=result_tensor.shape,
1207 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001208 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001209 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001210 input_list=input_list,
1211 output_list=output_list,
1212 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001213 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001214 ):
1215 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001216
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001217 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001218 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001219
1220 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001221
1222 compliance = self.tensorComplianceMetaData(
1223 op, ifm.dtype, args_dict, result_tensor, error_name
1224 )
1225
1226 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001227
James Ward8b390432022-08-12 20:48:56 +01001228 def build_matmul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001229 self,
1230 rng,
1231 op,
1232 inputs,
1233 args_dict,
1234 validator_fcns=None,
1235 error_name=None,
1236 qinfo=None,
James Ward8b390432022-08-12 20:48:56 +01001237 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001238 assert len(inputs) == 2
1239 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001240 accum_dtype = args_dict["acc_type"]
1241 result_tensor = OutputShaper.matmulOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001242 self.ser, rng, a, b, accum_dtype, error_name
James Ward8b390432022-08-12 20:48:56 +01001243 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001244
1245 # Invalidate Input/Output list for error if checks.
1246 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001247 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001248 pCount, cCount = op["operands"]
1249 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001250 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001251 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001252 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001253
Les Bell729b0352021-11-24 10:28:21 +00001254 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001255 self.ser,
1256 validator_fcns,
1257 error_name,
1258 op=op,
1259 input_shape=a.shape,
1260 input_dtype=a.dtype,
1261 input2_shape=b.shape,
1262 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001263 output_shape=result_tensor.shape,
1264 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001265 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001266 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001267 input_list=input_list,
1268 output_list=output_list,
1269 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001270 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001271 ):
1272 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001273
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001274 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001275 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001276
1277 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001278
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001279 compliance = self.tensorComplianceMetaData(
1280 op, a.dtype, args_dict, result_tensor, error_name
1281 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001282
1283 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001284
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001285 def build_reduce(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001286 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001287 ):
1288 assert len(inputs) == 1
1289 a = inputs[0]
1290 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001291 result_tensor = OutputShaper.reduceOp(self.ser, rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001292
1293 # Invalidate Input/Output list for error if checks.
1294 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001295 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001296 pCount, cCount = op["operands"]
1297 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001298 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001299 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001300 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001301
Les Bell729b0352021-11-24 10:28:21 +00001302 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001303 self.ser,
1304 validator_fcns,
1305 error_name,
1306 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001307 axis=axis,
1308 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001309 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001311 output_dtype=result_tensor.dtype,
1312 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001313 input_list=input_list,
1314 output_list=output_list,
1315 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001316 ):
1317 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001318
1319 attr = ts.TosaSerializerAttribute()
1320 attr.AxisAttribute(axis)
1321
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001322 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001323
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001324 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1325 # Number of products - needed for compliance
1326 args_dict["n"] = a.shape[axis]
1327
1328 compliance = self.tensorComplianceMetaData(
1329 op, a.dtype, args_dict, result_tensor, error_name
1330 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001331
1332 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001333
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001334 def build_clamp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001335 self,
1336 rng,
1337 op,
1338 inputs,
1339 args_dict,
1340 validator_fcns=None,
1341 error_name=None,
1342 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001343 ):
1344 assert len(inputs) == 1
1345 a = inputs[0]
1346
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001347 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001348
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001349 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001350
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001351 if error_name == ErrorIf.MaxSmallerMin:
1352 # Make sure the numbers are different to invoke this error
1353 while v[0] == v[1]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001354 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001355 max_val = min(v)
1356 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001357 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001358 max_val = max(v)
1359 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001360
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001361 # Invalidate Input/Output list for error if checks.
1362 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001363 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001364 pCount, cCount = op["operands"]
1365 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001366 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001367 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001368 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001369
Les Bell729b0352021-11-24 10:28:21 +00001370 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001371 self.ser,
1372 validator_fcns,
1373 error_name,
1374 op=op,
1375 max_val=max_val,
1376 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001377 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001378 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001379 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001380 output_dtype=result_tensor.dtype,
1381 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001382 input_list=input_list,
1383 output_list=output_list,
1384 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001385 ):
1386 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001387
1388 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001389 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1390 if a.dtype == DType.FP16:
1391 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1392 min_val = min_val.astype(np.float32)
1393 max_val = max_val.astype(np.float32)
Tai Ly60dc48c2024-03-08 22:19:41 +00001394 min_val_as_bytes = struct.pack("<f", min_val)
1395 max_val_as_bytes = struct.pack("<f", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001396 elif a.dtype in (DType.INT8, DType.INT16):
Tai Ly60dc48c2024-03-08 22:19:41 +00001397 min_val_as_bytes = struct.pack("<i", min_val)
1398 max_val_as_bytes = struct.pack("<i", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001399 else:
1400 # to avoid internal error for incorrect input types
Tai Ly60dc48c2024-03-08 22:19:41 +00001401 min_val_as_bytes = struct.pack("<i", 0)
1402 max_val_as_bytes = struct.pack("<i", 0)
1403
1404 attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001405
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001406 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001407
1408 compliance = self.tensorComplianceMetaData(
1409 op, a.dtype, args_dict, result_tensor, error_name
1410 )
1411
1412 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001413
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001414 def build_activation(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001415 self,
1416 rng,
1417 op,
1418 inputs,
1419 args_dict,
1420 validator_fcns=None,
1421 error_name=None,
1422 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001423 ):
1424 assert len(inputs) == 1
1425 a = inputs[0]
1426
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001427 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001428
1429 # Invalidate Input/Output list for error if checks.
1430 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001431 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432 pCount, cCount = op["operands"]
1433 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001434 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001435 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001436 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001437
Les Bell729b0352021-11-24 10:28:21 +00001438 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001439 self.ser,
1440 validator_fcns,
1441 error_name,
1442 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001443 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001444 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001445 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001446 output_dtype=result_tensor.dtype,
1447 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001448 input_list=input_list,
1449 output_list=output_list,
1450 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001451 ):
1452 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001453
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001454 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001455
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001456 compliance = self.tensorComplianceMetaData(
1457 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001458 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001459
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001460 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001461
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001462 def build_concat(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001463 self,
1464 rng,
1465 op,
1466 inputs,
1467 args_dict,
1468 validator_fcns=None,
1469 error_name=None,
1470 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001471 ):
Won Jeon74342e52024-01-09 00:34:40 +00001472 if op["op"] == Op.CONCAT_SHAPE:
1473 axis = 0
1474 else:
1475 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001476 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001477 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001478
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001479 result_tensor = OutputShaper.concatOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001480 self.ser, rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001481 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001482
Matthew Haddon818ab902021-07-27 09:12:49 +01001483 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001484 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001485 input_tensor_names.append(tensor.name)
1486
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001487 # Invalidate Input/Output list for error if checks.
1488 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001489 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001490 pCount, cCount = op["operands"]
1491 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001492 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001493 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001494 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001495
Les Bell729b0352021-11-24 10:28:21 +00001496 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001497 self.ser,
1498 validator_fcns,
1499 error_name,
1500 op=op,
1501 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001502 input_shape=inputs[0].shape,
1503 output_shape=result_tensor.shape,
1504 input_dtype=inputs[0].dtype,
1505 output_dtype=result_tensor.dtype,
1506 inputs=inputs,
1507 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001508 input_list=input_list,
1509 output_list=output_list,
1510 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001511 ):
1512 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001513
Won Jeon74342e52024-01-09 00:34:40 +00001514 if op["op"] == Op.CONCAT:
1515 attr = ts.TosaSerializerAttribute()
1516 attr.AxisAttribute(axis)
1517 else:
1518 assert op["op"] == Op.CONCAT_SHAPE
1519 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001520 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001521
1522 compliance = self.tensorComplianceMetaData(
1523 op, inputs[0].dtype, args_dict, result_tensor, error_name
1524 )
1525
1526 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001527
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001528 def build_pad(
1529 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001530 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001531 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001532 inputs,
1533 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001534 validator_fcns=None,
1535 error_name=None,
1536 qinfo=None,
1537 ):
Tai Lye095da72024-01-25 22:00:18 +00001538 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001539 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001540 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001541 padding = args_dict["pad"]
1542 pad_const_int = args_dict["pad_const_int"]
1543 pad_const_float = args_dict["pad_const_fp"]
1544
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001545 result_tensor = OutputShaper.padOp(self.ser, rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001546
Tai Ly60dc48c2024-03-08 22:19:41 +00001547 # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
1548 if gtu.dtypeIsFloat(a.dtype):
1549 pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
1550 else:
1551 pad_const_val_as_bytes = struct.pack("<i", pad_const_int)
1552
Kevin Chengfe392ce2021-10-18 21:51:55 +00001553 attr = ts.TosaSerializerAttribute()
Tai Ly60dc48c2024-03-08 22:19:41 +00001554 attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
Eric Kunzee5e26762020-10-13 16:11:07 -07001555
Matthew Haddone807aae2021-10-11 18:12:58 +01001556 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001557 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001558 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001559 pCount, cCount = op["operands"]
1560 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001561 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001562 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001563 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001564
Les Bell729b0352021-11-24 10:28:21 +00001565 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001566 self.ser,
1567 validator_fcns,
1568 error_name,
1569 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001570 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001571 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001572 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001573 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001574 pad=padding,
1575 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001576 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001577 input_list=input_list,
1578 output_list=output_list,
1579 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001580 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001581 ):
1582 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001583
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001584 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001585
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001586 compliance = self.tensorComplianceMetaData(
1587 op, a.dtype, args_dict, result_tensor, error_name
1588 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001589
1590 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001591
Won Jeona21b2e82023-08-10 10:33:01 +00001592 def build_dim(
1593 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001594 rng,
Won Jeona21b2e82023-08-10 10:33:01 +00001595 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001596 inputs,
1597 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001598 validator_fcns=None,
1599 error_name=None,
1600 qinfo=None,
1601 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001602 assert len(inputs) == 1
1603 a = inputs[0]
1604 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001605 result_tensor = OutputShaper.dimOp(self.ser, rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001606
1607 # Invalidate Input/Output list for error if checks.
1608 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001609 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001610 pCount, cCount = op["operands"]
1611 num_operands = pCount + cCount
1612 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001613 rng, error_name, input_list, output_list
Won Jeona21b2e82023-08-10 10:33:01 +00001614 )
1615
1616 if not TosaErrorValidator.evValidateErrorIfs(
1617 self.ser,
1618 validator_fcns,
1619 error_name,
1620 op=op,
1621 axis=axis,
1622 input_shape=a.shape,
1623 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001624 output_shape=result_tensor.shape,
1625 output_dtype=result_tensor.dtype,
1626 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001627 input_list=input_list,
1628 output_list=output_list,
1629 num_operands=num_operands,
1630 ):
1631 return None
1632
1633 attr = ts.TosaSerializerAttribute()
1634 attr.AxisAttribute(axis)
1635
1636 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001637 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001638
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001639 def build_reshape(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001640 self,
1641 rng,
1642 op,
1643 inputs,
1644 args_dict,
1645 validator_fcns=None,
1646 error_name=None,
1647 qinfo=None,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001648 ):
Tai Ly8690a082023-12-18 20:40:24 +00001649 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001650 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001651 shape = inputs[1]
1652 shape_attr = args_dict["new_shape"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001653 result_tensor = OutputShaper.reshapeOp(self.ser, rng, a, shape_attr, error_name)
Matthew Haddone807aae2021-10-11 18:12:58 +01001654
1655 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001656 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001657 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001658 pCount, cCount = op["operands"]
1659 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001660 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001661 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001662 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001663
Les Bell729b0352021-11-24 10:28:21 +00001664 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001665 self.ser,
1666 validator_fcns,
1667 error_name,
1668 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001669 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001670 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001671 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001672 output_dtype=result_tensor.dtype,
1673 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001674 input_list=input_list,
1675 output_list=output_list,
1676 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001677 ):
1678 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001679
Tai Ly8690a082023-12-18 20:40:24 +00001680 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001681
1682 compliance = self.tensorComplianceMetaData(
1683 op, a.dtype, args_dict, result_tensor, error_name
1684 )
1685
1686 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001687
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001688 def build_reverse(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001689 self,
1690 rng,
1691 op,
1692 inputs,
1693 args_dict,
1694 validator_fcns=None,
1695 error_name=None,
1696 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001697 ):
1698 assert len(inputs) == 1
1699 a = inputs[0]
1700 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001701 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001702
1703 # Invalidate Input/Output list for error if checks.
1704 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001705 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001706 pCount, cCount = op["operands"]
1707 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001708 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001709 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001710 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001711
Les Bell729b0352021-11-24 10:28:21 +00001712 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001713 self.ser,
1714 validator_fcns,
1715 error_name,
1716 op=op,
1717 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001718 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001719 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001720 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001721 output_dtype=result_tensor.dtype,
1722 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001723 input_list=input_list,
1724 output_list=output_list,
1725 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001726 ):
1727 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001728
1729 attr = ts.TosaSerializerAttribute()
1730 attr.AxisAttribute(axis)
1731
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001732 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001733 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001734
evacha0198477222024-01-26 12:25:32 +00001735 def build_transpose(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001736 self,
1737 rng,
1738 op,
1739 inputs,
1740 args_dict,
1741 validator_fcns=None,
1742 error_name=None,
1743 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00001744 ):
1745 assert len(inputs) == 1
1746 a = inputs[0]
1747 perms = args_dict["perms"]
1748
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001749 result_tensor = OutputShaper.transposeOp(self.ser, rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001750
Kevin Chengfe392ce2021-10-18 21:51:55 +00001751 attr = ts.TosaSerializerAttribute()
1752 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001753
Matthew Haddone807aae2021-10-11 18:12:58 +01001754 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001755 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001756 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001757 pCount, cCount = op["operands"]
1758 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001759 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001760 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001761 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001762
Les Bell729b0352021-11-24 10:28:21 +00001763 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001764 self.ser,
1765 validator_fcns,
1766 error_name,
1767 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001769 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001770 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001771 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001772 output_dtype=result_tensor.dtype,
1773 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001774 input_list=input_list,
1775 output_list=output_list,
1776 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001777 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001778 ):
1779 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001780
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001781 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001782
1783 compliance = self.tensorComplianceMetaData(
1784 op, a.dtype, args_dict, result_tensor, error_name
1785 )
1786
1787 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001788
evacha017f7d4252024-01-24 12:08:09 +00001789 def build_slice(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001790 self,
1791 rng,
1792 op,
1793 inputs,
1794 args_dict,
1795 validator_fcns=None,
1796 error_name=None,
1797 qinfo=None,
evacha017f7d4252024-01-24 12:08:09 +00001798 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001799 assert len(inputs) == 3
1800 a, start_var, size_var = inputs
1801 start_const = args_dict["start"]
1802 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001803
1804 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001805 self.ser, rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001806 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001807
1808 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001809 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001810 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001811 pCount, cCount = op["operands"]
1812 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001814 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001815 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001816
Les Bell729b0352021-11-24 10:28:21 +00001817 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001818 self.ser,
1819 validator_fcns,
1820 error_name,
1821 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001822 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001823 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001825 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001826 start=start_const,
1827 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001828 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001829 input_list=input_list,
1830 output_list=output_list,
1831 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001832 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001833 ):
1834 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001835
Tai Ly8ead6c42024-02-14 22:35:44 +00001836 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001837
1838 compliance = self.tensorComplianceMetaData(
1839 op, a.dtype, args_dict, result_tensor, error_name
1840 )
1841
1842 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001843
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001844 def build_tile(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001845 self,
1846 rng,
1847 op,
1848 inputs,
1849 args_dict,
1850 validator_fcns=None,
1851 error_name=None,
1852 qinfo=None,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001853 ):
Tai Ly8690a082023-12-18 20:40:24 +00001854 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001855 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001856 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001857 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001858 result_tensor = OutputShaper.tileOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001859 self.ser, rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001860 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001861
1862 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001863 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001864 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001865 pCount, cCount = op["operands"]
1866 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001867 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001868 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001869 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001870
Les Bell729b0352021-11-24 10:28:21 +00001871 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001872 self.ser,
1873 validator_fcns,
1874 error_name,
1875 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001876 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001877 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001878 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001879 output_dtype=result_tensor.dtype,
1880 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001881 input_list=input_list,
1882 output_list=output_list,
1883 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001884 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001885 ):
1886 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001887
Tai Ly8690a082023-12-18 20:40:24 +00001888 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001889
1890 compliance = self.tensorComplianceMetaData(
1891 op, a.dtype, args_dict, result_tensor, error_name
1892 )
1893
1894 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001895
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001896 def build_gather(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001897 self,
1898 rng,
1899 op,
1900 inputs,
1901 args_dict,
1902 validator_fcns=None,
1903 error_name=None,
1904 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001905 ):
1906 assert len(inputs) == 2
1907 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001909 result_tensor = OutputShaper.gatherOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001910 self.ser, rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001911 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001912
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001913 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001914 input_list = [values.name, indices.name]
1915 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001916 pCount, cCount = op["operands"]
1917 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001918 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001919 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001920 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001921
Les Bell729b0352021-11-24 10:28:21 +00001922 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001923 self.ser,
1924 validator_fcns,
1925 error_name,
1926 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001927 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001928 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001929 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001930 output_dtype=result_tensor.dtype,
1931 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001932 input_list=input_list,
1933 output_list=output_list,
1934 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001935 ):
1936 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001937
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001938 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001939
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001940 compliance = self.tensorComplianceMetaData(
1941 op, values.dtype, args_dict, result_tensor, error_name
1942 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001944 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001945
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001946 def build_scatter(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001947 self,
1948 rng,
1949 op,
1950 inputs,
1951 args_dict,
1952 validator_fcns=None,
1953 error_name=None,
1954 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001955 ):
1956 assert len(inputs) == 3
1957 values_in, indices, input = inputs
1958 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001959 self.ser, rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001960 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001961
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001962 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001963 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001964 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001965 pCount, cCount = op["operands"]
1966 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001967 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001968 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001969 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001970
Les Bell729b0352021-11-24 10:28:21 +00001971 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001972 self.ser,
1973 validator_fcns,
1974 error_name,
1975 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001976 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001977 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001978 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001979 output_dtype=result_tensor.dtype,
1980 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001981 input_list=input_list,
1982 output_list=output_list,
1983 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001984 ):
1985 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001986
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001987 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001988
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001989 compliance = self.tensorComplianceMetaData(
1990 op, values_in.dtype, args_dict, result_tensor, error_name
1991 )
1992
1993 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001994
Kevin Cheng550ccc52021-03-03 11:21:43 -08001995 def build_resize(
1996 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001997 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001998 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001999 inputs,
2000 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01002001 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002002 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002003 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002004 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002005 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002006 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002007 scale_input = inputs[1]
2008 offset_input = inputs[2]
2009 border_input = inputs[3]
2010
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002011 mode = args_dict["mode"]
2012 scale = args_dict["scale"]
2013 offset = args_dict["offset"]
2014 border = args_dict["border"]
2015 output_dtype = args_dict["output_dtype"]
2016
2017 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002018 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002019 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002020 input,
2021 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002022 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002023 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002024 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002025 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002026 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002027 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002028 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002029
Matthew Haddon848efb42021-09-09 12:30:53 +01002030 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002031 input_list = [
2032 input.name,
2033 scale_input.name,
2034 offset_input.name,
2035 border_input.name,
2036 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002037 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002038 pCount, cCount = op["operands"]
2039 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002040 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002041 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002042 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002043
Les Bell729b0352021-11-24 10:28:21 +00002044 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002045 self.ser,
2046 validator_fcns,
2047 error_name,
2048 op=op,
2049 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002050 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002051 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002052 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002053 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002054 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002055 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002056 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002057 input_list=input_list,
2058 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002059 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002060 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002061 ):
2062 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002063
Eric Kunzee5e26762020-10-13 16:11:07 -07002064 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002065 # write empty scale/offset/border into ResizeAttribute
2066 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002067 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002068
2069 compliance = self.tensorComplianceMetaData(
2070 op, input.dtype, args_dict, result_tensor, error_name
2071 )
2072
2073 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002074
evacha0198477222024-01-26 12:25:32 +00002075 def build_const(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002076 self,
2077 rng,
2078 op,
2079 inputs,
2080 args_dict,
2081 validator_fcns=None,
2082 error_name=None,
2083 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00002084 ):
2085 assert len(inputs) == 1
2086 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002087 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002088
2089 compliance = self.tensorComplianceMetaData(
2090 op, val.dtype, args_dict, val, error_name
2091 )
2092
2093 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002094
2095 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002096 def build_cast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002097 self,
2098 rng,
2099 op,
2100 inputs,
2101 args_dict,
2102 validator_fcns=None,
2103 error_name=None,
2104 qinfo=None,
Jeremy Johnson708da822023-11-15 16:25:45 +00002105 ):
2106 assert len(inputs) == 1
2107 val = inputs[0]
2108 out_dtype = args_dict["out_type"]
2109
2110 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002111 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002112 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002113
2114 # Invalidate Input/Output list for error if checks.
2115 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002116 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002117 pCount, cCount = op["operands"]
2118 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002119 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002120 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002121 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002122
Les Bell729b0352021-11-24 10:28:21 +00002123 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002124 self.ser,
2125 validator_fcns,
2126 error_name,
2127 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002128 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002129 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002130 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002131 output_dtype=result_tensor.dtype,
2132 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002133 input_list=input_list,
2134 output_list=output_list,
2135 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002136 ):
2137 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002138
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002139 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002140
2141 compliance = self.tensorComplianceMetaData(
2142 op, val.dtype, args_dict, result_tensor, error_name
2143 )
2144
2145 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002146
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002147 def build_rescale(
2148 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002149 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002150 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002151 inputs,
2152 args_dict,
2153 validator_fcns=None,
2154 error_name=None,
2155 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002156 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002157 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002158 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002159 multiplier_val = inputs[1]
2160 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002161 out_dtype = args_dict["output_dtype"]
2162 scale32 = args_dict["scale"]
2163 double_round = args_dict["double_round"]
2164 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002165 shift_arr = args_dict["shift"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002166 multiplier_arr = args_dict["multiplier"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002167
2168 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002169 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002170 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002171
2172 if per_channel:
2173 nc = val.shape[-1]
2174 else:
2175 nc = 1
2176
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002177 in_type_width = gtu.dtypeWidth(val.dtype)
2178 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002179
Tai Ly8690a082023-12-18 20:40:24 +00002180 input_unsigned = False
2181 output_unsigned = False
2182
Kevin Cheng3a478572021-01-22 17:21:02 -08002183 if val.dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002184 input_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002185 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002186 elif val.dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002187 input_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002188 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002189 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002190 elif error_name in [
2191 ErrorIf.InputZeroPointNotZero,
2192 ErrorIf.U16InputZeroPointNotValid,
2193 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002194 input_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002195 if input_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002196 input_zp = input_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002197 in_type_width += 1
2198 elif val.dtype == DType.UINT16:
2199 # Must come after ErrorIf.U16InputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002200 input_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002201 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002202 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002203 else:
2204 input_zp = 0
2205
Kevin Cheng3a478572021-01-22 17:21:02 -08002206 if out_dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002207 output_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002208 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002209 elif out_dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002210 output_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002211 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002212 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002213 elif error_name in [
2214 ErrorIf.OutputZeroPointNotZero,
2215 ErrorIf.U16OutputZeroPointNotValid,
2216 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002217 output_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002218 if output_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002219 output_zp = output_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002220 out_type_width += 1
2221 elif out_dtype == DType.UINT16:
2222 # Must come after ErrorIf.U16OutputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002223 output_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002224 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002225 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002226 else:
2227 output_zp = 0
2228
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002229 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2230 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002231
2232 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002233 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2234 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002235
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002236 logger.debug(
2237 f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
2238 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002239 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002240 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002241 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002242 assert val.placeholderFilename
2243 values = np.load(
2244 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2245 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002246 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2247 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2248 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002249 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2250 # Check we can safely convert to the expected dtype
2251 assert (
2252 val_adj.all() >= np.iinfo(values.dtype).min
2253 and val_adj.all() <= np.iinfo(values.dtype).max
2254 )
2255
2256 # Force casting to output datatype
2257 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2258
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002259 if not np.all(np.array_equal(values, val_adj)):
2260 # Values changed so overwrite file with new values
2261 np.save(
2262 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2263 val_adj,
2264 False,
2265 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002266
Matthew Haddonc2025212021-10-08 21:21:05 +01002267 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002268 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002269 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002270 pCount, cCount = op["operands"]
2271 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002272 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002273 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002274 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002275
2276 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002277 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002278 self.ser,
2279 validator_fcns,
2280 error_name,
2281 op=op,
2282 input_dtype=val.dtype,
2283 output_dtype=out_dtype,
2284 input_shape=val.shape,
2285 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002286 scale32=scale32,
2287 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002288 input_list=input_list,
2289 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002290 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002291 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002292 ):
2293 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002294
Eric Kunzee5e26762020-10-13 16:11:07 -07002295 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002296 attr.RescaleAttribute(
2297 input_zp,
2298 output_zp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002299 scale32,
2300 double_round,
2301 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002302 input_unsigned,
2303 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002304 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002305
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002306 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002307
2308 compliance = self.tensorComplianceMetaData(
2309 op, val.dtype, args_dict, result_tensor, error_name
2310 )
2311
2312 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002313
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002314 def _get_condition_tensor(self, rng, op, cond, error_name):
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002315 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002316 cond_type = gtu.get_wrong_output_type(op, rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002317 else:
2318 cond_type = DType.BOOL
2319 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002320 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002321 if choice == 1:
2322 cond_shape = [2]
2323 else:
2324 cond_shape = [1, 2]
2325 else:
2326 # Must be of size 1 (rank 0)
2327 cond_shape = []
2328 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2329 return cond_tens
2330
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002331 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002332 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002333 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002334 op,
2335 inputs,
2336 args_dict,
2337 validator_fcns=None,
2338 error_name=None,
2339 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002340 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002341 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002342 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002343 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002344 assert len(inputs) == 2
2345 then_tens, else_tens = inputs
2346
2347 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002348
2349 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002350 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002351
2352 # Make then/else tensors
2353 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002354
Jeremy Johnson587cc842024-02-08 11:45:44 +00002355 dtype = DType.INT32
2356
Matthew Haddon630c17c2021-10-14 15:05:41 +01002357 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002358 if error_name in [
2359 ErrorIf.CondIfOutputListThenGraphMismatch,
2360 ErrorIf.CondIfOutputListElseGraphMismatch,
2361 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002362 incorrect_shape = deepcopy(then_tens.shape)
2363 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002364 incorrect_shape[i] += (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002365 rng.choice([-3, -2, 2, 3])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002366 if incorrect_shape[i] > 3
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002367 else rng.choice([1, 2, 4])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002368 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002369 incorrect_arr = np.int32(rng.integers(0, 256, size=incorrect_shape))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002370
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002371 then_arr = np.int32(rng.integers(0, 256, size=out_shape))
2372 else_arr = np.int32(rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002373
2374 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002375 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002376
2377 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002378 then_block = "THEN_BLOCK"
2379 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002380 attr = ts.TosaSerializerAttribute()
2381 attr.CondIfAttribute(then_block, else_block)
2382
2383 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002384 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002385
Jerry Ge9e94af82022-10-27 09:57:00 -07002386 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002387 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002388 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002389 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002390 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002391 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002392 self.ser.addOutputTensor(then_tens)
2393
Jerry Ge9e94af82022-10-27 09:57:00 -07002394 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002395 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002396 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002397 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002398 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002399 self.ser.addOutputTensor(else_tens)
2400
Les Bell729b0352021-11-24 10:28:21 +00002401 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002402 self.ser,
2403 validator_fcns,
2404 error_name,
2405 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002406 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002407 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002408 ):
2409 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002410
Jeremy Johnson587cc842024-02-08 11:45:44 +00002411 compliance = self.tensorComplianceMetaData(
2412 op, dtype, args_dict, result_tensor, error_name
2413 )
2414
2415 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002416
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002417 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002418 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002419 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002420 op,
2421 inputs,
2422 args_dict,
2423 validator_fcns=None,
2424 error_name=None,
2425 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002426 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002427 # For cond_if with a binary op in the then/else blocks, take a and b and
2428 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002429 assert len(inputs) == 2
2430 a, b = inputs
2431
2432 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002433
2434 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002435 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002436
Jeremy Johnson587cc842024-02-08 11:45:44 +00002437 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002438
2439 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002440 then_block = "THEN_BLOCK"
2441 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002442 attr = ts.TosaSerializerAttribute()
2443 attr.CondIfAttribute(then_block, else_block)
2444
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002445 if error_name in [
2446 ErrorIf.CondIfInputListThenGraphMismatch,
2447 ErrorIf.CondIfInputListElseGraphMismatch,
2448 ErrorIf.CondIfOutputListElseGraphMismatch,
2449 ErrorIf.CondIfOutputListThenGraphMismatch,
2450 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002451 incorrect_shape = a.shape.copy()
2452 for i in range(len(incorrect_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002453 incorrect_shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002454 incorrect_block_input = deepcopy(a)
2455 incorrect_block_input.shape = incorrect_shape
2456
Eric Kunzee5e26762020-10-13 16:11:07 -07002457 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002458 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002459 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002460 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002461
James Ward24dbc422022-10-19 12:20:31 +01002462 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002463 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002464 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002465 then_op, else_op = (
2466 self.TOSA_OP_LIST["logical_right_shift"],
2467 self.TOSA_OP_LIST["logical_left_shift"],
2468 )
Les Bell6040b4d2021-10-11 12:50:31 +01002469 else:
2470 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002471
Jeremy Johnson587cc842024-02-08 11:45:44 +00002472 # Determine the element-wise binary operation that compliance will need to
2473 # check the results of
2474 compliance_op = then_op if cond else else_op
2475
2476 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002477 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002478 if (
2479 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2480 and block == then_block
2481 ) or (
2482 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2483 and block == else_block
2484 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002485 self.ser.addInputTensor(incorrect_block_input)
2486 self.ser.addInputTensor(b)
2487 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002488 elif (
2489 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2490 and block == then_block
2491 ) or (
2492 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2493 and block == else_block
2494 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002495 self.ser.addInputTensor(a)
2496 self.ser.addInputTensor(b)
2497 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2498 else:
2499 self.ser.addInputTensor(a)
2500 self.ser.addInputTensor(b)
2501 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002502 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002503
Les Bell729b0352021-11-24 10:28:21 +00002504 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002505 self.ser,
2506 validator_fcns,
2507 error_name,
2508 op=op,
2509 a=a,
2510 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002511 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002512 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002513 ):
2514 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002515
Jeremy Johnson587cc842024-02-08 11:45:44 +00002516 compliance = self.tensorComplianceMetaData(
2517 compliance_op, a.dtype, args_dict, result_tensor, error_name
2518 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002519
Jeremy Johnson587cc842024-02-08 11:45:44 +00002520 return TosaTestGen.BuildInfo(result_tensor, compliance)
2521
2522 def build_while_loop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002523 self,
2524 rng,
2525 op,
2526 inputs,
2527 args_dict,
2528 validator_fcns=None,
2529 error_name=None,
2530 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002531 ):
2532 assert len(inputs) == 1
2533 a = inputs[0]
2534 iter_val = args_dict["iterations"]
2535
Kevin Cheng550ccc52021-03-03 11:21:43 -08002536 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002537
Kevin Cheng550ccc52021-03-03 11:21:43 -08002538 cond_block = "COND_BLOCK"
2539 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002540
2541 attr = ts.TosaSerializerAttribute()
2542 attr.WhileLoopAttribute(cond_block, body_block)
2543
2544 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002545 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002546 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002547 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002548
2549 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002550 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2551 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002552 if error_name == ErrorIf.InputListOutputListMismatch:
2553 incorrect_acc = deepcopy(acc)
2554 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002555 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002556 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2557 else:
2558 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002559
2560 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002561 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002562 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002563 [iter.name, a.name, acc.name],
2564 [iter_out.name, a_out.name, acc_out.name],
2565 attr,
2566 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002567 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002568
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002569 if error_name in [
2570 ErrorIf.InputListCondGraphMismatch,
2571 ErrorIf.InputListBodyGraphInputMismatch,
2572 ErrorIf.InputListBodyGraphOutputMismatch,
2573 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002574 incorrect_iter = deepcopy(iter)
2575 for i in range(len(incorrect_iter.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002576 incorrect_iter.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002577 if len(incorrect_iter.shape) == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002578 incorrect_iter.shape.append(rng.choice([-3, -2, 2, 3]))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002579
2580 incorrect_acc = deepcopy(acc)
2581 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002582 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002583
Eric Kunzee5e26762020-10-13 16:11:07 -07002584 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002585 self.ser.addBasicBlock(cond_block)
2586
Matthew Haddon630c17c2021-10-14 15:05:41 +01002587 if error_name == ErrorIf.InputListCondGraphMismatch:
2588 self.ser.addInputTensor(incorrect_iter)
2589 self.ser.addInputTensor(a)
2590 self.ser.addInputTensor(incorrect_acc)
2591 else:
2592 self.ser.addInputTensor(iter)
2593 self.ser.addInputTensor(a)
2594 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002595 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002596
2597 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002598 cond_type = rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002599 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002600 cond_type = DType.BOOL
2601 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002602 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002603 if choice == 1:
2604 cond_shape = [3]
2605 else:
2606 cond_shape = [1, 2]
2607 else:
2608 cond_shape = []
2609 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002610
Kevin Cheng550ccc52021-03-03 11:21:43 -08002611 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002612
2613 # BODY block (input: a, acc, iter, output: a, acc, iter)
2614 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002615 self.ser.addBasicBlock(body_block)
2616
Matthew Haddon630c17c2021-10-14 15:05:41 +01002617 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2618 self.ser.addInputTensor(incorrect_iter)
2619 self.ser.addInputTensor(a)
2620 self.ser.addInputTensor(incorrect_acc)
2621 else:
2622 self.ser.addInputTensor(iter)
2623 self.ser.addInputTensor(a)
2624 self.ser.addInputTensor(acc)
2625
Kevin Cheng550ccc52021-03-03 11:21:43 -08002626 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002627
2628 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002629 iter_body_out = self.ser.addIntermediate(
2630 incorrect_iter.shape, incorrect_iter.dtype
2631 )
2632 acc_body_out = self.ser.addIntermediate(
2633 incorrect_acc.shape, incorrect_acc.dtype
2634 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002635 else:
2636 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2637 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2638
Eric Kunzee5e26762020-10-13 16:11:07 -07002639 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2640 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2641 self.ser.addOutputTensor(iter_body_out)
2642 self.ser.addOutputTensor(a)
2643 self.ser.addOutputTensor(acc_body_out)
2644
Les Bell729b0352021-11-24 10:28:21 +00002645 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002646 self.ser,
2647 validator_fcns,
2648 error_name,
2649 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002650 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002651 ):
2652 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002653
Jeremy Johnson587cc842024-02-08 11:45:44 +00002654 compliance = self.tensorComplianceMetaData(
2655 op, a.dtype, args_dict, acc_out, error_name
2656 )
2657
2658 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002659
Luke Hutton57287132023-02-06 14:54:18 +00002660 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002661 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002662 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002663 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002664 inputs,
2665 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002666 validator_fcns=None,
2667 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002668 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002669 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002670 assert len(inputs) == 2
2671 val1, val2 = inputs
2672 inverse = args_dict["inverse"]
2673
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002674 results = OutputShaper.fft2dOp(self.ser, rng, val1, val2, error_name)
Luke Hutton57287132023-02-06 14:54:18 +00002675
2676 input_names = [val1.name, val2.name]
2677 pCount, cCount = op["operands"]
2678 num_operands = pCount + cCount
2679
2680 output_names = [res.name for res in results]
2681 output_shapes = [res.shape for res in results]
2682 output_dtypes = [res.dtype for res in results]
2683
2684 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002685 rng, error_name, input_names, output_names
Luke Hutton57287132023-02-06 14:54:18 +00002686 )
2687
2688 if not TosaErrorValidator.evValidateErrorIfs(
2689 self.ser,
2690 validator_fcns,
2691 error_name,
2692 op=op,
2693 inverse=inverse,
2694 input1=val1,
2695 input2=val2,
2696 input_shape=val1.shape,
2697 input_dtype=val1.dtype,
2698 output_shape=output_shapes,
2699 output_dtype=output_dtypes,
2700 result_tensors=results,
2701 input_list=input_names,
2702 output_list=output_names,
2703 num_operands=num_operands,
2704 ):
2705 return None
2706
Tai Lyd3797f02023-11-15 23:06:19 +00002707 # TODO - Test local_bound, for now set local bound attribute to False
2708 local_bound = False
2709
Luke Hutton57287132023-02-06 14:54:18 +00002710 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002711 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002712
2713 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002714
2715 compliance = []
2716 for res in results:
2717 compliance.append(
2718 self.tensorComplianceMetaData(
2719 op, val1.dtype, args_dict, res, error_name
2720 )
2721 )
2722
2723 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002724
Tai Lyd3797f02023-11-15 23:06:19 +00002725 def build_rfft2d(
2726 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002727 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002728 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002729 inputs,
2730 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002731 validator_fcns=None,
2732 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002733 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002734 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002735 assert len(inputs) == 1
2736 val = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002737 results = OutputShaper.rfft2dOp(self.ser, rng, val, error_name)
Luke Hutton261b7b62023-01-10 14:50:31 +00002738
2739 input_names = [val.name]
2740 pCount, cCount = op["operands"]
2741 num_operands = pCount + cCount
2742
2743 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002744 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002745 output_dtypes = [res.dtype for res in results]
2746
2747 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002748 rng, error_name, input_names, output_names
Luke Hutton261b7b62023-01-10 14:50:31 +00002749 )
2750
2751 if not TosaErrorValidator.evValidateErrorIfs(
2752 self.ser,
2753 validator_fcns,
2754 error_name,
2755 op=op,
2756 input_shape=val.shape,
2757 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002758 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002759 output_dtype=output_dtypes,
2760 result_tensors=results,
2761 input_list=input_names,
2762 output_list=output_names,
2763 num_operands=num_operands,
2764 ):
2765 return None
2766
Tai Lyd3797f02023-11-15 23:06:19 +00002767 # TODO - Test local_bound, for now set local bound attribute to False
2768 local_bound = False
2769
2770 attr = ts.TosaSerializerAttribute()
2771 attr.RFFTAttribute(local_bound)
2772
2773 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002774
2775 compliance = []
2776 for res in results:
2777 compliance.append(
2778 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2779 )
2780
2781 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002782
Won Jeon74342e52024-01-09 00:34:40 +00002783 def build_shape_op(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002784 self,
2785 rng,
2786 op,
2787 inputs,
2788 args_dict,
2789 validator_fcns=None,
2790 error_name=None,
2791 qinfo=None,
Won Jeon74342e52024-01-09 00:34:40 +00002792 ):
2793 assert len(inputs) == 2
2794 a, b = inputs
2795
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002796 result_tensor = OutputShaper.addShapeOp(self.ser, rng, a, b, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00002797
2798 # Invalidate Input/Output list for error if checks.
2799 input_list = [a.name, b.name]
2800 output_list = [result_tensor.name]
2801 pCount, cCount = op["operands"]
2802 num_operands = pCount + cCount
2803 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2804 self, error_name, input_list, output_list
2805 )
2806
2807 if not TosaErrorValidator.evValidateErrorIfs(
2808 self.ser,
2809 validator_fcns,
2810 error_name,
2811 op=op,
2812 input1=a,
2813 input2=b,
2814 input_shape=a.shape,
2815 input_dtype=a.dtype,
2816 output_shape=result_tensor.shape,
2817 output_dtype=result_tensor.dtype,
2818 result_tensors=[result_tensor],
2819 input_list=input_list,
2820 output_list=output_list,
2821 num_operands=num_operands,
2822 ):
2823 return None
2824
2825 self.ser.addOperator(
2826 op["op"],
2827 input_list,
2828 output_list,
2829 )
2830 compliance = self.tensorComplianceMetaData(
2831 op, a.dtype, args_dict, result_tensor, error_name
2832 )
2833
2834 return TosaTestGen.BuildInfo(result_tensor, compliance)
2835
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002836 def create_filter_lists(
2837 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2838 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002839 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2840 default_test_rank_range = range(1, 5)
2841 if not shapeFilter:
2842 shapeFilter = [None]
2843
2844 # Calculate the filters based on what is requested and what the operator allows
2845 rmin, rmax = op["rank"]
2846 if rankFilter is not None:
2847 cleanRankFilter = []
2848 # Ensure rankFilter values are allowed by operator
2849 for rank in rankFilter:
2850 if rank >= rmin and rank <= rmax:
2851 cleanRankFilter.append(rank)
2852 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002853 # Ensure default behaviour is bounded by default range or by operator,
2854 # whichever is the smaller range of ranks.
2855 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002856 cleanRankFilter = (
2857 opRankRange
2858 if len(opRankRange) <= len(default_test_rank_range)
2859 else default_test_rank_range
2860 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002861 else:
2862 cleanRankFilter = range(rmin, rmax + 1)
2863
2864 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002865
Matthew Haddon1c00b712021-10-01 15:51:03 +01002866 if dtypeFilter is not None:
2867 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002868 # Create list of operator dtypes filtered by requested dtypes
2869 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 if dtype in dtypeFilter or (
2871 isinstance(dtype, list) and dtype[0] in dtypeFilter
2872 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002873 cleanDtypeFilter.append(dtype)
2874 else:
2875 cleanDtypeFilter = dtypes
2876
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002877 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002878 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002879 "shapeFilter": shapeFilter,
2880 "rankFilter": cleanRankFilter,
2881 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002882 }
2883 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002884 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002885 if validator is not None:
2886 validator_info = validator(check=False, op=op)
2887 else:
2888 return None
2889
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002890 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002891
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002892 # Set parameters as required
2893 if error_arguments["rank"] is not None:
2894 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002895 else:
2896 rankFilter = cleanRankFilter
2897
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002898 if error_arguments["dtype"] is not None:
2899 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002900 else:
2901 dtypeFilter = cleanDtypeFilter
2902
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002903 if error_arguments["shape"] is not None:
2904 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002905 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002906 shapeFilter = shapeFilter[
2907 :2
2908 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002909
2910 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002911 "shapeFilter": shapeFilter,
2912 "rankFilter": rankFilter,
2913 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002914 }
2915 return filterDict
2916
Kevin Cheng550ccc52021-03-03 11:21:43 -08002917 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002918 self,
2919 opName,
2920 shapeFilter=[None],
2921 rankFilter=None,
2922 dtypeFilter=None,
2923 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002924 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002925
2926 try:
2927 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002928 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002929 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002930
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002931 if not self.args.stable_rng:
2932 # Initialize a new random number generator per op
2933 self.resetGlobalRNG()
Eric Kunzee5e26762020-10-13 16:11:07 -07002934
Jeremy Johnson1271c442023-09-05 11:39:26 +01002935 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002936
Eric Kunzee5e26762020-10-13 16:11:07 -07002937 # Test list consists of a tuple of:
2938 # (opName, testNameStr, dtype, shapeList, argumentsList)
2939 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002940 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002941 error_if_validators = op["error_if_validators"]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002942 num_error_types_created = 0
Matthew Haddon1c00b712021-10-01 15:51:03 +01002943 else:
2944 error_if_validators = [None]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002945 num_error_types_created = None
Eric Kunzee5e26762020-10-13 16:11:07 -07002946
Matthew Haddon1c00b712021-10-01 15:51:03 +01002947 for validator in error_if_validators:
2948 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002949 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002950 else:
2951 error_name = None
2952
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002953 filterDict = self.create_filter_lists(
2954 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2955 )
2956 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002957 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002958 cleanRankFilter = filterDict["rankFilter"]
2959 cleanDtypeFilter = filterDict["dtypeFilter"]
2960 cleanShapeFilter = filterDict["shapeFilter"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002961 logger.debug(
2962 f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
2963 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002964
2965 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002966 for t in cleanDtypeFilter:
2967 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002968 # Filter out by rank
2969 if shape is not None and len(shape) != r:
2970 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002971 self.setTargetShape(shape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002972 typeStr = self.typeStr(t)
2973 if self.args.stable_rng:
2974 shape_rng = TosaHashRandomGenerator(
2975 self.random_seed,
2976 [opName, r, typeStr],
2977 self.random_dtype_range,
2978 )
2979 else:
2980 shape_rng = self.global_rng
2981 shapeList = tgen_fcn(self, shape_rng, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002982
Matthew Haddon74567092021-07-16 15:38:20 +01002983 shapeStr = self.shapeStr(shapeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07002984
Matthew Haddon74567092021-07-16 15:38:20 +01002985 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2986 argList = []
2987 if agen_fcn:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002988 if self.args.stable_rng:
2989 arg_rng = TosaHashRandomGenerator(
2990 self.random_seed,
2991 [opName, shapeStr, typeStr],
2992 self.random_dtype_range,
2993 )
2994 else:
2995 arg_rng = self.global_rng
2996
2997 argList = agen_fcn(
2998 self, arg_rng, opName, shapeList, t, error_name
2999 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003000 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003001 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003002
Matthew Haddon74567092021-07-16 15:38:20 +01003003 for argStr, args in argList:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003004 # Create the test name string - for example: add_1x2x3_i32
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003005 if testType == "positive":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003006 name_parts = [opName, shapeStr, typeStr]
3007 else:
3008 assert testType == "negative"
3009 name_parts = [
3010 opName,
3011 "ERRORIF",
3012 error_name,
3013 shapeStr,
3014 typeStr,
3015 ]
3016 if argStr:
3017 name_parts.append(argStr)
3018 testStr = "_".join(name_parts)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003019
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003020 testList.append(
3021 (opName, testStr, t, error_name, shapeList, args)
3022 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003023 if error_name is not None:
3024 # Check the last test is of the error we wanted
3025 if len(testList) == 0 or testList[-1][3] != error_name:
3026 if self.args.level8k:
3027 logger.info(f"Missing {error_name} tests due to level8k mode")
3028 else:
3029 logger.error(f"ERROR: Failed to create any {error_name} tests")
3030 logger.debug(
3031 "Last test created: {}".format(
3032 testList[-1] if testList else None
3033 )
3034 )
3035 else:
3036 # Successfully created at least one ERRROR_IF test
3037 num_error_types_created += 1
Matthew Haddon1c00b712021-10-01 15:51:03 +01003038
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003039 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01003040 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3041 if "invalid_test_validators" in op:
3042 invalid_test_validators = op["invalid_test_validators"]
3043 clean_testList = []
3044 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01003045 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003046 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003047 if validator_fcn(
3048 opName=test[0],
3049 input_dtype=test[2],
3050 shapeList=test[4],
3051 args=test[5],
3052 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003053 remove_test = True
3054 if not remove_test:
3055 clean_testList.append(test)
3056 testList = clean_testList
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003057 else:
3058 if num_error_types_created is not None and not self.args.level8k:
3059 remaining_error_types = (
3060 len(error_if_validators) - num_error_types_created
3061 )
3062 if remaining_error_types:
3063 raise Exception(
3064 f"Failed to create {remaining_error_types} error types for {opName}"
3065 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003066
3067 return testList
3068
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003069 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00003070 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003071 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003072 try:
3073 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003074 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003075 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003076
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003077 logger.info(f"Creating {testStr}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01003078
Eric Kunzee5e26762020-10-13 16:11:07 -07003079 # Create a serializer
3080 self.createSerializer(opName, testStr)
3081
Jeremy Johnson1271c442023-09-05 11:39:26 +01003082 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003083 if "error_if_validators" in op:
3084 error_if_validators = op["error_if_validators"]
3085 else:
3086 error_if_validators = None
3087
Kevin Cheng550ccc52021-03-03 11:21:43 -08003088 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003089 num_operands = pCount + cCount
3090
3091 if isinstance(dtype_or_dtypeList, list):
3092 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003093 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003094 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003095 else:
3096 dtypeList = [dtype_or_dtypeList] * (num_operands)
3097
Won Jeon74342e52024-01-09 00:34:40 +00003098 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003099 assert (
3100 len(shapeList) == num_operands
3101 ), "shapeList length {} must match number of operands {}".format(
3102 len(shapeList), num_operands
3103 )
3104 assert (
3105 len(dtypeList) == num_operands
3106 ), "dtypeList length {} must match number of operands {}".format(
3107 len(dtypeList), num_operands
3108 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003109
3110 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003111 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003112 except KeyError:
3113 qgen = None
3114
3115 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003116
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003117 # Set the random number generator
3118 if self.args.stable_rng:
3119 build_rng = TosaHashRandomGenerator(
3120 self.random_seed, [testStr], self.random_dtype_range
3121 )
3122 else:
3123 build_rng = self.global_rng
3124
Matthew Haddon1c00b712021-10-01 15:51:03 +01003125 if qgen is not None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003126 qinfo = qgen(
3127 build_rng, self.args.zeropoint, op, dtype_or_dtypeList, error_name
3128 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003129 else:
3130 qinfo = None
3131
Jeremy Johnson1271c442023-09-05 11:39:26 +01003132 # Extra meta data for the desc.json
3133 tensMeta = {}
3134
Jeremy Johnson587cc842024-02-08 11:45:44 +00003135 # Check we are using the new interface with an argsDict dictionary
3136 assert isinstance(
3137 argsDict, dict
3138 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003139
Jeremy Johnson587cc842024-02-08 11:45:44 +00003140 # New interface with args info in dictionary
3141 assert "dg_type" in argsDict
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003142 tvgInfo = tvgen_fcn(
3143 self, build_rng, opName, dtypeList, shapeList, argsDict, error_name
3144 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003145 if tvgInfo.dataGenDict:
3146 tensMeta["data_gen"] = tvgInfo.dataGenDict
3147 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003148
Jeremy Johnson587cc842024-02-08 11:45:44 +00003149 result = build_fcn(
3150 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003151 build_rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003152 op,
3153 tens,
3154 argsDict,
3155 validator_fcns=error_if_validators,
3156 error_name=error_name,
3157 qinfo=qinfo,
3158 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003159
Jeremy Johnson1271c442023-09-05 11:39:26 +01003160 if result:
Les Bell729b0352021-11-24 10:28:21 +00003161 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003162 if isinstance(result, TosaTestGen.BuildInfo):
3163 # Add the compliance meta data (if any)
3164 compliance = result.getComplianceInfo()
3165 if compliance:
3166 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003167 self.serialize("test", tensMeta)
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003168 return True
Les Bell729b0352021-11-24 10:28:21 +00003169 else:
3170 # The test is not valid
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003171 logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003172 return False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003173
Eric Kunzee5e26762020-10-13 16:11:07 -07003174 def createDynamicOpLists(self):
3175
Jeremy Johnson00423432022-09-12 17:27:37 +01003176 if "conv2d_TEMPLATE" not in self.TOSA_OP_LIST:
3177 # Already created these lists (can occur when class is initialized more than once)
3178 return
3179
Eric Kunzee5e26762020-10-13 16:11:07 -07003180 # Dynamically create op lists for convolutions with a list of kernel sizes
Jeremy Johnson0c716862023-04-13 17:18:19 +01003181 if not self.args.level8k:
3182 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3183 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3184 else:
3185 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3186 KERNELS_2D = [[1, bigK], [bigK, 2]]
3187 KERNELS_3D = [[1, bigK, 1], [2, 2, bigK]]
Eric Kunzee5e26762020-10-13 16:11:07 -07003188
Kevin Cheng1533b852021-09-01 12:51:58 -07003189 for k in KERNELS_2D:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003190 testName = "conv2d_{}x{}".format(k[0], k[1])
3191 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
3192 self.TOSA_OP_LIST[testName]["filter"] = k
3193 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003194 self.TOSA_OP_LIST[testName]["real_name"] = "conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003195
Kevin Cheng550ccc52021-03-03 11:21:43 -08003196 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
3197 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3198 "depthwise_conv2d_TEMPLATE"
3199 ].copy()
3200 self.TOSA_OP_LIST[testName]["filter"] = k
3201 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003202 self.TOSA_OP_LIST[testName]["real_name"] = "depthwise_conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003203
Kevin Cheng550ccc52021-03-03 11:21:43 -08003204 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
3205 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
3206 "transpose_conv2d_TEMPLATE"
3207 ].copy()
3208 self.TOSA_OP_LIST[testName]["filter"] = k
3209 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003210 self.TOSA_OP_LIST[testName]["real_name"] = "transpose_conv2d"
Eric Kunzee5e26762020-10-13 16:11:07 -07003211
Kevin Cheng1533b852021-09-01 12:51:58 -07003212 for k in KERNELS_3D:
3213 testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
3214 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
3215 self.TOSA_OP_LIST[testName]["filter"] = k
3216 self.TOSA_OP_LIST[testName]["template"] = False
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003217 self.TOSA_OP_LIST[testName]["real_name"] = "conv3d"
Kevin Cheng1533b852021-09-01 12:51:58 -07003218
Eric Kunzee5e26762020-10-13 16:11:07 -07003219 # Delete any templates after having created any dynamic ops
3220 # This is a two-pass operation because it's bad practice to delete
3221 # keys from dictionaries while iterating
3222 keyList = []
3223 for k in self.TOSA_OP_LIST:
3224 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003225 if self.TOSA_OP_LIST[k]["template"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07003226 keyList.append(k)
3227 continue
3228 except KeyError:
3229 pass
3230
3231 for k in keyList:
3232 del self.TOSA_OP_LIST[k]
3233
3234 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003235 """Fill in default fields for ops if they aren't already specified.
3236 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003237 for op in self.TOSA_OP_LIST:
3238
3239 # Required fields
3240 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003241 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003242 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003243 raise Exception(
3244 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3245 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003246
3247 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003248 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003249 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003250 raise Exception(
3251 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3252 op
3253 )
3254 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003255
3256 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003257 _ = self.TOSA_OP_LIST[op]["types"]
3258 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003259 raise Exception(
3260 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3261 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003262
3263 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003264 _ = self.TOSA_OP_LIST[op]["op"]
3265 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003266 raise Exception(
3267 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3268 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003269
3270 # Put in default rank range, if missing
3271 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003272 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003273 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003274 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003275
3276 # Tensor operator list
3277 # 'op': op name
3278 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003279 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3280 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003281 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3282 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003283 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003284
Kevin Cheng550ccc52021-03-03 11:21:43 -08003285 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003286 TYPE_INT_FP = [
3287 DType.INT8,
3288 DType.INT16,
3289 DType.INT32,
3290 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003291 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003292 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003293 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003294
Kevin Cheng550ccc52021-03-03 11:21:43 -08003295 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003296 TYPE_FI32 = [
3297 DType.FP32,
3298 DType.FP16,
3299 DType.BF16,
3300 DType.INT32,
3301 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003302 TYPE_FIB = [
3303 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003304 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003305 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003306 DType.INT8,
3307 DType.INT16,
3308 DType.INT32,
3309 DType.BOOL,
3310 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003311 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003312
Won Jeon2c34b462024-02-06 18:37:00 +00003313 TYPE_NARROW_INT_FP = [
3314 DType.INT8,
3315 DType.INT16,
3316 DType.FP16,
3317 DType.BF16,
3318 DType.FP32,
3319 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003320
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003321 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003322 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003323 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003324 [DType.INT8, DType.INT8, DType.INT32],
3325 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003326 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003327 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003328 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003329 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003330 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3331 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003332 ]
3333
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003334 DEFAULT_RANK_RANGE = (1, gtu.MAX_TENSOR_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003335
3336 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003337 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003338 "argmax": {
3339 "op": Op.ARGMAX,
3340 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003341 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 "build_fcn": (
3343 build_argmax,
3344 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003345 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003346 TosaArgGen.agAxis,
3347 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003348 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003349 "error_if_validators": (
3350 TosaErrorValidator.evAxisSmallerZero,
3351 TosaErrorValidator.evAxisLargerRank,
3352 TosaErrorValidator.evArgmaxOutputRankMismatch,
3353 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3354 TosaErrorValidator.evWrongRank,
3355 TosaErrorValidator.evWrongInputType,
3356 TosaErrorValidator.evWrongOutputType,
3357 TosaErrorValidator.evWrongInputList,
3358 TosaErrorValidator.evWrongOutputList,
3359 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003360 "data_gen": {
3361 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3362 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003363 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003364 "avg_pool2d": {
3365 "op": Op.AVG_POOL2D,
3366 "operands": (1, 0),
3367 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003368 "build_fcn": (
3369 build_pool2d,
3370 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003371 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003372 TosaArgGen.agPooling,
3373 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003374 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003375 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003376 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003377 "error_if_validators": (
3378 TosaErrorValidator.evKernelSmallerOne,
3379 TosaErrorValidator.evStrideSmallerOne,
3380 TosaErrorValidator.evPadSmallerZero,
3381 TosaErrorValidator.evWrongRank,
3382 TosaErrorValidator.evWrongInputType,
3383 TosaErrorValidator.evWrongOutputType,
3384 TosaErrorValidator.evWrongInputList,
3385 TosaErrorValidator.evWrongOutputList,
3386 TosaErrorValidator.evInputZeroPointNotZero,
3387 TosaErrorValidator.evOutputZeroPointNotZero,
3388 TosaErrorValidator.evPadLargerEqualKernel,
3389 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003390 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003391 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003392 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003393 "data_gen": {
3394 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3395 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003396 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003397 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003398 "conv2d_TEMPLATE": {
3399 "op": Op.CONV2D,
3400 "operands": (1, 2),
3401 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003402 "build_fcn": (
3403 build_conv2d,
3404 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003405 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003406 TosaArgGen.agConv,
3407 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003408 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003409 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003410 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3411 "error_if_validators": (
3412 TosaErrorValidator.evWrongInputType,
3413 TosaErrorValidator.evWrongOutputType,
3414 TosaErrorValidator.evWrongInputList,
3415 TosaErrorValidator.evWrongOutputList,
3416 TosaErrorValidator.evInputZeroPointNotZero,
3417 TosaErrorValidator.evWeightZeroPointNotZero,
3418 TosaErrorValidator.evPadSmallerZero,
3419 TosaErrorValidator.evStrideSmallerOne,
3420 TosaErrorValidator.evDilationSmallerOne,
3421 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003422 TosaErrorValidator.evConvOutputShapeMismatch,
3423 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003424 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003425 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003426 "data_gen": {
3427 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3428 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003429 "template": True,
3430 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003431 # Templated operator. Filled in by createDynamicOpLists
3432 "conv3d_TEMPLATE": {
3433 "op": Op.CONV3D,
3434 "operands": (1, 2),
3435 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003436 "build_fcn": (
3437 build_conv3d,
3438 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003439 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003440 TosaArgGen.agConv,
3441 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003442 "qgen": TosaQuantGen.qgConv,
3443 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003444 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3445 "error_if_validators": (
3446 TosaErrorValidator.evWrongInputType,
3447 TosaErrorValidator.evWrongOutputType,
3448 TosaErrorValidator.evWrongInputList,
3449 TosaErrorValidator.evWrongOutputList,
3450 TosaErrorValidator.evInputZeroPointNotZero,
3451 TosaErrorValidator.evWeightZeroPointNotZero,
3452 TosaErrorValidator.evPadSmallerZero,
3453 TosaErrorValidator.evStrideSmallerOne,
3454 TosaErrorValidator.evDilationSmallerOne,
3455 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003456 TosaErrorValidator.evConvOutputShapeMismatch,
3457 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003458 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003459 ),
evacha0147ab1762024-01-29 13:23:23 +00003460 "data_gen": {
3461 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3462 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003463 "template": True,
3464 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003465 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003466 "depthwise_conv2d_TEMPLATE": {
3467 "op": Op.DEPTHWISE_CONV2D,
3468 "operands": (1, 2),
3469 "filter": [1, 1],
3470 "rank": (4, 4),
3471 "build_fcn": (
3472 build_depthwise_conv2d,
3473 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003474 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003475 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003476 ),
3477 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003478 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003479 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3480 "error_if_validators": (
3481 TosaErrorValidator.evWrongInputType,
3482 TosaErrorValidator.evWrongOutputType,
3483 TosaErrorValidator.evWrongInputList,
3484 TosaErrorValidator.evWrongOutputList,
3485 TosaErrorValidator.evInputZeroPointNotZero,
3486 TosaErrorValidator.evWeightZeroPointNotZero,
3487 TosaErrorValidator.evPadSmallerZero,
3488 TosaErrorValidator.evStrideSmallerOne,
3489 TosaErrorValidator.evDilationSmallerOne,
3490 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003491 TosaErrorValidator.evConvOutputShapeMismatch,
3492 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003493 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003494 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003495 "data_gen": {
3496 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3497 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003498 "template": True,
3499 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003500 "fully_connected": {
3501 "op": Op.FULLY_CONNECTED,
3502 "operands": (1, 2),
3503 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003504 "build_fcn": (
3505 build_fully_connected,
3506 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003507 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003508 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003509 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003510 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003511 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003512 "error_if_validators": (
3513 TosaErrorValidator.evInputZeroPointNotZero,
3514 TosaErrorValidator.evWeightZeroPointNotZero,
3515 TosaErrorValidator.evWrongRank,
3516 TosaErrorValidator.evWrongInputType,
3517 TosaErrorValidator.evWrongOutputType,
3518 TosaErrorValidator.evWrongInputList,
3519 TosaErrorValidator.evWrongOutputList,
3520 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003521 "data_gen": {
3522 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3523 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003524 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003525 "matmul": {
3526 "op": Op.MATMUL,
3527 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003528 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003529 "build_fcn": (
3530 build_matmul,
3531 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003532 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003533 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003534 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003535 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003536 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003537 "error_if_validators": (
3538 TosaErrorValidator.evInputZeroPointNotZero,
3539 TosaErrorValidator.evWrongRank,
3540 TosaErrorValidator.evWrongInputType,
3541 TosaErrorValidator.evWrongOutputType,
3542 TosaErrorValidator.evWrongInputList,
3543 TosaErrorValidator.evWrongOutputList,
3544 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003545 "data_gen": {
3546 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003547 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003548 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003549 "max_pool2d": {
3550 "op": Op.MAX_POOL2D,
3551 "operands": (1, 0),
3552 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003553 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003554 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003555 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003556 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003557 TosaArgGen.agPooling,
3558 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003559 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003560 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003561 "error_if_validators": (
3562 TosaErrorValidator.evKernelSmallerOne,
3563 TosaErrorValidator.evStrideSmallerOne,
3564 TosaErrorValidator.evPadSmallerZero,
3565 TosaErrorValidator.evWrongRank,
3566 TosaErrorValidator.evWrongInputType,
3567 TosaErrorValidator.evWrongOutputType,
3568 TosaErrorValidator.evWrongInputList,
3569 TosaErrorValidator.evWrongOutputList,
3570 TosaErrorValidator.evPadLargerEqualKernel,
3571 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003572 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003573 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003574 "data_gen": {
3575 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3576 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003577 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003578 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003579 "transpose_conv2d_TEMPLATE": {
3580 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003581 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003582 "rank": (4, 4),
3583 "build_fcn": (
3584 build_transpose_conv2d,
3585 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003586 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003587 TosaArgGen.agTransposeConv2D,
3588 ),
3589 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003590 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003591 "invalid_test_validators": (
3592 TosaInvalidValidator.ivHeightWidthInvalid,
3593 TosaInvalidValidator.ivNonPositiveOutputShape,
3594 ),
3595 "error_if_validators": (
3596 TosaErrorValidator.evWrongInputType,
3597 TosaErrorValidator.evWrongOutputType,
3598 TosaErrorValidator.evWrongInputList,
3599 TosaErrorValidator.evWrongOutputList,
3600 TosaErrorValidator.evInputZeroPointNotZero,
3601 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003602 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003603 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003604 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003605 TosaErrorValidator.evConvOutputShapeMismatch,
Tai Lyf36f2562024-03-14 16:21:29 +00003606 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003607 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003608 "data_gen": {
3609 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3610 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003611 "template": True,
3612 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003613 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003614 "clamp": {
3615 "op": Op.CLAMP,
3616 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003617 "build_fcn": (
3618 build_clamp,
3619 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003620 TosaTensorValuesGen.tvgLazyGenDefault,
3621 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003622 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003623 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003624 "error_if_validators": (
3625 TosaErrorValidator.evMaxSmallerMin,
3626 TosaErrorValidator.evWrongInputType,
3627 TosaErrorValidator.evWrongOutputType,
3628 TosaErrorValidator.evWrongInputList,
3629 TosaErrorValidator.evWrongOutputList,
3630 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003631 "data_gen": {
3632 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3633 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003634 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003635 "sigmoid": {
3636 "op": Op.SIGMOID,
3637 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003638 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003639 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003640 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003641 TosaTensorValuesGen.tvgLazyGenDefault,
3642 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003643 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003644 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003645 "error_if_validators": (
3646 TosaErrorValidator.evWrongInputType,
3647 TosaErrorValidator.evWrongOutputType,
3648 TosaErrorValidator.evWrongInputList,
3649 TosaErrorValidator.evWrongOutputList,
3650 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003651 "data_gen": {
3652 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3653 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003654 },
3655 "tanh": {
3656 "op": Op.TANH,
3657 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003658 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003659 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003660 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003661 TosaTensorValuesGen.tvgLazyGenDefault,
3662 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003663 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003664 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003665 "error_if_validators": (
3666 TosaErrorValidator.evWrongInputType,
3667 TosaErrorValidator.evWrongOutputType,
3668 TosaErrorValidator.evWrongInputList,
3669 TosaErrorValidator.evWrongOutputList,
3670 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003671 "data_gen": {
3672 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3673 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003674 "compliance": {
3675 "abs_error_lower_bound": 0.5,
3676 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003677 },
Won Jeon78155c62023-06-10 00:20:04 +00003678 "erf": {
3679 "op": Op.ERF,
3680 "operands": (1, 0),
3681 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003682 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003683 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003684 TosaTensorValuesGen.tvgLazyGenDefault,
3685 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003686 ),
3687 "types": TYPE_FP,
3688 "error_if_validators": (
3689 TosaErrorValidator.evWrongInputType,
3690 TosaErrorValidator.evWrongOutputType,
3691 TosaErrorValidator.evWrongInputList,
3692 TosaErrorValidator.evWrongOutputList,
3693 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003694 "data_gen": {
3695 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3696 },
3697 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003698 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003699 # Elementwise Binary Operators
3700 "add": {
3701 "op": Op.ADD,
3702 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003703 "build_fcn": (
3704 build_binary_broadcast,
3705 TosaTensorGen.tgBroadcastFuzz,
3706 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003707 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003708 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003709 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003710 "error_if_validators": (
3711 TosaErrorValidator.evRankMismatch,
3712 TosaErrorValidator.evWrongInputType,
3713 TosaErrorValidator.evWrongOutputType,
3714 TosaErrorValidator.evWrongInputList,
3715 TosaErrorValidator.evWrongOutputList,
3716 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003717 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003718 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003719 "data_gen": {
3720 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3721 },
3722 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003723 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003724 "arithmetic_right_shift": {
3725 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3726 "operands": (2, 0),
3727 "build_fcn": (
3728 build_arithmetic_right_shift,
3729 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003730 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003731 TosaArgGen.agArithmeticRightShift,
3732 ),
3733 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003734 "error_if_validators": (
3735 TosaErrorValidator.evRankMismatch,
3736 TosaErrorValidator.evWrongInputType,
3737 TosaErrorValidator.evWrongOutputType,
3738 TosaErrorValidator.evWrongInputList,
3739 TosaErrorValidator.evWrongOutputList,
3740 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003741 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003742 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003743 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003744 "bitwise_and": {
3745 "op": Op.BITWISE_AND,
3746 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003747 "build_fcn": (
3748 build_binary_broadcast,
3749 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003750 TosaTensorValuesGen.tvgLazyGenDefault,
3751 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003752 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003753 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003754 "error_if_validators": (
3755 TosaErrorValidator.evRankMismatch,
3756 TosaErrorValidator.evWrongInputType,
3757 TosaErrorValidator.evWrongOutputType,
3758 TosaErrorValidator.evWrongInputList,
3759 TosaErrorValidator.evWrongOutputList,
3760 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003761 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003762 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003763 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003764 "bitwise_or": {
3765 "op": Op.BITWISE_OR,
3766 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003767 "build_fcn": (
3768 build_binary_broadcast,
3769 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003770 TosaTensorValuesGen.tvgLazyGenDefault,
3771 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003772 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003773 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003774 "error_if_validators": (
3775 TosaErrorValidator.evRankMismatch,
3776 TosaErrorValidator.evWrongInputType,
3777 TosaErrorValidator.evWrongOutputType,
3778 TosaErrorValidator.evWrongInputList,
3779 TosaErrorValidator.evWrongOutputList,
3780 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003781 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003782 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003783 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003784 "bitwise_xor": {
3785 "op": Op.BITWISE_XOR,
3786 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003787 "build_fcn": (
3788 build_binary_broadcast,
3789 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003790 TosaTensorValuesGen.tvgLazyGenDefault,
3791 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003792 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003793 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003794 "error_if_validators": (
3795 TosaErrorValidator.evRankMismatch,
3796 TosaErrorValidator.evWrongInputType,
3797 TosaErrorValidator.evWrongOutputType,
3798 TosaErrorValidator.evWrongInputList,
3799 TosaErrorValidator.evWrongOutputList,
3800 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003801 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003802 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003803 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003804 "intdiv": {
3805 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003806 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003807 "build_fcn": (
3808 build_binary_broadcast,
3809 TosaTensorGen.tgBroadcastFuzz,
3810 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003811 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003812 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003813 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003814 "error_if_validators": (
3815 TosaErrorValidator.evRankMismatch,
3816 TosaErrorValidator.evWrongInputType,
3817 TosaErrorValidator.evWrongOutputType,
3818 TosaErrorValidator.evWrongInputList,
3819 TosaErrorValidator.evWrongOutputList,
3820 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003821 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003822 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003823 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003824 "logical_and": {
3825 "op": Op.LOGICAL_AND,
3826 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003827 "build_fcn": (
3828 build_binary_broadcast,
3829 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003830 TosaTensorValuesGen.tvgLazyGenDefault,
3831 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003832 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003833 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003834 "error_if_validators": (
3835 TosaErrorValidator.evRankMismatch,
3836 TosaErrorValidator.evWrongInputType,
3837 TosaErrorValidator.evWrongOutputType,
3838 TosaErrorValidator.evWrongInputList,
3839 TosaErrorValidator.evWrongOutputList,
3840 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003841 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003842 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003843 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003844 "logical_left_shift": {
3845 "op": Op.LOGICAL_LEFT_SHIFT,
3846 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003847 "build_fcn": (
3848 build_binary_broadcast,
3849 TosaTensorGen.tgBroadcastFuzz,
3850 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003851 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003852 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003853 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003854 "error_if_validators": (
3855 TosaErrorValidator.evRankMismatch,
3856 TosaErrorValidator.evWrongInputType,
3857 TosaErrorValidator.evWrongOutputType,
3858 TosaErrorValidator.evWrongInputList,
3859 TosaErrorValidator.evWrongOutputList,
3860 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003861 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003862 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003863 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003864 "logical_right_shift": {
3865 "op": Op.LOGICAL_RIGHT_SHIFT,
3866 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003867 "build_fcn": (
3868 build_binary_broadcast,
3869 TosaTensorGen.tgBroadcastFuzz,
3870 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003871 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003872 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003873 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003874 "error_if_validators": (
3875 TosaErrorValidator.evRankMismatch,
3876 TosaErrorValidator.evWrongInputType,
3877 TosaErrorValidator.evWrongOutputType,
3878 TosaErrorValidator.evWrongInputList,
3879 TosaErrorValidator.evWrongOutputList,
3880 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003881 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003882 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003883 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003884 "logical_or": {
3885 "op": Op.LOGICAL_OR,
3886 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003887 "build_fcn": (
3888 build_binary_broadcast,
3889 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003890 TosaTensorValuesGen.tvgLazyGenDefault,
3891 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003892 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003893 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003894 "error_if_validators": (
3895 TosaErrorValidator.evRankMismatch,
3896 TosaErrorValidator.evWrongInputType,
3897 TosaErrorValidator.evWrongOutputType,
3898 TosaErrorValidator.evWrongInputList,
3899 TosaErrorValidator.evWrongOutputList,
3900 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003901 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003902 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003903 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003904 "logical_xor": {
3905 "op": Op.LOGICAL_XOR,
3906 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003907 "build_fcn": (
3908 build_binary_broadcast,
3909 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003910 TosaTensorValuesGen.tvgLazyGenDefault,
3911 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003912 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003913 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003914 "error_if_validators": (
3915 TosaErrorValidator.evRankMismatch,
3916 TosaErrorValidator.evWrongInputType,
3917 TosaErrorValidator.evWrongOutputType,
3918 TosaErrorValidator.evWrongInputList,
3919 TosaErrorValidator.evWrongOutputList,
3920 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003921 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003922 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003923 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003924 "maximum": {
3925 "op": Op.MAXIMUM,
3926 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003927 "build_fcn": (
3928 build_binary_broadcast,
3929 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003930 TosaTensorValuesGen.tvgLazyGenDefault,
3931 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003932 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003933 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003934 "error_if_validators": (
3935 TosaErrorValidator.evRankMismatch,
3936 TosaErrorValidator.evWrongInputType,
3937 TosaErrorValidator.evWrongOutputType,
3938 TosaErrorValidator.evWrongInputList,
3939 TosaErrorValidator.evWrongOutputList,
3940 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003941 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003942 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003943 "data_gen": {
3944 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3945 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003946 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003947 "minimum": {
3948 "op": Op.MINIMUM,
3949 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003950 "build_fcn": (
3951 build_binary_broadcast,
3952 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003953 TosaTensorValuesGen.tvgLazyGenDefault,
3954 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003955 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003956 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003957 "error_if_validators": (
3958 TosaErrorValidator.evRankMismatch,
3959 TosaErrorValidator.evWrongInputType,
3960 TosaErrorValidator.evWrongOutputType,
3961 TosaErrorValidator.evWrongInputList,
3962 TosaErrorValidator.evWrongOutputList,
3963 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003964 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003965 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003966 "data_gen": {
3967 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3968 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003969 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003970 "mul": {
3971 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003972 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003973 "build_fcn": (
3974 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003975 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003976 TosaTensorValuesGen.tvgMul,
3977 TosaArgGen.agMul,
3978 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003979 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003980 "error_if_validators": (
3981 TosaErrorValidator.evWrongInputType,
3982 TosaErrorValidator.evWrongOutputType,
3983 TosaErrorValidator.evWrongInputList,
3984 TosaErrorValidator.evWrongOutputList,
3985 TosaErrorValidator.evRankMismatch,
3986 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003987 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003988 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003989 "data_gen": {
3990 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3991 },
3992 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003993 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003994 "pow": {
3995 "op": Op.POW,
3996 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003997 "build_fcn": (
3998 build_binary_broadcast,
3999 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00004000 TosaTensorValuesGen.tvgPow,
4001 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004002 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004003 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004004 "error_if_validators": (
4005 TosaErrorValidator.evRankMismatch,
4006 TosaErrorValidator.evWrongInputType,
4007 TosaErrorValidator.evWrongOutputType,
4008 TosaErrorValidator.evWrongInputList,
4009 TosaErrorValidator.evWrongOutputList,
4010 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004011 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004012 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004013 "data_gen": {
4014 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4015 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004016 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004017 "sub": {
4018 "op": Op.SUB,
4019 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004020 "build_fcn": (
4021 build_binary_broadcast,
4022 TosaTensorGen.tgBroadcastFuzz,
4023 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004024 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004025 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004026 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004027 "error_if_validators": (
4028 TosaErrorValidator.evRankMismatch,
4029 TosaErrorValidator.evWrongInputType,
4030 TosaErrorValidator.evWrongOutputType,
4031 TosaErrorValidator.evWrongInputList,
4032 TosaErrorValidator.evWrongOutputList,
4033 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004034 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004035 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004036 "data_gen": {
4037 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4038 },
4039 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004040 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004041 "table": {
4042 "op": Op.TABLE,
4043 # Use the automatic generation functions to create the input array
4044 # but create the table tensor in the build function, as it may be
4045 # a different type from the input
4046 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004047 "build_fcn": (
4048 build_table,
4049 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004050 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004051 TosaArgGen.agTable,
4052 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004053 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004054 "error_if_validators": (
4055 TosaErrorValidator.evWrongInputType,
4056 TosaErrorValidator.evWrongOutputType,
4057 TosaErrorValidator.evWrongInputList,
4058 TosaErrorValidator.evWrongOutputList,
4059 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004060 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004061 # Elementwise Unary operators
4062 "abs": {
4063 "op": Op.ABS,
4064 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004065 "build_fcn": (
4066 build_unary,
4067 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004068 TosaTensorValuesGen.tvgLazyGenDefault,
4069 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004070 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004071 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004072 "error_if_validators": (
4073 TosaErrorValidator.evWrongInputType,
4074 TosaErrorValidator.evWrongOutputType,
4075 TosaErrorValidator.evWrongInputList,
4076 TosaErrorValidator.evWrongOutputList,
4077 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004078 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004079 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004080 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004081 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004082 "bitwise_not": {
4083 "op": Op.BITWISE_NOT,
4084 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004085 "build_fcn": (
4086 build_unary,
4087 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004088 TosaTensorValuesGen.tvgLazyGenDefault,
4089 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004090 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004091 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004092 "error_if_validators": (
4093 TosaErrorValidator.evWrongInputType,
4094 TosaErrorValidator.evWrongOutputType,
4095 TosaErrorValidator.evWrongInputList,
4096 TosaErrorValidator.evWrongOutputList,
4097 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004098 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004099 "ceil": {
4100 "op": Op.CEIL,
4101 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004102 "build_fcn": (
4103 build_unary,
4104 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004105 TosaTensorValuesGen.tvgLazyGenDefault,
4106 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004107 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004108 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004109 "error_if_validators": (
4110 TosaErrorValidator.evWrongInputType,
4111 TosaErrorValidator.evWrongOutputType,
4112 TosaErrorValidator.evWrongInputList,
4113 TosaErrorValidator.evWrongOutputList,
4114 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004115 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004116 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004117 },
4118 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004119 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004120 "clz": {
4121 "op": Op.CLZ,
4122 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004123 "build_fcn": (
4124 build_unary,
4125 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004126 TosaTensorValuesGen.tvgLazyGenDefault,
4127 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004128 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004129 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004130 "error_if_validators": (
4131 TosaErrorValidator.evWrongInputType,
4132 TosaErrorValidator.evWrongOutputType,
4133 TosaErrorValidator.evWrongInputList,
4134 TosaErrorValidator.evWrongOutputList,
4135 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004136 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004137 "cos": {
4138 "op": Op.COS,
4139 "operands": (1, 0),
4140 "build_fcn": (
4141 build_unary,
4142 TosaTensorGen.tgBasic,
4143 TosaTensorValuesGen.tvgLazyGenDefault,
4144 TosaArgGen.agNone,
4145 ),
4146 "types": TYPE_FP,
4147 "error_if_validators": (
4148 TosaErrorValidator.evWrongInputType,
4149 TosaErrorValidator.evWrongOutputType,
4150 TosaErrorValidator.evWrongInputList,
4151 TosaErrorValidator.evWrongOutputList,
4152 ),
4153 "data_gen": {
4154 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4155 },
4156 "compliance": {"abs_error_normal_divisor": 2},
4157 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004158 "exp": {
4159 "op": Op.EXP,
4160 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004161 "build_fcn": (
4162 build_unary,
4163 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004164 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004165 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004166 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004167 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004168 "error_if_validators": (
4169 TosaErrorValidator.evWrongInputType,
4170 TosaErrorValidator.evWrongOutputType,
4171 TosaErrorValidator.evWrongInputList,
4172 TosaErrorValidator.evWrongOutputList,
4173 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004174 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004175 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004177 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004178 "floor": {
4179 "op": Op.FLOOR,
4180 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004181 "build_fcn": (
4182 build_unary,
4183 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004184 TosaTensorValuesGen.tvgLazyGenDefault,
4185 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004186 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004187 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004188 "error_if_validators": (
4189 TosaErrorValidator.evWrongInputType,
4190 TosaErrorValidator.evWrongOutputType,
4191 TosaErrorValidator.evWrongInputList,
4192 TosaErrorValidator.evWrongOutputList,
4193 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004194 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004195 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004196 },
4197 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004198 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004199 "log": {
4200 "op": Op.LOG,
4201 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004202 "build_fcn": (
4203 build_unary,
4204 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004205 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004206 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004207 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004208 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004209 "error_if_validators": (
4210 TosaErrorValidator.evWrongInputType,
4211 TosaErrorValidator.evWrongOutputType,
4212 TosaErrorValidator.evWrongInputList,
4213 TosaErrorValidator.evWrongOutputList,
4214 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004215 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004216 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004217 },
4218 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004219 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004220 "logical_not": {
4221 "op": Op.LOGICAL_NOT,
4222 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004223 "build_fcn": (
4224 build_unary,
4225 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004226 TosaTensorValuesGen.tvgLazyGenDefault,
4227 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004228 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004229 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004230 "error_if_validators": (
4231 TosaErrorValidator.evWrongInputType,
4232 TosaErrorValidator.evWrongOutputType,
4233 TosaErrorValidator.evWrongInputList,
4234 TosaErrorValidator.evWrongOutputList,
4235 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004236 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004237 "negate": {
4238 "op": Op.NEGATE,
4239 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004240 "build_fcn": (
4241 build_unary,
4242 TosaTensorGen.tgBasic,
4243 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004244 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004245 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004246 "qgen": TosaQuantGen.qgUnary,
4247 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004248 "error_if_validators": (
4249 TosaErrorValidator.evInputZeroPointNotZero,
4250 TosaErrorValidator.evOutputZeroPointNotZero,
4251 TosaErrorValidator.evWrongInputType,
4252 TosaErrorValidator.evWrongOutputType,
4253 TosaErrorValidator.evWrongInputList,
4254 TosaErrorValidator.evWrongOutputList,
4255 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004256 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004257 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004258 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004259 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004260 "reciprocal": {
4261 "op": Op.RECIPROCAL,
4262 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004263 "build_fcn": (
4264 build_unary,
4265 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004266 TosaTensorValuesGen.tvgLazyGenDefault,
4267 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004268 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004269 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004270 "error_if_validators": (
4271 TosaErrorValidator.evWrongInputType,
4272 TosaErrorValidator.evWrongOutputType,
4273 TosaErrorValidator.evWrongInputList,
4274 TosaErrorValidator.evWrongOutputList,
4275 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004276 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004277 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004278 },
4279 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004280 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004281 "rsqrt": {
4282 "op": Op.RSQRT,
4283 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004284 "build_fcn": (
4285 build_unary,
4286 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004287 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004288 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004289 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004290 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004291 "error_if_validators": (
4292 TosaErrorValidator.evWrongInputType,
4293 TosaErrorValidator.evWrongOutputType,
4294 TosaErrorValidator.evWrongInputList,
4295 TosaErrorValidator.evWrongOutputList,
4296 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004297 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004298 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004299 },
4300 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004301 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004302 "sin": {
4303 "op": Op.SIN,
4304 "operands": (1, 0),
4305 "build_fcn": (
4306 build_unary,
4307 TosaTensorGen.tgBasic,
4308 TosaTensorValuesGen.tvgLazyGenDefault,
4309 TosaArgGen.agNone,
4310 ),
4311 "types": TYPE_FP,
4312 "error_if_validators": (
4313 TosaErrorValidator.evWrongInputType,
4314 TosaErrorValidator.evWrongOutputType,
4315 TosaErrorValidator.evWrongInputList,
4316 TosaErrorValidator.evWrongOutputList,
4317 ),
4318 "data_gen": {
4319 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4320 },
4321 "compliance": {"abs_error_normal_divisor": 2},
4322 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004323 # Elementwise Ternary operators
4324 "select": {
4325 "op": Op.SELECT,
4326 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004327 "build_fcn": (
4328 build_select,
4329 TosaTensorGen.tgBroadcastFuzz,
4330 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004331 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004332 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004333 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004334 "error_if_validators": (
4335 TosaErrorValidator.evRankMismatch,
4336 TosaErrorValidator.evWrongInputType,
4337 TosaErrorValidator.evWrongOutputType,
4338 TosaErrorValidator.evWrongInputList,
4339 TosaErrorValidator.evWrongOutputList,
4340 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004341 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004342 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004343 "data_gen": {
4344 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4345 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004346 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004347 # Comparison operators
4348 "equal": {
4349 "op": Op.EQUAL,
4350 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004351 "build_fcn": (
4352 build_comparison,
4353 TosaTensorGen.tgBroadcastFuzz,
4354 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004355 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004356 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004357 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004358 "error_if_validators": (
4359 TosaErrorValidator.evRankMismatch,
4360 TosaErrorValidator.evWrongInputType,
4361 TosaErrorValidator.evWrongOutputType,
4362 TosaErrorValidator.evWrongInputList,
4363 TosaErrorValidator.evWrongOutputList,
4364 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004365 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004366 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004367 "data_gen": {
4368 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4369 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004370 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004371 "greater_equal": {
4372 "op": Op.GREATER_EQUAL,
4373 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004374 "build_fcn": (
4375 build_comparison,
4376 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004377 TosaTensorValuesGen.tvgLazyGenDefault,
4378 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004379 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004380 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004381 "error_if_validators": (
4382 TosaErrorValidator.evRankMismatch,
4383 TosaErrorValidator.evWrongInputType,
4384 TosaErrorValidator.evWrongOutputType,
4385 TosaErrorValidator.evWrongInputList,
4386 TosaErrorValidator.evWrongOutputList,
4387 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004388 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004389 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004390 "data_gen": {
4391 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4392 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004393 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004394 "greater": {
4395 "op": Op.GREATER,
4396 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004397 "build_fcn": (
4398 build_comparison,
4399 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004400 TosaTensorValuesGen.tvgLazyGenDefault,
4401 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004402 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004403 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004404 "error_if_validators": (
4405 TosaErrorValidator.evRankMismatch,
4406 TosaErrorValidator.evWrongInputType,
4407 TosaErrorValidator.evWrongOutputType,
4408 TosaErrorValidator.evWrongInputList,
4409 TosaErrorValidator.evWrongOutputList,
4410 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004411 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004412 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004413 "data_gen": {
4414 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4415 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004416 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004417 # Reduction operators
4418 "reduce_all": {
4419 "op": Op.REDUCE_ALL,
4420 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004421 "build_fcn": (
4422 build_reduce,
4423 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004424 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004425 TosaArgGen.agAxis,
4426 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004427 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004428 "error_if_validators": (
4429 TosaErrorValidator.evAxisLargerRank,
4430 TosaErrorValidator.evAxisSmallerZero,
4431 TosaErrorValidator.evShapeOfAxisNotOne,
4432 TosaErrorValidator.evWrongInputType,
4433 TosaErrorValidator.evWrongOutputType,
4434 TosaErrorValidator.evWrongRank,
4435 TosaErrorValidator.evWrongInputList,
4436 TosaErrorValidator.evWrongOutputList,
4437 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004438 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004439 "reduce_any": {
4440 "op": Op.REDUCE_ANY,
4441 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004442 "build_fcn": (
4443 build_reduce,
4444 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004445 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004446 TosaArgGen.agAxis,
4447 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004448 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004449 "error_if_validators": (
4450 TosaErrorValidator.evAxisLargerRank,
4451 TosaErrorValidator.evAxisSmallerZero,
4452 TosaErrorValidator.evShapeOfAxisNotOne,
4453 TosaErrorValidator.evWrongInputType,
4454 TosaErrorValidator.evWrongOutputType,
4455 TosaErrorValidator.evWrongRank,
4456 TosaErrorValidator.evWrongInputList,
4457 TosaErrorValidator.evWrongOutputList,
4458 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004459 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004460 "reduce_max": {
4461 "op": Op.REDUCE_MAX,
4462 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004463 "build_fcn": (
4464 build_reduce,
4465 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004466 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004467 TosaArgGen.agAxis,
4468 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004469 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004470 "error_if_validators": (
4471 TosaErrorValidator.evAxisLargerRank,
4472 TosaErrorValidator.evAxisSmallerZero,
4473 TosaErrorValidator.evShapeOfAxisNotOne,
4474 TosaErrorValidator.evWrongInputType,
4475 TosaErrorValidator.evWrongOutputType,
4476 TosaErrorValidator.evWrongRank,
4477 TosaErrorValidator.evWrongInputList,
4478 TosaErrorValidator.evWrongOutputList,
4479 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004480 "data_gen": {
4481 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4482 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004483 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004484 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004485 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004486 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004487 "build_fcn": (
4488 build_reduce,
4489 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004490 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004491 TosaArgGen.agAxis,
4492 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004493 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004494 "error_if_validators": (
4495 TosaErrorValidator.evAxisLargerRank,
4496 TosaErrorValidator.evAxisSmallerZero,
4497 TosaErrorValidator.evShapeOfAxisNotOne,
4498 TosaErrorValidator.evWrongInputType,
4499 TosaErrorValidator.evWrongOutputType,
4500 TosaErrorValidator.evWrongRank,
4501 TosaErrorValidator.evWrongInputList,
4502 TosaErrorValidator.evWrongOutputList,
4503 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004504 "data_gen": {
4505 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4506 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004507 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004508 "reduce_product": {
4509 "op": Op.REDUCE_PRODUCT,
4510 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004511 "build_fcn": (
4512 build_reduce,
4513 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004514 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004515 TosaArgGen.agAxis,
4516 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004517 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004518 "error_if_validators": (
4519 TosaErrorValidator.evAxisLargerRank,
4520 TosaErrorValidator.evAxisSmallerZero,
4521 TosaErrorValidator.evShapeOfAxisNotOne,
4522 TosaErrorValidator.evWrongInputType,
4523 TosaErrorValidator.evWrongOutputType,
4524 TosaErrorValidator.evWrongRank,
4525 TosaErrorValidator.evWrongInputList,
4526 TosaErrorValidator.evWrongOutputList,
4527 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004528 "data_gen": {
4529 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4530 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004531 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004532 "reduce_sum": {
4533 "op": Op.REDUCE_SUM,
4534 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004535 "build_fcn": (
4536 build_reduce,
4537 TosaTensorGen.tgBasic,
4538 TosaTensorValuesGen.tvgReduceSum,
4539 TosaArgGen.agAxis,
4540 ),
James Ward24dbc422022-10-19 12:20:31 +01004541 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004542 "error_if_validators": (
4543 TosaErrorValidator.evAxisLargerRank,
4544 TosaErrorValidator.evAxisSmallerZero,
4545 TosaErrorValidator.evShapeOfAxisNotOne,
4546 TosaErrorValidator.evWrongInputType,
4547 TosaErrorValidator.evWrongOutputType,
4548 TosaErrorValidator.evWrongRank,
4549 TosaErrorValidator.evWrongInputList,
4550 TosaErrorValidator.evWrongOutputList,
4551 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004552 "data_gen": {
4553 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4554 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004555 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004556 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004557 "concat": {
4558 "op": Op.CONCAT,
4559 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004560 "build_fcn": (
4561 build_concat,
4562 TosaTensorGen.tgConcat,
4563 TosaTensorValuesGen.tvgConcat,
4564 TosaArgGen.agAxis,
4565 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004566 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004567 "error_if_validators": (
4568 TosaErrorValidator.evAxisLargerRank,
4569 TosaErrorValidator.evAxisSmallerZero,
4570 TosaErrorValidator.evConcatInputRankMismatch,
4571 TosaErrorValidator.evConcatShapeSumMismatch,
4572 TosaErrorValidator.evConcatInputDimMismatch,
4573 TosaErrorValidator.evWrongInputType,
4574 TosaErrorValidator.evWrongOutputType,
4575 TosaErrorValidator.evWrongOutputList,
4576 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004577 "data_gen": {
4578 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4579 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004580 },
4581 "pad": {
4582 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004583 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004584 "build_fcn": (
4585 build_pad,
4586 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004587 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004588 TosaArgGen.agPad,
4589 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004590 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004591 "error_if_validators": (
4592 TosaErrorValidator.evWrongInputType,
4593 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004594 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004595 TosaErrorValidator.evWrongOutputType,
4596 TosaErrorValidator.evWrongInputList,
4597 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004598 TosaErrorValidator.evRankMismatch,
4599 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004600 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004601 "data_gen": {
4602 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4603 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004604 },
Won Jeona21b2e82023-08-10 10:33:01 +00004605 "dim": {
4606 "op": Op.DIM,
4607 "operands": (1, 0),
4608 "build_fcn": (
4609 build_dim,
4610 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004611 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004612 TosaArgGen.agAxis,
4613 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004614 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004615 "error_if_validators": (
4616 TosaErrorValidator.evAxisLargerRank,
4617 TosaErrorValidator.evAxisSmallerZero,
4618 TosaErrorValidator.evWrongInputType,
4619 TosaErrorValidator.evWrongInputList,
4620 TosaErrorValidator.evWrongOutputList,
4621 TosaErrorValidator.evWrongRank,
4622 ),
4623 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004624 "reshape": {
4625 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004626 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004627 "build_fcn": (
4628 build_reshape,
4629 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004630 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004631 TosaArgGen.agReshape,
4632 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004633 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004634 "error_if_validators": (
4635 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4636 TosaErrorValidator.evWrongInputType,
4637 TosaErrorValidator.evWrongOutputType,
4638 TosaErrorValidator.evWrongInputList,
4639 TosaErrorValidator.evWrongOutputList,
4640 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004641 "data_gen": {
4642 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4643 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004644 },
4645 "reverse": {
4646 "op": Op.REVERSE,
4647 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004648 "build_fcn": (
4649 build_reverse,
4650 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004651 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004652 TosaArgGen.agAxis,
4653 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004654 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004655 "error_if_validators": (
4656 TosaErrorValidator.evAxisSmallerZero,
4657 TosaErrorValidator.evAxisLargerRank,
4658 TosaErrorValidator.evWrongInputType,
4659 TosaErrorValidator.evWrongOutputType,
4660 TosaErrorValidator.evWrongInputList,
4661 TosaErrorValidator.evWrongOutputList,
4662 ),
evacha0198477222024-01-26 12:25:32 +00004663 "data_gen": {
4664 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4665 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004666 },
4667 "slice": {
4668 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004669 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004670 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004671 "build_fcn": (
4672 build_slice,
4673 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004674 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004675 TosaArgGen.agSlice,
4676 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004677 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004678 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004679 # TODO Turn off these error categories for now as the reference
4680 # model cannot allocate memory space for empty tensor. We probably
4681 # can report an accurate error messege at the right place during
4682 # exeuction.
4683 # TosaErrorValidator.evStartSmallerZero,
4684 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004685 TosaErrorValidator.evStartSizeOutsideBounds,
4686 TosaErrorValidator.evSizeOutputShapeMismatch,
4687 TosaErrorValidator.evInputSizeStartLengthMismatch,
4688 TosaErrorValidator.evWrongRank,
4689 TosaErrorValidator.evWrongInputType,
4690 TosaErrorValidator.evWrongOutputType,
4691 TosaErrorValidator.evWrongInputList,
4692 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004693 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004694 ),
evacha017f7d4252024-01-24 12:08:09 +00004695 "data_gen": {
4696 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4697 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004698 },
4699 "tile": {
4700 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004701 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004702 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004703 "build_fcn": (
4704 build_tile,
4705 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004706 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004707 TosaArgGen.agTile,
4708 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004709 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004710 "error_if_validators": (
4711 TosaErrorValidator.evWrongInputType,
4712 TosaErrorValidator.evWrongOutputType,
4713 TosaErrorValidator.evWrongInputList,
4714 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004715 TosaErrorValidator.evRankMismatch,
4716 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004717 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004718 "data_gen": {
4719 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4720 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004721 },
4722 "transpose": {
4723 "op": Op.TRANSPOSE,
4724 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004725 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004726 "build_fcn": (
4727 build_transpose,
4728 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004729 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004730 TosaArgGen.agTranspose,
4731 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004732 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004733 "error_if_validators": (
4734 TosaErrorValidator.evIndexOutsideBounds,
4735 TosaErrorValidator.evIndexUsedTwice,
4736 TosaErrorValidator.evWrongInputType,
4737 TosaErrorValidator.evWrongOutputType,
4738 TosaErrorValidator.evWrongInputList,
4739 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004740 TosaErrorValidator.evWrongRank,
4741 TosaErrorValidator.evRankMismatch,
4742 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004743 ),
evacha0198477222024-01-26 12:25:32 +00004744 "data_gen": {
4745 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4746 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004747 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004748 # Data nodes
4749 "const": {
4750 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004751 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004752 "build_fcn": (
4753 build_const,
4754 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004755 TosaTensorValuesGen.tvgLazyGenDefault,
4756 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004757 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004758 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004759 "data_gen": {
4760 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4761 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004762 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004763 "identity": {
4764 "op": Op.IDENTITY,
4765 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004766 "build_fcn": (
4767 build_unary,
4768 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004769 TosaTensorValuesGen.tvgLazyGenDefault,
4770 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004771 ),
evacha011adff832024-03-06 17:33:44 +00004772 "types": TYPE_FIB + [DType.INT4, DType.INT48],
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004773 "data_gen": {
4774 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4775 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004776 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004777 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004778 "gather": {
4779 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004780 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004781 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004782 "build_fcn": (
4783 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004784 TosaTensorGen.tgGather,
4785 TosaTensorValuesGen.tvgGather,
4786 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004787 ),
James Ward24dbc422022-10-19 12:20:31 +01004788 "types": (
4789 DType.INT8,
4790 DType.INT16,
4791 DType.INT32,
4792 DType.FP16,
4793 DType.BF16,
4794 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004795 DType.FP8E4M3,
4796 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004797 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004798 "error_if_validators": (
4799 TosaErrorValidator.evWrongInputType,
4800 TosaErrorValidator.evWrongOutputType,
4801 TosaErrorValidator.evWrongInputList,
4802 TosaErrorValidator.evWrongOutputList,
4803 TosaErrorValidator.evWrongRank,
4804 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004805 "data_gen": {
4806 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4807 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004808 },
4809 "scatter": {
4810 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004811 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004812 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004813 "build_fcn": (
4814 build_scatter,
4815 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004816 TosaTensorValuesGen.tvgScatter,
4817 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004818 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004819 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004820 "error_if_validators": (
4821 TosaErrorValidator.evWrongInputType,
4822 TosaErrorValidator.evWrongOutputType,
4823 TosaErrorValidator.evWrongInputList,
4824 TosaErrorValidator.evWrongOutputList,
4825 TosaErrorValidator.evWrongRank,
4826 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004827 "data_gen": {
4828 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4829 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004830 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004831 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004832 "resize": {
4833 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004834 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004835 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004836 "build_fcn": (
4837 build_resize,
4838 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004839 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004840 TosaArgGen.agResize,
4841 ),
James Ward24dbc422022-10-19 12:20:31 +01004842 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004843 "invalid_test_validators": (
4844 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004845 ),
4846 "error_if_validators": (
4847 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004848 TosaErrorValidator.evScaleSmallerEqualZero,
4849 TosaErrorValidator.evScaleNLargerMax,
4850 TosaErrorValidator.evScaleDLargerMax,
4851 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004852 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004853 TosaErrorValidator.evBorderSmallerMin,
4854 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004855 TosaErrorValidator.evWrongInputType,
4856 TosaErrorValidator.evWrongOutputType,
4857 TosaErrorValidator.evWrongRank,
4858 TosaErrorValidator.evWrongInputList,
4859 TosaErrorValidator.evWrongOutputList,
4860 TosaErrorValidator.evBatchMismatch,
4861 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004862 TosaErrorValidator.evResizeOutputShapeMismatch,
4863 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004864 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004865 "data_gen": {
4866 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4867 },
4868 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004869 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004870 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004871 "cast": {
4872 "op": Op.CAST,
4873 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004874 "build_fcn": (
4875 build_cast,
4876 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004877 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004878 TosaArgGen.agCast,
4879 ),
James Ward8b390432022-08-12 20:48:56 +01004880 "types": (
4881 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004882 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004883 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004884 DType.INT8,
4885 DType.INT16,
4886 DType.INT32,
4887 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004888 DType.FP8E4M3,
4889 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004890 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004891 "error_if_validators": (
4892 TosaErrorValidator.evWrongInputType,
4893 TosaErrorValidator.evWrongOutputType,
4894 TosaErrorValidator.evWrongInputList,
4895 TosaErrorValidator.evWrongOutputList,
4896 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004897 "data_gen": {
4898 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4899 },
4900 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004901 },
4902 "rescale": {
4903 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004904 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004905 "build_fcn": (
4906 build_rescale,
4907 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004908 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004909 TosaArgGen.agRescale,
4910 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004911 "types": [
4912 DType.UINT8,
4913 DType.INT8,
4914 DType.INT16,
4915 DType.INT32,
4916 DType.INT48,
4917 DType.UINT16,
4918 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004919 "error_if_validators": (
4920 TosaErrorValidator.evInputZeroPointNotZero,
4921 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004922 TosaErrorValidator.evU16InputZeroPointNotValid,
4923 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004924 TosaErrorValidator.evScaleTrue,
4925 TosaErrorValidator.evScaleNotTrue,
4926 TosaErrorValidator.evWrongInputType,
4927 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004928 TosaErrorValidator.evWrongInputList,
4929 TosaErrorValidator.evWrongOutputList,
4930 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004931 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004932 # Custom
4933 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004934 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004935 # Two varients of cond_if, one that generates one of two constant tensors (no
4936 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4937 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004938 "cond_if_const": {
4939 "op": Op.COND_IF,
4940 "operands": (0, 2),
4941 "build_fcn": (
4942 build_cond_if_const,
4943 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004944 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004945 TosaArgGen.agCondIf,
4946 ),
4947 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004948 "error_if_validators": (
4949 TosaErrorValidator.evOutputListThenGraphMismatch,
4950 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004951 TosaErrorValidator.evCondIfCondNotMatchingBool,
4952 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004953 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004954 },
4955 "cond_if_binary": {
4956 "op": Op.COND_IF,
4957 "operands": (2, 0),
4958 "build_fcn": (
4959 build_cond_if_binary,
4960 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004961 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004962 TosaArgGen.agCondIf,
4963 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004964 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004965 "error_if_validators": (
4966 TosaErrorValidator.evInputListThenGraphMismatch,
4967 TosaErrorValidator.evInputListElseGraphMismatch,
4968 TosaErrorValidator.evOutputListThenGraphMismatch,
4969 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004970 TosaErrorValidator.evCondIfCondNotMatchingBool,
4971 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004972 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004973 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004974 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004975 "while_loop": {
4976 "op": Op.WHILE_LOOP,
4977 "operands": (0, 1),
4978 "build_fcn": (
4979 build_while_loop,
4980 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004981 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004982 TosaArgGen.agWhileLoop,
4983 ),
4984 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004985 "error_if_validators": (
4986 TosaErrorValidator.evInputListOutputListMismatch,
4987 TosaErrorValidator.evInputListCondGraphMismatch,
4988 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4989 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4990 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004991 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004992 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004993 },
Luke Hutton57287132023-02-06 14:54:18 +00004994 "fft2d": {
4995 "op": Op.FFT2D,
4996 "operands": (2, 0),
4997 "rank": (3, 3),
4998 "build_fcn": (
4999 build_fft2d,
5000 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00005001 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00005002 TosaArgGen.agFFT2d,
5003 ),
5004 "types": [DType.FP32],
5005 "error_if_validators": (
5006 TosaErrorValidator.evWrongInputType,
5007 TosaErrorValidator.evWrongOutputType,
5008 TosaErrorValidator.evWrongInputList,
5009 TosaErrorValidator.evWrongOutputList,
5010 TosaErrorValidator.evWrongRank,
5011 TosaErrorValidator.evBatchMismatch,
5012 TosaErrorValidator.evKernelNotPowerOfTwo,
5013 TosaErrorValidator.evFFTInputShapeMismatch,
5014 TosaErrorValidator.evFFTOutputShapeMismatch,
5015 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00005016 "data_gen": {
5017 "fp": (gtu.DataGenType.DOT_PRODUCT,),
5018 },
Luke Hutton57287132023-02-06 14:54:18 +00005019 },
Luke Hutton261b7b62023-01-10 14:50:31 +00005020 "rfft2d": {
5021 "op": Op.RFFT2D,
5022 "operands": (1, 0),
5023 "rank": (3, 3),
5024 "build_fcn": (
5025 build_rfft2d,
5026 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00005027 TosaTensorValuesGen.tvgLazyGenDefault,
5028 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00005029 ),
5030 "types": [DType.FP32],
5031 "error_if_validators": (
5032 TosaErrorValidator.evWrongInputType,
5033 TosaErrorValidator.evWrongOutputType,
5034 TosaErrorValidator.evWrongInputList,
5035 TosaErrorValidator.evWrongOutputList,
5036 TosaErrorValidator.evWrongRank,
5037 TosaErrorValidator.evBatchMismatch,
5038 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00005039 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00005040 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00005041 "data_gen": {
5042 "fp": (gtu.DataGenType.DOT_PRODUCT,),
5043 },
Luke Hutton261b7b62023-01-10 14:50:31 +00005044 },
Won Jeon74342e52024-01-09 00:34:40 +00005045 # Shape
5046 "add_shape": {
5047 "op": Op.ADD_SHAPE,
5048 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005049 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005050 "build_fcn": (
5051 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005052 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005053 TosaTensorValuesGen.tvgAddSub,
5054 TosaArgGen.agNone,
5055 ),
5056 "types": [DType.SHAPE],
5057 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5058 },
5059 "sub_shape": {
5060 "op": Op.SUB_SHAPE,
5061 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005062 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005063 "build_fcn": (
5064 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005065 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005066 TosaTensorValuesGen.tvgAddSub,
5067 TosaArgGen.agNone,
5068 ),
5069 "types": [DType.SHAPE],
5070 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5071 },
5072 "mul_shape": {
5073 "op": Op.MUL_SHAPE,
5074 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005075 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005076 "build_fcn": (
5077 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005078 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005079 TosaTensorValuesGen.tvgMul,
5080 TosaArgGen.agNone,
5081 ),
5082 "types": [DType.SHAPE],
5083 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5084 },
5085 "div_shape": {
5086 "op": Op.DIV_SHAPE,
5087 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005088 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005089 "build_fcn": (
5090 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005091 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005092 TosaTensorValuesGen.tvgIntDiv,
5093 TosaArgGen.agNone,
5094 ),
5095 "types": [DType.SHAPE],
5096 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5097 },
5098 "concat_shape": {
5099 "op": Op.CONCAT_SHAPE,
5100 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005101 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005102 "build_fcn": (
5103 build_concat,
5104 TosaTensorGen.tgConcat,
5105 TosaTensorValuesGen.tvgConcat,
5106 TosaArgGen.agNone,
5107 ),
5108 "types": [DType.SHAPE],
5109 "error_if_validators": (),
5110 },
5111 "const_shape": {
5112 "op": Op.CONST_SHAPE,
5113 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005114 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005115 "build_fcn": (
5116 build_const,
5117 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005118 TosaTensorValuesGen.tvgLazyGenDefault,
5119 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005120 ),
5121 "types": [DType.SHAPE],
5122 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005123 }
5124
Kevin Cheng550ccc52021-03-03 11:21:43 -08005125
Eric Kunzee5e26762020-10-13 16:11:07 -07005126class OutputShaper:
5127 # Methods in this class compute the expected output shape and datatype
5128 # for common classes of operations
5129 def __init__(self):
5130 pass
5131
5132 # These methods return arguments that can be used for
5133 # creating a new output tensor
5134 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005135 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5136 if error_name != ErrorIf.RankMismatch:
5137 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005138 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005139
5140 shape = []
5141 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005142 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005143 shape.append(b.shape[i])
5144 else:
5145 shape.append(a.shape[i])
5146
Jerry Ge135c9552023-05-23 20:59:32 +00005147 fuzz_idx = rng.integers(0, len(a.shape))
5148 if error_name == ErrorIf.DimensionMismatch:
5149 shape[fuzz_idx] += 1
5150
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005151 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005152 all_dtypes = [
5153 DType.INT8,
5154 DType.INT16,
5155 DType.INT32,
5156 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005157 DType.FP16,
5158 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005159 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005160 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005161 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5162 outputDType = rng.choice(wrong_dtypes)
5163 else:
5164 outputDType = a.dtype
5165
5166 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005167
5168 @staticmethod
5169 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005170 assert len(a.shape) == len(b.shape)
5171 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005172
5173 shape = []
5174 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005175 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005176 shape.append(a.shape[i])
5177
Kevin Cheng550ccc52021-03-03 11:21:43 -08005178 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005179
5180 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005181 def unaryOp(ser, rng, a, error_name=None):
5182 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005183 all_dtypes = [
5184 DType.INT8,
5185 DType.INT16,
5186 DType.INT32,
5187 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005188 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005189 DType.FP16,
5190 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005191 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005192 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5193 outputDType = rng.choice(wrong_dtypes)
5194 else:
5195 outputDType = a.dtype
5196
5197 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005198
5199 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005200 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005201 if error_name != ErrorIf.RankMismatch:
5202 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005203 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005204
5205 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005206 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005207 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005208 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5209 else:
5210 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005211
Jerry Ge135c9552023-05-23 20:59:32 +00005212 fuzz_idx = rng.integers(0, len(a.shape))
5213 if error_name == ErrorIf.DimensionMismatch:
5214 shape[fuzz_idx] += 1
5215
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005216 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005217 all_dtypes = [
5218 DType.INT8,
5219 DType.INT16,
5220 DType.INT32,
5221 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005222 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005223 DType.FP16,
5224 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005225 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005226 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5227 outputDType = rng.choice(wrong_dtypes)
5228 else:
5229 outputDType = a.dtype
5230
5231 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005232
5233 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005234 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005235 if error_name != ErrorIf.RankMismatch:
5236 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005237 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005238
5239 # Do broadcast
5240 shape = []
5241 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005242 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005243 shape.append(b.shape[i])
5244 else:
5245 shape.append(a.shape[i])
5246
Jerry Ge135c9552023-05-23 20:59:32 +00005247 fuzz_idx = rng.integers(0, len(a.shape))
5248 if error_name == ErrorIf.DimensionMismatch:
5249 shape[fuzz_idx] += 1
5250
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005251 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005252 wrong_dtypes = [
5253 DType.INT8,
5254 DType.INT16,
5255 DType.INT32,
5256 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005257 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005258 DType.FP16,
5259 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005260 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005261 outputDType = rng.choice(wrong_dtypes)
5262 else:
5263 outputDType = DType.BOOL
5264
5265 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005266
5267 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005268 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005269 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005270 if error_name not in [
5271 ErrorIf.AxisSmallerZero,
5272 ErrorIf.AxisLargerRank,
5273 ErrorIf.ShapeOfAxisNotOne,
5274 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005275 shape[axis] = 1
5276 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5277 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005278
Matthew Haddond6ce7252021-09-29 15:35:44 +01005279 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005280 all_dtypes = [
5281 DType.INT8,
5282 DType.INT16,
5283 DType.INT32,
5284 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005285 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005286 DType.FP16,
5287 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005288 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005289 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5290 outputDType = rng.choice(wrong_dtypes)
5291 else:
5292 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005293
Matthew Haddond6ce7252021-09-29 15:35:44 +01005294 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005295
5296 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005297 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005298 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005299
5300 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5301 del shape[axis]
5302
5303 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5304 remove = rng.choice([True, False])
5305 if remove and len(shape) > 1:
5306 del shape[0]
5307 else:
5308 shape.append(1)
5309 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5310 for i in range(len(shape)):
5311 shape[i] = shape[i] + rng.integers(1, 10)
5312
5313 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005314 all_dtypes = [
5315 DType.INT8,
5316 DType.INT16,
5317 DType.INT32,
5318 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005319 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005320 DType.FP16,
5321 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005322 DType.FP8E4M3,
5323 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005324 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005325 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5326 outputDType = rng.choice(wrong_dtypes)
5327 else:
5328 outputDType = DType.INT32
5329
5330 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005331
5332 @staticmethod
Tai Lyf36f2562024-03-14 16:21:29 +00005333 def _get_conv_output_type(input_dtype):
5334 if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
5335 return input_dtype
5336 elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
5337 return DType.FP16
5338 elif input_dtype in (DType.INT8, DType.INT4):
5339 return DType.INT32
5340 elif input_dtype in (DType.INT16,):
5341 return DType.INT48
5342 assert True, f"Unsupported convolution data type {input_dtype}"
5343
5344 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005345 def conv2dOp(
5346 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5347 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005348
5349 # IFM: NHWC
5350 # Filter: OHWI
5351 # OFM: NHWC
5352
Kevin Cheng550ccc52021-03-03 11:21:43 -08005353 h = (
5354 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005355 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005356 + padding[0]
5357 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005358 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005359 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005360
Kevin Cheng550ccc52021-03-03 11:21:43 -08005361 w = (
5362 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005363 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005364 + padding[2]
5365 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005366 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005367 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005368
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005369 if error_name == ErrorIf.ConvOutputShapeMismatch:
5370 choices = [1, 2, 3]
5371 change = rng.choice(choices)
5372 # increment in multiples of stride to not hit non-integer error case
5373 if change in [1, 3]:
5374 h = h + (rng.choice(choices) * strides[0])
5375 if change in [2, 3]:
5376 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005377
Eric Kunzee5e26762020-10-13 16:11:07 -07005378 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5379
James Ward8b390432022-08-12 20:48:56 +01005380 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005381 # Pick some potentially correct output dtype if input type is incorrect
5382 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005383 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005384 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005385
5386 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005387 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005388 excludes = [DType.FP16, DType.FP32]
Jeremy Johnson80fd9b82024-03-12 11:46:50 +00005389 elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
Won Jeon2c34b462024-02-06 18:37:00 +00005390 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005391 else:
5392 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005393 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005394 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005395
Kevin Cheng550ccc52021-03-03 11:21:43 -08005396 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005397
5398 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005399 def conv3dOp(
5400 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5401 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005402
5403 # IFM: NDHWC
5404 # Filter: ODHWI
5405 # OFM: NDHWC
5406
5407 d = (
5408 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005409 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005410 + padding[0]
5411 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005412 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005413 ) // strides[0] + 1
5414
5415 h = (
5416 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005417 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005418 + padding[2]
5419 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005420 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005421 ) // strides[1] + 1
5422
5423 w = (
5424 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005425 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005426 + padding[4]
5427 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005428 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005429 ) // strides[2] + 1
5430
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005431 if error_name == ErrorIf.ConvOutputShapeMismatch:
5432 choices = [1, 2, 3, 4]
5433 change = rng.choice(choices)
5434 # increment in multiples of stride to not hit non-integer error case
5435 if change in [1, 4]:
5436 d = d + (rng.choice(choices) * strides[0])
5437 if change in [2, 4]:
5438 h = h + (rng.choice(choices) * strides[1])
5439 if change in [3, 4]:
5440 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005441
Kevin Cheng1533b852021-09-01 12:51:58 -07005442 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5443
James Ward8b390432022-08-12 20:48:56 +01005444 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005445 # Pick some potentially correct output dtype if input type is incorrect
5446 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005447 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005448 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005449
5450 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005451 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005452 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005453 else:
5454 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005455 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005456 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005457
5458 return ser.addOutput(ofm_shape, out_dtype)
5459
5460 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005461 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005462 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005463 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005464 # IFM: NHWC
5465 # Filter: HWCM
5466 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005467
Kevin Cheng550ccc52021-03-03 11:21:43 -08005468 h = (
5469 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005470 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005471 + padding[0]
5472 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005473 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005474 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005475
Kevin Cheng550ccc52021-03-03 11:21:43 -08005476 w = (
5477 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005478 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005479 + padding[2]
5480 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005481 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005482 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005483
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005484 if error_name == ErrorIf.ConvOutputShapeMismatch:
5485 choices = [1, 2, 3]
5486 change = rng.choice(choices)
5487 # increment in multiples of stride to not hit non-integer error case
5488 if change in [1, 3]:
5489 h = h + (rng.choice(choices) * strides[0])
5490 if change in [2, 3]:
5491 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005492
Eric Kunzee5e26762020-10-13 16:11:07 -07005493 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5494
James Ward8b390432022-08-12 20:48:56 +01005495 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005496 # Pick some potentially correct output dtype if input type is incorrect
5497 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005498 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005499 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005500
5501 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005502 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005503 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005504 else:
5505 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005506 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005507 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005508
Kevin Cheng550ccc52021-03-03 11:21:43 -08005509 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005510
5511 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005512 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005513 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005514 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005515 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005516 h = 1
5517 w = 1
5518 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005519 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5520 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005521
5522 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005523 choices = [1, 2, 3]
5524 change = rng.choice(choices)
5525 # increment in multiples of stride to not hit non-integer error case
5526 if change in [1, 3]:
5527 h = h + (rng.choice(choices) * stride[0])
5528 if change in [2, 3]:
5529 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005530 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005531
5532 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005533 all_dtypes = [
5534 DType.INT8,
5535 DType.INT16,
5536 DType.INT32,
5537 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005538 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005539 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005540 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005541 DType.FP8E4M3,
5542 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005543 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005544 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5545 outputDType = rng.choice(wrong_dtypes)
5546 else:
5547 outputDType = ifm.dtype
5548
5549 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005550
5551 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005552 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005553 # input: N, IC
5554 # filter: OC, IC
5555 # output: N, OC
5556
5557 output_shape = [input.shape[0], filter.shape[0]]
5558
James Ward8b390432022-08-12 20:48:56 +01005559 # Validated in arg_gen (also invalidated for ErrorIf)
5560 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005561
Kevin Cheng550ccc52021-03-03 11:21:43 -08005562 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005563
5564 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005565 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005566 # a: N, H, C
5567 # b: N, C, W
5568 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005569
Kevin Cheng2d60f002021-06-09 14:18:32 -07005570 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005571
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005572 if error_name == ErrorIf.WrongOutputType:
5573 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005574 incorrect_types = (
5575 DType.INT4,
5576 DType.INT8,
5577 DType.INT16,
5578 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005579 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005580 DType.FP16,
5581 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005582 DType.FP8E4M3,
5583 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005584 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005585 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005586 incorrect_types = (
5587 DType.INT4,
5588 DType.INT8,
5589 DType.INT16,
5590 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005591 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005592 DType.FP16,
5593 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005594 DType.FP8E4M3,
5595 DType.FP8E5M2,
5596 )
5597 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5598 incorrect_types = (
5599 DType.INT4,
5600 DType.INT8,
5601 DType.INT16,
5602 DType.INT32,
5603 DType.INT48,
5604 DType.FP32,
5605 DType.BF16,
5606 DType.FP8E4M3,
5607 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005608 )
James Ward24dbc422022-10-19 12:20:31 +01005609 elif (
5610 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5611 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005612 incorrect_types = (
5613 DType.INT4,
5614 DType.INT8,
5615 DType.INT16,
5616 DType.INT32,
5617 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005618 DType.FP8E4M3,
5619 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005620 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005621 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005622 elif error_name == ErrorIf.WrongInputType:
5623 # Pick some potentially correct output dtype if input type is incorrect
5624 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005625 else:
James Ward8b390432022-08-12 20:48:56 +01005626 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005627
Kevin Cheng550ccc52021-03-03 11:21:43 -08005628 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005629
5630 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005631 def concatOp(ser, rng, axis, inputs, error_name=None):
5632 input1 = inputs[0]
5633 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005634
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005635 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005636 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005637 if not (
5638 # unable to concat tensors of different ranks
5639 error_name == ErrorIf.ConcatInputRankMismatch
5640 # unable to concat tensors along an invalid axis
5641 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005642 ):
5643 for tensor in remaining_inputs:
5644 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005645
Matthew Haddon01c359d2021-10-15 16:30:48 +01005646 if error_name == ErrorIf.ConcatShapeSumMismatch:
5647 output_shape[axis] += rng.integers(5, 10)
5648
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005649 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005650 all_dtypes = {
5651 DType.INT8,
5652 DType.INT16,
5653 DType.INT32,
5654 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005655 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005656 DType.FP16,
5657 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005658 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005659 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5660 outputDType = rng.choice(wrong_dtypes)
5661 else:
5662 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005663
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005664 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005665
5666 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005667 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005668
5669 output_shape = a.shape.copy()
5670
5671 for i in range(len(output_shape)):
5672 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5673
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005674 if error_name == ErrorIf.PadOutputShapeMismatch:
5675 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005676 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005677 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005678 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005679
Matthew Haddone807aae2021-10-11 18:12:58 +01005680 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005681 all_dtypes = [
5682 DType.INT8,
5683 DType.INT16,
5684 DType.INT32,
5685 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005686 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005687 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005688 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005689 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005690 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5691 outputDType = rng.choice(wrong_dtypes)
5692 else:
5693 outputDType = a.dtype
5694
5695 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005696
5697 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005698 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005699 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005700
5701 if error_name == ErrorIf.WrongOutputType:
5702 all_dtypes = [
5703 DType.INT8,
5704 DType.INT16,
5705 DType.INT32,
5706 DType.INT48,
5707 DType.FP32,
5708 DType.FP16,
5709 DType.BF16,
5710 ]
5711 wrong_dtypes = list(set(all_dtypes))
5712 outputDType = rng.choice(wrong_dtypes)
5713 else:
5714 outputDType = DType.SHAPE
5715
5716 return ser.addOutput(output_shape, outputDType)
5717
5718 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005719 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005720 output_shape = shape.copy()
5721
Matthew Haddone807aae2021-10-11 18:12:58 +01005722 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5723 for i in range(len(output_shape)):
5724 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5725
5726 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005727 all_dtypes = [
5728 DType.INT8,
5729 DType.INT16,
5730 DType.INT32,
5731 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005732 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005733 DType.FP16,
5734 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005735 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005736 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5737 outputDType = rng.choice(wrong_dtypes)
5738 else:
5739 outputDType = a.dtype
5740
5741 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005742
5743 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005744 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005745
Matthew Haddone807aae2021-10-11 18:12:58 +01005746 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005747 all_dtypes = [
5748 DType.INT8,
5749 DType.INT16,
5750 DType.INT32,
5751 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005752 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005753 DType.FP16,
5754 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005755 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005756 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005757 outputDType = rng.choice(wrong_dtypes)
5758 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005759 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005760
Luke Huttona4e48ca2023-02-22 11:53:48 +00005761 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005762 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005763 for index in range(len(output_shape)):
5764 if output_shape[index] <= 2:
5765 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5766 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005767 output_shape[index] = output_shape[index] + rng.choice(
5768 [-2, -1, 1, 2]
5769 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005770 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5771 output_shape = input.shape.copy()
5772 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005773 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005774
5775 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005776
5777 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005778 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005779
5780 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005781 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005782
5783 for i in range(len(output_shape)):
5784 output_shape[i] = a.shape[i] * multiples[i]
5785
Luke Huttona4e48ca2023-02-22 11:53:48 +00005786 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005787 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005788
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005789 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005790 all_dtypes = [
5791 DType.INT8,
5792 DType.INT16,
5793 DType.INT32,
5794 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005795 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005796 DType.FP16,
5797 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005798 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005799 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5800 outputDType = rng.choice(wrong_dtypes)
5801 else:
5802 outputDType = a.dtype
5803
5804 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005805
5806 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005807 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005808 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005809
Kevin Cheng550ccc52021-03-03 11:21:43 -08005810 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005811
Luke Huttona4e48ca2023-02-22 11:53:48 +00005812 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005813 for i in range(len(output_shape)):
5814 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005815
Luke Huttona4e48ca2023-02-22 11:53:48 +00005816 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5817 for i in range(len(output_shape)):
5818 output_shape[i] += rng.integers(1, 10)
5819 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005820 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005821
Matthew Haddone807aae2021-10-11 18:12:58 +01005822 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005823 all_dtypes = [
5824 DType.INT8,
5825 DType.INT16,
5826 DType.INT32,
5827 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005828 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005829 DType.FP16,
5830 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005831 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005832 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5833 outputDType = rng.choice(wrong_dtypes)
5834 else:
5835 outputDType = a.dtype
5836
5837 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005838
5839 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005840 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005841 if error_name != ErrorIf.WrongRank:
5842 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005843 assert len(indices.shape) == 2
5844 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005845
Kevin Cheng77d0f762020-11-24 10:26:32 -08005846 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5847
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005848 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005849 all_dtypes = [
5850 DType.INT8,
5851 DType.INT16,
5852 DType.INT32,
5853 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005854 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005855 DType.FP16,
5856 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005857 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005858 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5859 outputDType = rng.choice(wrong_dtypes)
5860 else:
5861 outputDType = values.dtype
5862
5863 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005864
5865 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005866 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005867 if error_name != ErrorIf.WrongRank:
5868 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005869 assert len(indices.shape) == 2
5870 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005871 assert values_in.shape[0] == indices.shape[0] # N
5872 assert input.shape[1] == indices.shape[1] # W
5873 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005874
5875 output_shape = values_in.shape
5876
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005877 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005878 all_dtypes = [
5879 DType.INT8,
5880 DType.INT16,
5881 DType.INT32,
5882 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005883 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005884 DType.FP16,
5885 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005886 DType.FP8E4M3,
5887 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005888 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005889 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5890 outputDType = rng.choice(wrong_dtypes)
5891 else:
5892 outputDType = values_in.dtype
5893
5894 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005895
5896 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005897 def tableOp(ser, rng, input, error_name=None):
5898 # Same shape as the input, dtype dependent on input dtype
5899 if error_name != ErrorIf.WrongInputType:
5900 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005901 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005902 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005903 wrong_dtypes = [
5904 DType.INT8,
5905 DType.INT16,
5906 DType.INT32,
5907 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005908 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005909 DType.FP16,
5910 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005911 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005912 wrong_dtypes.remove(output_dtype)
5913 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005914 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005915
5916 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005917 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005918 serializer,
5919 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005920 input,
5921 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005922 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005923 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005924 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005925 input_dtype,
5926 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005927 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005928 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005929 # Calculate OH, OW
5930 scale_y_n = scale[0]
5931 scale_y_d = scale[1]
5932 scale_x_n = scale[2]
5933 scale_x_d = scale[3]
5934 if error_name == ErrorIf.ScaleSmallerEqualZero:
5935 scale_y_n = max(scale_y_n, 1)
5936 scale_y_d = max(scale_y_d, 1)
5937 scale_x_n = max(scale_x_n, 1)
5938 scale_x_d = max(scale_x_d, 1)
5939
5940 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5941 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5942
5943 if error_name is not None:
5944 # Make sure the output tensor is valid, which can occur when
5945 # scale, offset or border have been changed for ERROR_IFs
5946 oh = max(oh, 1)
5947 ow = max(ow, 1)
5948 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005949 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5950 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005951
5952 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5953 choices = [1, 2, 3]
5954 change = rng.choice(choices)
5955 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5956 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005957 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005958 oh -= scale_y_d
5959 assert oh > 0 # Should have been caught in agResize
5960 else:
5961 oh += scale_y_d
5962 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005963 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005964 ow -= scale_x_d
5965 assert ow > 0 # Should have been caught in agResize
5966 else:
5967 ow += scale_x_d
5968
Matthew Haddon848efb42021-09-09 12:30:53 +01005969 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005970 output_dims = [
5971 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005972 oh,
5973 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005974 input.shape[0],
5975 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005976 elif error_name == ErrorIf.BatchMismatch:
5977 output_dims = [
5978 input.shape[0] + rng.integers(1, 10),
5979 oh,
5980 ow,
5981 input.shape[3],
5982 ]
5983 elif error_name == ErrorIf.ChannelMismatch:
5984 output_dims = [
5985 input.shape[0],
5986 oh,
5987 ow,
5988 input.shape[3] + rng.integers(1, 10),
5989 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005990 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005991 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005992
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005993 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005994
5995 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005996 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005997 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005998
5999 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01006000 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01006001 if error_name == ErrorIf.ConvOutputShapeMismatch:
6002 choices = [1, 2, 3]
6003 change = rng.choice(choices)
6004 if change in [1, 3]:
6005 output_shape[1] = output_shape[1] + rng.choice(choices)
6006 if change in [2, 3]:
6007 output_shape[2] = output_shape[2] + rng.choice(choices)
6008
James Ward8b390432022-08-12 20:48:56 +01006009 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00006010 # Pick some potentially correct output dtype if input type is incorrect
6011 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006012 else:
Tai Lyf36f2562024-03-14 16:21:29 +00006013 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00006014
6015 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01006016 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01006017 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01006018 else:
6019 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006020 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00006021 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006022
Kevin Cheng550ccc52021-03-03 11:21:43 -08006023 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00006024
6025 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00006026 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
6027 outputs = []
6028
6029 assert ifm1.dtype == ifm2.dtype
6030 input_dtype = ifm1.dtype
6031
6032 if error_name != ErrorIf.FFTInputShapeMismatch:
6033 assert ifm1.shape == ifm2.shape
6034
6035 input_shape = ifm1.shape
6036 if error_name != ErrorIf.WrongRank:
6037 assert len(input_shape) == 3
6038
6039 output_shape = input_shape.copy()
6040 output_dtype = input_dtype
6041
6042 if error_name == ErrorIf.WrongOutputType:
6043 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006044 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00006045 output_dtype = rng.choice(wrong_dtypes)
6046 elif error_name == ErrorIf.BatchMismatch:
6047 output_shape[0] += rng.integers(1, 10)
6048 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6049 modify_dim = rng.choice([1, 2])
6050 output_shape[modify_dim] += rng.integers(1, 10)
6051
6052 outputs.append(serializer.addOutput(output_shape, output_dtype))
6053 outputs.append(serializer.addOutput(output_shape, output_dtype))
6054 return outputs
6055
6056 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00006057 def rfft2dOp(serializer, rng, value, error_name=None):
6058 outputs = []
6059
6060 input_shape = value.shape
6061 if error_name != ErrorIf.WrongRank:
6062 assert len(input_shape) == 3
6063
6064 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
6065
6066 output_dtype = value.dtype
6067 if error_name == ErrorIf.WrongOutputType:
6068 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006069 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00006070 output_dtype = rng.choice(wrong_dtypes)
6071 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00006072 output_shape[0] += rng.integers(1, 10)
6073 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6074 modify_dim = rng.choice([1, 2])
6075 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00006076
6077 outputs.append(serializer.addOutput(output_shape, output_dtype))
6078 outputs.append(serializer.addOutput(output_shape, output_dtype))
6079 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00006080
6081 @staticmethod
6082 def addShapeOp(ser, rng, a, b, error_name=None):
6083 if error_name != ErrorIf.RankMismatch:
6084 assert len(a.shape) == len(b.shape)
6085 assert a.dtype == b.dtype
6086
6087 shape = []
6088 for i in range(len(a.shape)):
6089 shape.append(a.shape[i])
6090
6091 fuzz_idx = rng.integers(0, len(a.shape))
6092 if error_name == ErrorIf.DimensionMismatch:
6093 shape[fuzz_idx] += 1
6094
6095 if error_name == ErrorIf.WrongOutputType:
6096 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
6097 outputDType = rng.choice(wrong_dtypes)
6098 else:
6099 outputDType = DType.SHAPE
6100 return ser.addOutput(shape, outputDType)