blob: 71d7fcc003e205e2e7cdb7f41cbe831a30cb32d3 [file] [log] [blame]
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00001# Copyright (c) 2020-2024, ARM Limited.
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002# SPDX-License-Identifier: Apache-2.0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003import json
Jeremy Johnsonaf090182024-02-13 18:25:39 +00004import logging
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005import os
Tai Ly60dc48c2024-03-08 22:19:41 +00006import struct
Matthew Haddon630c17c2021-10-14 15:05:41 +01007from copy import deepcopy
Jeremy Johnson1271c442023-09-05 11:39:26 +01008from datetime import datetime
9from pathlib import Path
Eric Kunzee5e26762020-10-13 16:11:07 -070010
Jeremy Johnson1271c442023-09-05 11:39:26 +010011import generator.tosa_utils as gtu
Jeremy Johnson5c1364c2022-01-13 15:04:21 +000012import numpy as np
Jeremy Johnson2ec34942021-12-14 16:34:05 +000013import serializer.tosa_serializer as ts
Jeremy Johnson65ba8092023-10-09 16:31:13 +010014from generator.datagenerator import GenerateLibrary
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010015from generator.tosa_arg_gen import TosaArgGen
16from generator.tosa_arg_gen import TosaQuantGen
17from generator.tosa_arg_gen import TosaTensorGen
18from generator.tosa_arg_gen import TosaTensorValuesGen
Jeremy Johnson2ec34942021-12-14 16:34:05 +000019from generator.tosa_error_if import ErrorIf
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010020from generator.tosa_error_if import TosaErrorIfArgGen
21from generator.tosa_error_if import TosaErrorValidator
22from generator.tosa_error_if import TosaInvalidValidator
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010023from generator.tosa_random_gen import TosaHashRandomGenerator
24from generator.tosa_random_gen import TosaRandomGenerator
Jeremy Johnson1271c442023-09-05 11:39:26 +010025from schemavalidation.schemavalidation import TestDescSchemaValidator
Les Bell0e027d42021-11-09 14:42:14 +000026from tosa.DType import DType
27from tosa.Op import Op
Matthew Haddonb724efc2021-08-25 16:40:29 +010028
Jeremy Johnson1271c442023-09-05 11:39:26 +010029TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
30// SPDX-License-Identifier: Apache-2.0
31// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
32"""
33
Jeremy Johnsonaf090182024-02-13 18:25:39 +000034logging.basicConfig()
35logger = logging.getLogger("tosa_verif_build_tests")
36
Matthew Haddonb724efc2021-08-25 16:40:29 +010037
Eric Kunzee5e26762020-10-13 16:11:07 -070038class TosaTestGen:
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000039 # This currently matches the 8K level defined in the specification.
Jeremy Johnsonb2099702023-04-12 15:59:01 +010040 TOSA_8K_LEVEL_MAX_SCALE = 64
Jeremy Johnson0c716862023-04-13 17:18:19 +010041 TOSA_8K_LEVEL_MAX_KERNEL = 8192
42 TOSA_8K_LEVEL_MAX_STRIDE = 8192
Jeremy Johnson97eb75f2021-07-08 11:58:02 +010043
Jeremy Johnson1271c442023-09-05 11:39:26 +010044 # Main compliance dot product statistical test range
Jeremy Johnson30476252023-11-20 16:15:30 +000045 TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
Jeremy Johnson1271c442023-09-05 11:39:26 +010046 TOSA_MI_DOT_PRODUCT_MIN = 1000
47
Eric Kunzee5e26762020-10-13 16:11:07 -070048 def __init__(self, args):
49 self.args = args
50 self.basePath = args.output_dir
51 self.random_seed = args.random_seed
52 self.ser = None
Eric Kunzee5e26762020-10-13 16:11:07 -070053 self.createDynamicOpLists()
54 self.initOpListDefaults()
55 self.quantGen = TosaQuantGen()
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010056 self.global_rng = None
Eric Kunzee5e26762020-10-13 16:11:07 -070057 # Force makeShape to do a specific starting shape
58 self.targetted_shape = None
Jeremy Johnson1271c442023-09-05 11:39:26 +010059 # JSON schema validation
60 self.descSchemaValidator = TestDescSchemaValidator()
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010061 # Data generator library is sometimes needed for compliance set up
62 # even if we are generating the data later (lazy_data_generation)
63 self.dgl = GenerateLibrary(args.generate_lib_path)
Eric Kunzee5e26762020-10-13 16:11:07 -070064
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010065 # Work out floating point range
66 def convertFPRange(rangeFP, maxFP):
67 # Converts program arguments of max/-max to FP max
68 vals = []
69 for v in rangeFP:
70 if v == "max":
71 v = maxFP
72 elif v == "-max":
73 v = -maxFP
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000074 elif v < 0:
75 # Trim to minimum data type value
76 v = max(v, -maxFP)
77 elif v > 0:
78 # Trim to maximum data type value
79 v = min(v, maxFP)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010080 vals.append(v)
81 return tuple(sorted(vals))
82
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010083 self.random_dtype_range = {
84 DType.SHAPE: tuple(self.args.tensor_shape_range[0:2])
85 }
Won Jeon2c34b462024-02-06 18:37:00 +000086 for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010087 self.random_dtype_range[dtype] = convertFPRange(
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010088 args.tensor_fp_value_range,
89 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
90 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +010091 self.resetGlobalRNG()
92
93 def resetGlobalRNG(self):
94 self.global_rng = TosaRandomGenerator(self.random_seed, self.random_dtype_range)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +010095
Eric Kunzee5e26762020-10-13 16:11:07 -070096 def createSerializer(self, opName, testPath):
97 self.testPath = os.path.join(opName, testPath)
98
99 fullPath = os.path.join(self.basePath, self.testPath)
100 os.makedirs(fullPath, exist_ok=True)
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100101 # Embed const data in the flatbuffer
102 constMode = ts.ConstMode.EMBED
Jeremy Johnson1271c442023-09-05 11:39:26 +0100103 if self.args.lazy_data_gen:
104 # Lazy data generation - so make constants files
105 constMode = ts.ConstMode.INPUTS
106 elif self.args.dump_consts:
Jeremy Johnson55363fd2023-07-25 14:10:50 +0100107 constMode = ts.ConstMode.EMBED_DUMP
108 self.ser = ts.TosaSerializer(fullPath, constMode)
Eric Kunzee5e26762020-10-13 16:11:07 -0700109
110 def getSerializer(self):
111 return self.ser
112
Jeremy Johnson1271c442023-09-05 11:39:26 +0100113 def serialize(self, testName, metaData=None):
114 path = Path(self.basePath) / self.testPath
115
116 # Write out TOSA flatbuffer binary
117 path_fb = path / f"{testName}.tosa"
118 with path_fb.open("wb") as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700119 fd.write(self.ser.serialize())
120
Jeremy Johnson1271c442023-09-05 11:39:26 +0100121 # Get JSON descriptor from serializer
122 desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
123
124 if metaData:
125 # Add extra meta data to desc.json
126 desc["meta"] = metaData
127
128 # Validate desc.json before we output it
129 self.descSchemaValidator.validate_config(desc)
130
131 if metaData:
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100132 if "data_gen" in metaData:
133 if self.args.lazy_data_gen:
134 # Output datagen meta data as CPP data
135 path_md = path / f"{testName}_meta_data_gen.cpp"
136 with path_md.open("w") as fd:
137 fd.write(TOSA_AUTOGENERATED_HEADER)
138 fd.write("// Test meta data for data generation setup\n\n")
139 fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
140 json.dump(metaData["data_gen"], fd)
141 fd.write(')";\n\n')
Jeremy Johnson1271c442023-09-05 11:39:26 +0100142 if "compliance" in metaData:
143 # Output datagen meta data as CPP data
144 path_md = path / f"{testName}_meta_compliance.cpp"
145 with path_md.open("w") as fd:
146 fd.write(TOSA_AUTOGENERATED_HEADER)
147 fd.write("// Test meta data for compliance validation\n\n")
148 fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
149 json.dump(metaData["compliance"], fd)
150 fd.write(')";\n\n')
151
152 # Write desc.json
153 path_desc = path / "desc.json"
154 with path_desc.open("w") as fd:
155 json.dump(desc, fd, indent=1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700156
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100157 def buildPlaceholderTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700158 placeholders = []
159
Kevin Cheng989cb052021-04-28 16:29:44 -0700160 assert len(shape_list) == len(dtype_list)
161
Jeremy Johnson1271c442023-09-05 11:39:26 +0100162 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700163 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100164 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100165 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700166 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700167
168 return placeholders
169
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100170 def buildConstTensors(self, rng, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700171 consts = []
172
Kevin Cheng989cb052021-04-28 16:29:44 -0700173 assert len(shape_list) == len(dtype_list)
174
Jeremy Johnson1271c442023-09-05 11:39:26 +0100175 arr = None
Kevin Cheng989cb052021-04-28 16:29:44 -0700176 for idx, shape in enumerate(shape_list):
Jeremy Johnson1271c442023-09-05 11:39:26 +0100177 if not self.args.lazy_data_gen:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100178 arr = rng.randTensor(shape, dtype_list[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -0700179 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700180
181 return consts
182
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100183 def makeShape(self, rng, rank):
Eric Kunzee5e26762020-10-13 16:11:07 -0700184 if self.targetted_shape:
185 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800186 return np.int32(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100187 rng.integers(
Kevin Cheng550ccc52021-03-03 11:21:43 -0800188 low=self.args.tensor_shape_range[0],
189 high=self.args.tensor_shape_range[1],
190 size=rank,
191 )
192 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700193
194 def setTargetShape(self, shape):
195 self.targetted_shape = shape
196
Eric Kunzee5e26762020-10-13 16:11:07 -0700197 def shapeStr(self, shape):
198
199 sStr = []
200 # Convert to strings
201 for i in shape:
202 sStr.append(str(i))
203
Kevin Cheng550ccc52021-03-03 11:21:43 -0800204 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700205
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100206 def typeStr(self, dtype):
207 if isinstance(dtype, list) or isinstance(dtype, tuple):
208 assert len(dtype) >= 2
209 strs = [self.typeStr(t) for t in dtype]
210 # Limit types to the first 2 as the 3rd is the accumulator
211 return "x".join(strs[:2])
Eric Kunzee5e26762020-10-13 16:11:07 -0700212 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 if dtype in gtu.DTYPE_ATTRIBUTES:
214 return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
Kevin Cheng989cb052021-04-28 16:29:44 -0700215 else:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100216 raise Exception(
217 "Unknown dtype, cannot convert to string: {}".format(dtype)
218 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
Luke Hutton57287132023-02-06 14:54:18 +0000220 def constrictBatchSize(self, shape):
221 # Limit the batch size unless an explicit target shape set
222 if self.args.max_batch_size and not self.args.target_shapes:
223 shape[0] = min(shape[0], self.args.max_batch_size)
224 return shape
225
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100226 def makeDimension(self, rng):
227 return rng.randInt(
James Ward30124a82023-02-02 14:56:33 +0000228 low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
229 )
230
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100231 def tensorComplianceMetaData(
232 self, op, inputType, argsDict, outputTensor, errorName
233 ):
Jeremy Johnson4f931302024-01-04 17:05:24 +0000234 # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
235 UNSUPPORTED_NON_FP32_INPUT_OPS = (
236 Op.MATMUL,
237 Op.CONV2D,
238 Op.FULLY_CONNECTED,
239 Op.DEPTHWISE_CONV2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +0000240 Op.TRANSPOSE_CONV2D,
evacha0147ab1762024-01-29 13:23:23 +0000241 Op.CONV3D,
Jeremy Johnson4f931302024-01-04 17:05:24 +0000242 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100243 if (
244 errorName
245 or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
Jeremy Johnson708da822023-11-15 16:25:45 +0000246 or (
247 not gtu.dtypeIsSupportedByCompliance(inputType)
248 and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS
249 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100250 ):
251 # No compliance for error tests or unsupported types currently
Jeremy Johnson1271c442023-09-05 11:39:26 +0100252 return None
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100253
Jeremy Johnson1271c442023-09-05 11:39:26 +0100254 # Create compliance meta data for expected output tensor
Jeremy Johnsonbb0935f2023-09-14 16:43:48 +0100255 compliance_tens = {
256 "mode": None,
257 # Data type is needed for all FP runs, as refmodel precise mode produces FP64
258 "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
259 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100260 if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
261 mode = gtu.ComplianceMode.DOT_PRODUCT
262 compliance_tens["dot_product_info"] = {
263 "s": argsDict["s"],
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100264 "ks": int(argsDict["ksb"])
265 if "ksb" in argsDict
266 else int(argsDict["ks"]),
Jeremy Johnson1271c442023-09-05 11:39:26 +0100267 }
evacha019c96eef2024-02-07 11:21:55 +0000268 elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100269 mode = gtu.ComplianceMode.FP_SPECIAL
270 elif "compliance" in op and "ulp" in op["compliance"]:
271 mode = gtu.ComplianceMode.ULP
272 compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +0000273 elif "compliance" in op and "relative" in op["compliance"]:
274 mode = gtu.ComplianceMode.RELATIVE
275 compliance_tens["relative_info"] = {
276 "max": argsDict["max_abs_value"],
277 "scale": op["compliance"]["relative"],
278 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100279 elif op["op"] == Op.REDUCE_PRODUCT:
280 mode = gtu.ComplianceMode.REDUCE_PRODUCT
Jeremy Johnsonbd801962024-01-03 17:07:44 +0000281 compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
Jeremy Johnson534923d2023-12-04 11:11:06 +0000282 elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID):
Jeremy Johnson9a758382023-11-07 16:27:35 +0000283 mode = gtu.ComplianceMode.ABS_ERROR
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +0000284 if "compliance" in op and "abs_error_lower_bound" in op["compliance"]:
285 compliance_tens["abs_error_info"] = {
286 "lower_bound": op["compliance"]["abs_error_lower_bound"]
287 }
Jerry Ge51bd4f52024-02-20 11:21:19 -0800288 elif op["op"] in (Op.SIN, Op.COS):
289 mode = gtu.ComplianceMode.ABS_ERROR
290 if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
291 compliance_tens["abs_error_info"] = {
292 "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
293 }
Jeremy Johnson1271c442023-09-05 11:39:26 +0100294 else:
295 mode = gtu.ComplianceMode.EXACT
296 compliance_tens["mode"] = gtu.ComplianceMode(mode).name
297
298 return compliance_tens
299
300 # Build Op functions
301 # Create the output tensor (calling OutputShaper as needed)
302 # Do final tweaks to attributes (if necessary for errorIf)
303 # Add Op into graph
304 # Return resulting tensor information or BuildInfo
305
306 class BuildInfo:
307 """Enhanced build information containing result tensor and associated compliance dict."""
308
309 def __init__(self, resultTensor, complianceDict):
Jeremy Johnsonc8330812024-01-18 16:57:28 +0000310 if isinstance(resultTensor, list):
311 assert complianceDict is None or isinstance(complianceDict, list)
312 self.resultTensorList = resultTensor
313 self.complianceDictList = complianceDict
314 else:
315 self.resultTensorList = [resultTensor]
316 if complianceDict is None:
317 self.complianceDictList = None
318 else:
319 self.complianceDictList = [complianceDict]
320
321 def getComplianceInfo(self):
322 if self.complianceDictList is None:
323 return None
324 else:
325 tens_dict = {}
326 for tens, comp in zip(self.resultTensorList, self.complianceDictList):
327 if comp is not None:
328 tens_dict[tens.name] = comp
329
330 if tens_dict:
331 # Have some compliance data, so return the info
332 compliance = {
333 "version": "0.1",
334 "tensors": tens_dict,
335 }
336 else:
337 compliance = None
338 return compliance
Eric Kunzee5e26762020-10-13 16:11:07 -0700339
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000340 def build_unary(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100341 self,
342 rng,
343 op,
344 inputs,
345 args_dict,
346 validator_fcns=None,
347 error_name=None,
348 qinfo=None,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000349 ):
350 assert len(inputs) == 1
351 a = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100352 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100353
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000354 assert not isinstance(op, int)
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100355
356 # Ensure new output type has correct qinfo
357 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000358 if result_tensor.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000359 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100360 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, a.dtype),
361 TosaQuantGen.getZeroPoint(
362 rng, self.args.zeropoint, result_tensor.dtype
363 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000364 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100365
366 # Invalidate Input/Output list for error if checks.
367 input_list = [a.name]
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000368 output_list = [result_tensor.name]
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100369 pCount, cCount = op["operands"]
370 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000371 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100372 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000373 )
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100374
Les Bell729b0352021-11-24 10:28:21 +0000375 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100376 self.ser,
377 validator_fcns,
378 error_name,
379 op=op,
380 input_dtype=a.dtype,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000381 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000382 qinfo=qinfo,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000383 result_tensors=[result_tensor],
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100384 input_list=input_list,
385 output_list=output_list,
386 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000387 ):
388 return None
Matthew Haddone4ecdb22021-09-28 11:38:21 +0100389
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000390 attr = None
391 if op["op"] == Op.NEGATE:
392 attr = ts.TosaSerializerAttribute()
393 attr.NegateAttribute(qinfo[0], qinfo[1])
394
395 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000396
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +0000397 compliance = self.tensorComplianceMetaData(
398 op, a.dtype, args_dict, result_tensor, error_name
399 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000400 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700401
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000402 def build_binary_broadcast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100403 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000404 ):
405 assert len(inputs) == 2
406 a, b = inputs
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100407 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100408
409 # Invalidate Input/Output list for error if checks.
410 input_list = [a.name, b.name]
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000411 output_list = [result_tensor.name]
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100412 pCount, cCount = op["operands"]
413 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000414 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100415 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000416 )
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100417
Les Bell729b0352021-11-24 10:28:21 +0000418 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100419 self.ser,
420 validator_fcns,
421 error_name,
422 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000423 input1=a,
424 input2=b,
425 input_dtype=a.dtype,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000426 output_dtype=result_tensor.dtype,
427 result_tensors=[result_tensor],
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100428 input_list=input_list,
429 output_list=output_list,
430 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000431 ):
432 return None
Matthew Haddoneacff9a2021-09-24 14:42:13 +0100433
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000434 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000435
Jeremy Johnson9a758382023-11-07 16:27:35 +0000436 compliance = self.tensorComplianceMetaData(
437 op, a.dtype, args_dict, result_tensor, error_name
438 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000439
440 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700441
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000442 def build_arithmetic_right_shift(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100443 self,
444 rng,
445 op,
446 inputs,
447 args_dict,
448 validator_fcns=None,
449 error_name=None,
450 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000451 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +0000452 assert len(inputs) == 2
453 a, b = inputs
454 round = args_dict["round"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100455 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100456
457 # Invalidate Input/Output list for error if checks.
458 input_list = [a.name, b.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000459 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100460 pCount, cCount = op["operands"]
461 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000462 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100463 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000464 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100465
Les Bell729b0352021-11-24 10:28:21 +0000466 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100467 self.ser,
468 validator_fcns,
469 error_name,
470 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000471 input1=a,
472 input2=b,
473 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000474 output_dtype=result_tensor.dtype,
475 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100476 input_list=input_list,
477 output_list=output_list,
478 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000479 ):
480 return None
Kevin Chengaee1fac2020-11-11 13:54:06 -0800481
482 attr = ts.TosaSerializerAttribute()
483 attr.ArithmeticRightShiftAttribute(round)
484
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000485 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +0000486
487 compliance = self.tensorComplianceMetaData(
488 op, a.dtype, args_dict, result_tensor, error_name
489 )
490
491 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Chengaee1fac2020-11-11 13:54:06 -0800492
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100493 def build_mul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100494 self,
495 rng,
496 op,
497 inputs,
498 args_dict,
499 validator_fcns=None,
500 error_name=None,
501 qinfo=None,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100502 ):
Jeremy Johnson0a042992024-02-28 13:20:05 +0000503 # Note that mul is binary operator but it has a shift value tensor
504 assert len(inputs) == 3
505 a, b, s = inputs
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100506
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100507 result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700508
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100509 # Special for multiply: Force the result to INT32 for INT types
James Ward24dbc422022-10-19 12:20:31 +0100510 if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100511 result_tensor.setDtype(DType.INT32)
512
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100513 if error_name == ErrorIf.WrongOutputType:
514 all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100515 outputDType = rng.choice(all_dtypes)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100516 result_tensor.setDtype(outputDType)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100517
518 # Invalidate Input/Output list for error if checks.
Jeremy Johnson0a042992024-02-28 13:20:05 +0000519 input_list = [a.name, b.name, s.name]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100520 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100521 pCount, cCount = op["operands"]
522 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000523 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100524 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000525 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100526
Les Bell729b0352021-11-24 10:28:21 +0000527 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100528 self.ser,
529 validator_fcns,
530 error_name,
531 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000532 input1=a,
533 input2=b,
534 input_dtype=a.dtype,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100535 output_dtype=result_tensor.dtype,
536 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100537 input_list=input_list,
538 output_list=output_list,
539 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000540 ):
541 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700542
Jeremy Johnson0a042992024-02-28 13:20:05 +0000543 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100544
545 compliance = self.tensorComplianceMetaData(
546 op, a.dtype, args_dict, result_tensor, error_name
547 )
548
549 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700550
Jeremy Johnson587cc842024-02-08 11:45:44 +0000551 def build_table(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100552 self,
553 rng,
554 op,
555 inputs,
556 args_dict,
557 validator_fcns=None,
558 error_name=None,
559 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000560 ):
561 assert len(inputs) == 1
562 a = inputs[0]
563 table = args_dict["table"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100564 result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700565
Kevin Chengfe392ce2021-10-18 21:51:55 +0000566 attr = ts.TosaSerializerAttribute()
567 attr.TableAttribute(table)
568
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100569 # Invalidate Input/Output list for error if checks.
570 input_list = [a.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +0000571 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100572 pCount, cCount = op["operands"]
573 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000574 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100575 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000576 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100577
Les Bell729b0352021-11-24 10:28:21 +0000578 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100579 self.ser,
580 validator_fcns,
581 error_name,
582 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000583 input_shape=a.shape,
584 input_dtype=a.dtype,
Jeremy Johnson587cc842024-02-08 11:45:44 +0000585 output_dtype=result_tensor.dtype,
586 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100587 input_list=input_list,
588 output_list=output_list,
589 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000590 ):
591 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100592
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000593 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700594
Jeremy Johnson587cc842024-02-08 11:45:44 +0000595 compliance = self.tensorComplianceMetaData(
596 op, a.dtype, args_dict, result_tensor, error_name
597 )
598
599 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700600
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000601 def build_select(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100602 self,
603 rng,
604 op,
605 inputs,
606 args_dict,
607 validator_fcns=None,
608 error_name=None,
609 qinfo=None,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000610 ):
611 assert len(inputs) == 3
612 cond, a, b = inputs
613
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100614 result_tensor = OutputShaper.selectOp(self.ser, rng, cond, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100615
616 # Invalidate Input/Output list for error if checks.
617 input_list = [cond.name, a.name, b.name]
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000618 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100619 pCount, cCount = op["operands"]
620 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000621 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100622 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000623 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100624
Les Bell729b0352021-11-24 10:28:21 +0000625 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100626 self.ser,
627 validator_fcns,
628 error_name,
629 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000630 input1=cond,
631 input2=a,
632 input3=b,
633 input_shape=a.shape,
634 input_dtype=a.dtype,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000635 output_dtype=result_tensor.dtype,
636 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100637 input_list=input_list,
638 output_list=output_list,
639 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000640 ):
641 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100642
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000643 self.ser.addOperator(
644 op["op"],
645 input_list,
646 output_list,
647 )
Jeremy Johnson7b9abce2024-01-10 11:07:29 +0000648 compliance = self.tensorComplianceMetaData(
649 op, a.dtype, args_dict, result_tensor, error_name
650 )
651
652 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700653
Jeremy Johnsona0150012023-11-15 15:52:06 +0000654 def build_comparison(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100655 self,
656 rng,
657 op,
658 inputs,
659 args_dict,
660 validator_fcns=None,
661 error_name=None,
662 qinfo=None,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000663 ):
664 assert len(inputs) == 2
665 a, b = inputs
666
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100667 result_tensor = OutputShaper.binaryComparisonOp(self.ser, rng, a, b, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100668
669 # Invalidate Input/Output list for error if checks.
670 input_list = [a.name, b.name]
Jeremy Johnsona0150012023-11-15 15:52:06 +0000671 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100672 pCount, cCount = op["operands"]
673 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000674 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100675 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000676 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100677
Les Bell729b0352021-11-24 10:28:21 +0000678 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100679 self.ser,
680 validator_fcns,
681 error_name,
682 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000683 input1=a,
684 input2=b,
685 input_shape=a.shape,
686 input_dtype=a.dtype,
Jeremy Johnsona0150012023-11-15 15:52:06 +0000687 output_shape=result_tensor.shape,
688 output_dtype=result_tensor.dtype,
689 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100690 input_list=input_list,
691 output_list=output_list,
692 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000693 ):
694 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +0100695
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000696 self.ser.addOperator(
697 op["op"],
698 input_list,
699 output_list,
700 )
Jeremy Johnsona0150012023-11-15 15:52:06 +0000701
702 compliance = self.tensorComplianceMetaData(
703 op, a.dtype, args_dict, result_tensor, error_name
704 )
705 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700706
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000707 def build_argmax(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100708 self, rng, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000709 ):
710 assert len(inputs) == 1
711 a = inputs[0]
712 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100713 result_tensor = OutputShaper.argmaxOp(self.ser, rng, a, axis, error_name)
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100714
715 # Invalidate Input/Output list for error if checks.
716 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000717 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100718 pCount, cCount = op["operands"]
719 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000720 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100721 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000722 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100723
Les Bell729b0352021-11-24 10:28:21 +0000724 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100725 self.ser,
726 validator_fcns,
727 error_name,
728 op=op,
729 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000730 input_shape=a.shape,
731 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000732 output_shape=result_tensor.shape,
733 output_dtype=result_tensor.dtype,
734 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +0100735 input_list=input_list,
736 output_list=output_list,
737 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000738 ):
739 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700740
741 attr = ts.TosaSerializerAttribute()
742 attr.AxisAttribute(axis)
743
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000744 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +0000745
746 compliance = self.tensorComplianceMetaData(
747 op, inputs[0].dtype, args_dict, result_tensor, error_name
748 )
749 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700750
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000751 def build_pool2d(
752 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100753 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000754 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100755 inputs,
756 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000757 validator_fcns=None,
758 error_name=None,
759 qinfo=None,
760 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100761 assert len(inputs) == 1
762 input = inputs[0]
763 # max_pool has no accum_dtype
764 accum_dtype = (
765 args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
766 )
767 stride = args_dict["stride"]
768 pad = args_dict["pad"]
769 kernel = args_dict["kernel"]
770
Jeremy Johnson0601f802023-11-08 16:28:09 +0000771 result_tensor = OutputShaper.pool2dOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100772 self.ser, rng, input, kernel, stride, pad, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000773 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100774
775 # Ensure new output type has correct qinfo
776 if error_name == ErrorIf.WrongInputType:
777 if input.dtype not in [DType.INT8, DType.UINT8]:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000778 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100779 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, input.dtype),
780 TosaQuantGen.getZeroPoint(
781 rng, self.args.zeropoint, result_tensor.dtype
782 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000783 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100784
785 # Invalidate Input/Output list for error if checks.
786 input_list = [input.name]
Jeremy Johnson0601f802023-11-08 16:28:09 +0000787 output_list = [result_tensor.name]
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100788 pCount, cCount = op["operands"]
789 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000790 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100791 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000792 )
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100793
Les Bell729b0352021-11-24 10:28:21 +0000794 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100795 self.ser,
796 validator_fcns,
797 error_name,
798 op=op,
799 input_shape=input.shape,
800 input_dtype=input.dtype,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000801 output_shape=result_tensor.shape,
802 output_dtype=result_tensor.dtype,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +0000803 accum_dtype=accum_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100804 kernel=kernel,
805 stride=stride,
806 pad=pad,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000807 qinfo=qinfo,
Jeremy Johnson0601f802023-11-08 16:28:09 +0000808 result_tensors=[result_tensor],
Matthew Haddonb6b59e32021-10-07 17:19:20 +0100809 input_list=input_list,
810 output_list=output_list,
811 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +0000812 ):
813 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700814
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000815 if qinfo is None:
816 qinfo = [0, 0]
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000818 attr = ts.TosaSerializerAttribute()
James Ward8b390432022-08-12 20:48:56 +0100819 attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000820
821 self.ser.addOperator(op["op"], input_list, output_list, attr)
Eric Kunzee5e26762020-10-13 16:11:07 -0700822
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100823 compliance = self.tensorComplianceMetaData(
824 op, inputs[0].dtype, args_dict, result_tensor, error_name
825 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100826
827 return TosaTestGen.BuildInfo(result_tensor, compliance)
James Ward8b390432022-08-12 20:48:56 +0100828
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000829 def build_conv2d(
830 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100831 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000832 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100833 inputs,
834 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000835 validator_fcns=None,
836 error_name=None,
837 qinfo=None,
838 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100839 assert len(inputs) == 3
840 ifm, filter, bias = inputs
841 accum_dtype = args_dict["acc_type"]
842 strides = args_dict["stride"]
843 padding = args_dict["pad"]
844 dilations = args_dict["dilation"]
845
Kevin Cheng550ccc52021-03-03 11:21:43 -0800846 assert len(padding) == 4
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100847 result_tensor = OutputShaper.conv2dOp(
James Ward8b390432022-08-12 20:48:56 +0100848 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100849 rng,
James Ward8b390432022-08-12 20:48:56 +0100850 ifm,
851 filter,
852 accum_dtype,
853 strides,
854 padding,
855 dilations,
856 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000857 )
858
859 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000860 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
861 DType.INT8,
862 DType.UINT8,
863 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000864 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100865 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
866 TosaQuantGen.getZeroPoint(
867 rng, self.args.zeropoint, result_tensor.dtype
868 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000869 ]
Les Bell0e027d42021-11-09 14:42:14 +0000870
871 # Invalidate Input/Output list for error_if checks.
872 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100873 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000874 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000875 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100876 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000877 )
Les Bell0e027d42021-11-09 14:42:14 +0000878
Les Bell729b0352021-11-24 10:28:21 +0000879 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000880 self.ser,
881 validator_fcns,
882 error_name,
883 op=op,
884 input_dtype=ifm.dtype,
885 weight_dtype=filter.dtype,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100886 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000887 qinfo=qinfo,
888 input_list=input_list,
889 num_operands=num_operands,
890 output_list=output_list,
891 pad=padding,
892 stride=strides,
893 dilation=dilations,
894 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100895 weight_shape=filter.shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100896 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000897 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000898 ):
899 return None
Eric Kunzee5e26762020-10-13 16:11:07 -0700900
Tai Lyd3797f02023-11-15 23:06:19 +0000901 # TODO - Test local_bound, for now set local bound attribute to False
902 local_bound = False
903
Eric Kunzee5e26762020-10-13 16:11:07 -0700904 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000905 attr.ConvAttribute(
906 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
907 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700908
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000909 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100910
911 compliance = self.tensorComplianceMetaData(
912 op, ifm.dtype, args_dict, result_tensor, error_name
913 )
914
915 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -0700916
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000917 def build_conv3d(
918 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100919 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000920 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100921 inputs,
922 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000923 validator_fcns=None,
924 error_name=None,
925 qinfo=None,
926 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100927 assert len(inputs) == 3
928 ifm, filter, bias = inputs
929 accum_dtype = args_dict["acc_type"]
930 strides = args_dict["stride"]
931 padding = args_dict["pad"]
932 dilations = args_dict["dilation"]
933
Kevin Cheng1533b852021-09-01 12:51:58 -0700934 assert len(padding) == 6
evacha0147ab1762024-01-29 13:23:23 +0000935 result_tensor = OutputShaper.conv3dOp(
James Ward8b390432022-08-12 20:48:56 +0100936 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100937 rng,
James Ward8b390432022-08-12 20:48:56 +0100938 ifm,
939 filter,
940 accum_dtype,
941 strides,
942 padding,
943 dilations,
944 error_name,
Les Bell0e027d42021-11-09 14:42:14 +0000945 )
946
947 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000948 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
949 DType.INT8,
950 DType.UINT8,
951 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000952 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100953 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
954 TosaQuantGen.getZeroPoint(
955 rng, self.args.zeropoint, result_tensor.dtype
956 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000957 ]
Les Bell0e027d42021-11-09 14:42:14 +0000958
959 # Invalidate Input/Output list for error_if checks.
960 input_list = [ifm.name, filter.name, bias.name]
evacha0147ab1762024-01-29 13:23:23 +0000961 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +0000962 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000963 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +0100964 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +0000965 )
Les Bell0e027d42021-11-09 14:42:14 +0000966
Les Bell729b0352021-11-24 10:28:21 +0000967 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +0000968 self.ser,
969 validator_fcns,
970 error_name,
971 op=op,
972 input_dtype=ifm.dtype,
973 weight_dtype=filter.dtype,
evacha0147ab1762024-01-29 13:23:23 +0000974 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +0000975 qinfo=qinfo,
976 input_list=input_list,
977 num_operands=num_operands,
978 output_list=output_list,
979 pad=padding,
980 stride=strides,
981 dilation=dilations,
982 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +0100983 weight_shape=filter.shape,
evacha0147ab1762024-01-29 13:23:23 +0000984 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +0000985 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +0000986 ):
987 return None
Kevin Cheng1533b852021-09-01 12:51:58 -0700988
Tai Lyd3797f02023-11-15 23:06:19 +0000989 # TODO - Test local_bound, for now set local bound attribute to False
990 local_bound = False
991
Kevin Cheng1533b852021-09-01 12:51:58 -0700992 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +0000993 attr.ConvAttribute(
994 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
995 )
Kevin Cheng1533b852021-09-01 12:51:58 -0700996
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000997 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0147ab1762024-01-29 13:23:23 +0000998
999 compliance = self.tensorComplianceMetaData(
1000 op, ifm.dtype, args_dict, result_tensor, error_name
1001 )
1002
1003 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng1533b852021-09-01 12:51:58 -07001004
Kevin Cheng550ccc52021-03-03 11:21:43 -08001005 def build_transpose_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001006 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001007 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001008 op,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001009 inputs,
1010 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001011 validator_fcns=None,
1012 error_name=None,
1013 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001014 ):
Jeremy Johnson95a67102024-01-10 14:16:39 +00001015 assert len(inputs) == 3
1016 ifm, filter, bias = inputs
1017 accum_dtype = args_dict["acc_type"]
1018 strides = args_dict["stride"]
1019 out_pad = args_dict["pad"]
1020 output_shape = args_dict["out_shape"]
1021
TatWai Chong24594f52022-06-08 00:48:04 -07001022 assert len(out_pad) == 4
Jeremy Johnson95a67102024-01-10 14:16:39 +00001023 result_tensor = OutputShaper.transposeConv2DOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001024 self.ser, rng, ifm, output_shape, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001025 )
Les Bell0e027d42021-11-09 14:42:14 +00001026
1027 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001028 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1029 DType.INT8,
1030 DType.UINT8,
1031 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001032 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001033 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1034 TosaQuantGen.getZeroPoint(
1035 rng, self.args.zeropoint, result_tensor.dtype
1036 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001037 ]
Les Bell0e027d42021-11-09 14:42:14 +00001038
1039 # Invalidate Input/Output list for error_if checks.
1040 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson95a67102024-01-10 14:16:39 +00001041 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001042 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001043 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001044 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001045 )
Les Bell0e027d42021-11-09 14:42:14 +00001046
Les Bell729b0352021-11-24 10:28:21 +00001047 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001048 self.ser,
1049 validator_fcns,
1050 error_name,
1051 op=op,
1052 input_dtype=ifm.dtype,
1053 weight_dtype=filter.dtype,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001054 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001055 qinfo=qinfo,
1056 input_list=input_list,
1057 num_operands=num_operands,
1058 output_list=output_list,
TatWai Chong24594f52022-06-08 00:48:04 -07001059 pad=out_pad,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001060 stride=strides,
Les Bell0e027d42021-11-09 14:42:14 +00001061 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001062 weight_shape=filter.shape,
Jeremy Johnson95a67102024-01-10 14:16:39 +00001063 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001064 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001065 ):
1066 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001067
Tai Lyd3797f02023-11-15 23:06:19 +00001068 # TODO - Test local_bound, for now set local bound attribute to False
1069 local_bound = False
1070
Eric Kunzee5e26762020-10-13 16:11:07 -07001071 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00001072 attr.TransposeConvAttribute(
Tai Lyf36f2562024-03-14 16:21:29 +00001073 out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
Tai Lyd3797f02023-11-15 23:06:19 +00001074 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001075
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001076 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson95a67102024-01-10 14:16:39 +00001077
1078 compliance = self.tensorComplianceMetaData(
1079 op, ifm.dtype, args_dict, result_tensor, error_name
1080 )
1081
1082 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001083
Kevin Cheng550ccc52021-03-03 11:21:43 -08001084 def build_depthwise_conv2d(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001085 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001086 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001087 op,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001088 inputs,
1089 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001090 validator_fcns=None,
1091 error_name=None,
1092 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001093 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001094 assert len(inputs) == 3
1095 ifm, filter, bias = inputs
1096 accum_dtype = args_dict["acc_type"]
1097 strides = args_dict["stride"]
1098 padding = args_dict["pad"]
1099 dilations = args_dict["dilation"]
1100
Jeremy Johnson4f931302024-01-04 17:05:24 +00001101 result_tensor = OutputShaper.depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01001102 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001103 rng,
James Ward8b390432022-08-12 20:48:56 +01001104 ifm,
1105 filter,
1106 accum_dtype,
1107 strides,
1108 padding,
1109 dilations,
1110 error_name,
Les Bell0e027d42021-11-09 14:42:14 +00001111 )
1112
1113 # Ensure new output type has correct qinfo
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001114 if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
1115 DType.INT8,
1116 DType.UINT8,
1117 ):
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001118 qinfo = [
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001119 TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype),
1120 TosaQuantGen.getZeroPoint(
1121 rng, self.args.zeropoint, result_tensor.dtype
1122 ),
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001123 ]
Les Bell0e027d42021-11-09 14:42:14 +00001124
1125 # Invalidate Input/Output list for error_if checks.
1126 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnson4f931302024-01-04 17:05:24 +00001127 output_list = [result_tensor.name]
Les Bell0e027d42021-11-09 14:42:14 +00001128 num_operands = sum(op["operands"])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001129 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001130 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001131 )
Les Bell0e027d42021-11-09 14:42:14 +00001132
Les Bell729b0352021-11-24 10:28:21 +00001133 if not TosaErrorValidator.evValidateErrorIfs(
Les Bell0e027d42021-11-09 14:42:14 +00001134 self.ser,
1135 validator_fcns,
1136 error_name,
1137 op=op,
1138 input_dtype=ifm.dtype,
1139 weight_dtype=filter.dtype,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001140 output_dtype=result_tensor.dtype,
Les Bell0e027d42021-11-09 14:42:14 +00001141 qinfo=qinfo,
1142 input_list=input_list,
1143 num_operands=num_operands,
1144 output_list=output_list,
1145 pad=padding,
1146 stride=strides,
1147 dilation=dilations,
1148 input_shape=ifm.shape,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001149 weight_shape=filter.shape,
Jeremy Johnson4f931302024-01-04 17:05:24 +00001150 output_shape=result_tensor.shape,
Tai Lyf36f2562024-03-14 16:21:29 +00001151 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001152 ):
1153 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001154
Tai Lyd3797f02023-11-15 23:06:19 +00001155 # TODO - Test local_bound, for now set local bound attribute to False
1156 local_bound = False
1157
Eric Kunzee5e26762020-10-13 16:11:07 -07001158 attr = ts.TosaSerializerAttribute()
Tai Lyf36f2562024-03-14 16:21:29 +00001159 attr.ConvAttribute(
1160 padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
1161 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001162
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001163 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson4f931302024-01-04 17:05:24 +00001164
1165 compliance = self.tensorComplianceMetaData(
1166 op, ifm.dtype, args_dict, result_tensor, error_name
1167 )
1168
1169 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001170
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001171 def build_fully_connected(
James Ward8b390432022-08-12 20:48:56 +01001172 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001173 rng,
James Ward8b390432022-08-12 20:48:56 +01001174 op,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001175 inputs,
1176 args_dict,
James Ward8b390432022-08-12 20:48:56 +01001177 validator_fcns=None,
1178 error_name=None,
1179 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001180 ):
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001181 assert len(inputs) == 3
1182 ifm, filter, bias = inputs
1183 accum_dtype = args_dict["acc_type"]
1184
1185 result_tensor = OutputShaper.fullyConnectedOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001186 self.ser, rng, ifm, filter, accum_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001187 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001188
1189 # Invalidate Input/Output list for error if checks.
1190 input_list = [ifm.name, filter.name, bias.name]
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001191 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001192 pCount, cCount = op["operands"]
1193 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001194 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001195 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001196 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001197
Les Bell729b0352021-11-24 10:28:21 +00001198 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001199 self.ser,
1200 validator_fcns,
1201 error_name,
1202 op=op,
1203 input_shape=ifm.shape,
1204 input_dtype=ifm.dtype,
1205 weight_dtype=filter.dtype,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001206 output_shape=result_tensor.shape,
1207 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001208 qinfo=qinfo,
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001209 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001210 input_list=input_list,
1211 output_list=output_list,
1212 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001213 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001214 ):
1215 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001216
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001217 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001218 attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001219
1220 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001221
1222 compliance = self.tensorComplianceMetaData(
1223 op, ifm.dtype, args_dict, result_tensor, error_name
1224 )
1225
1226 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001227
James Ward8b390432022-08-12 20:48:56 +01001228 def build_matmul(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001229 self,
1230 rng,
1231 op,
1232 inputs,
1233 args_dict,
1234 validator_fcns=None,
1235 error_name=None,
1236 qinfo=None,
James Ward8b390432022-08-12 20:48:56 +01001237 ):
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001238 assert len(inputs) == 2
1239 a, b = inputs
Jeremy Johnson1271c442023-09-05 11:39:26 +01001240 accum_dtype = args_dict["acc_type"]
1241 result_tensor = OutputShaper.matmulOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001242 self.ser, rng, a, b, accum_dtype, error_name
James Ward8b390432022-08-12 20:48:56 +01001243 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001244
1245 # Invalidate Input/Output list for error if checks.
1246 input_list = [a.name, b.name]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001247 output_list = [result_tensor.name]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001248 pCount, cCount = op["operands"]
1249 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001250 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001251 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001252 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001253
Les Bell729b0352021-11-24 10:28:21 +00001254 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001255 self.ser,
1256 validator_fcns,
1257 error_name,
1258 op=op,
1259 input_shape=a.shape,
1260 input_dtype=a.dtype,
1261 input2_shape=b.shape,
1262 input2_dtype=b.dtype,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001263 output_shape=result_tensor.shape,
1264 output_dtype=result_tensor.dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001265 qinfo=qinfo,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001266 result_tensors=[result_tensor],
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001267 input_list=input_list,
1268 output_list=output_list,
1269 num_operands=num_operands,
James Ward8b390432022-08-12 20:48:56 +01001270 accum_dtype=accum_dtype,
Les Bell729b0352021-11-24 10:28:21 +00001271 ):
1272 return None
Matthew Haddonc4cf0372021-10-11 09:38:10 +01001273
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001274 attr = ts.TosaSerializerAttribute()
James Wardd34b3fc2023-01-18 14:51:25 +00001275 attr.MatMulAttribute(qinfo[0], qinfo[1])
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001276
1277 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001278
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001279 compliance = self.tensorComplianceMetaData(
1280 op, a.dtype, args_dict, result_tensor, error_name
1281 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001282
1283 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001284
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001285 def build_reduce(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001286 self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001287 ):
1288 assert len(inputs) == 1
1289 a = inputs[0]
1290 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001291 result_tensor = OutputShaper.reduceOp(self.ser, rng, a, axis, error_name)
Matthew Haddond6ce7252021-09-29 15:35:44 +01001292
1293 # Invalidate Input/Output list for error if checks.
1294 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001295 output_list = [result_tensor.name]
Matthew Haddond6ce7252021-09-29 15:35:44 +01001296 pCount, cCount = op["operands"]
1297 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001298 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001299 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001300 )
Matthew Haddond6ce7252021-09-29 15:35:44 +01001301
Les Bell729b0352021-11-24 10:28:21 +00001302 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddond6ce7252021-09-29 15:35:44 +01001303 self.ser,
1304 validator_fcns,
1305 error_name,
1306 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001307 axis=axis,
1308 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001309 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001310 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001311 output_dtype=result_tensor.dtype,
1312 result_tensors=[result_tensor],
Matthew Haddond6ce7252021-09-29 15:35:44 +01001313 input_list=input_list,
1314 output_list=output_list,
1315 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001316 ):
1317 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001318
1319 attr = ts.TosaSerializerAttribute()
1320 attr.AxisAttribute(axis)
1321
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001322 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001323
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001324 if error_name is None and op["op"] == Op.REDUCE_PRODUCT:
1325 # Number of products - needed for compliance
1326 args_dict["n"] = a.shape[axis]
1327
1328 compliance = self.tensorComplianceMetaData(
1329 op, a.dtype, args_dict, result_tensor, error_name
1330 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001331
1332 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001333
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001334 def build_clamp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001335 self,
1336 rng,
1337 op,
1338 inputs,
1339 args_dict,
1340 validator_fcns=None,
1341 error_name=None,
1342 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001343 ):
1344 assert len(inputs) == 1
1345 a = inputs[0]
1346
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001347 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001348
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001349 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001350
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001351 if error_name == ErrorIf.MaxSmallerMin:
1352 # Make sure the numbers are different to invoke this error
1353 while v[0] == v[1]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001354 v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001355 max_val = min(v)
1356 min_val = max(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001357 else:
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001358 max_val = max(v)
1359 min_val = min(v)
Eric Kunzee5e26762020-10-13 16:11:07 -07001360
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001361 # Invalidate Input/Output list for error if checks.
1362 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001363 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001364 pCount, cCount = op["operands"]
1365 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001366 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001367 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001368 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001369
Les Bell729b0352021-11-24 10:28:21 +00001370 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001371 self.ser,
1372 validator_fcns,
1373 error_name,
1374 op=op,
1375 max_val=max_val,
1376 min_val=min_val,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001377 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001378 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001379 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001380 output_dtype=result_tensor.dtype,
1381 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001382 input_list=input_list,
1383 output_list=output_list,
1384 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001385 ):
1386 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001387
1388 attr = ts.TosaSerializerAttribute()
James Ward34071252022-12-07 15:48:47 +00001389 if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
1390 if a.dtype == DType.FP16:
1391 # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
1392 min_val = min_val.astype(np.float32)
1393 max_val = max_val.astype(np.float32)
Tai Ly60dc48c2024-03-08 22:19:41 +00001394 min_val_as_bytes = struct.pack("<f", min_val)
1395 max_val_as_bytes = struct.pack("<f", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001396 elif a.dtype in (DType.INT8, DType.INT16):
Tai Ly60dc48c2024-03-08 22:19:41 +00001397 min_val_as_bytes = struct.pack("<i", min_val)
1398 max_val_as_bytes = struct.pack("<i", max_val)
Won Jeon2c34b462024-02-06 18:37:00 +00001399 else:
1400 # to avoid internal error for incorrect input types
Tai Ly60dc48c2024-03-08 22:19:41 +00001401 min_val_as_bytes = struct.pack("<i", 0)
1402 max_val_as_bytes = struct.pack("<i", 0)
1403
1404 attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001405
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001406 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001407
1408 compliance = self.tensorComplianceMetaData(
1409 op, a.dtype, args_dict, result_tensor, error_name
1410 )
1411
1412 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001413
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001414 def build_activation(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001415 self,
1416 rng,
1417 op,
1418 inputs,
1419 args_dict,
1420 validator_fcns=None,
1421 error_name=None,
1422 qinfo=None,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001423 ):
1424 assert len(inputs) == 1
1425 a = inputs[0]
1426
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001427 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001428
1429 # Invalidate Input/Output list for error if checks.
1430 input_list = [a.name]
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001431 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001432 pCount, cCount = op["operands"]
1433 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001434 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001435 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001436 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001437
Les Bell729b0352021-11-24 10:28:21 +00001438 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001439 self.ser,
1440 validator_fcns,
1441 error_name,
1442 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001443 input_shape=a.shape,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001444 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001445 input_dtype=a.dtype,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001446 output_dtype=result_tensor.dtype,
1447 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001448 input_list=input_list,
1449 output_list=output_list,
1450 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001451 ):
1452 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001453
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001454 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001455
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001456 compliance = self.tensorComplianceMetaData(
1457 op, a.dtype, args_dict, result_tensor, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001458 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001459
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00001460 return TosaTestGen.BuildInfo(result_tensor, compliance)
Won Jeon78155c62023-06-10 00:20:04 +00001461
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001462 def build_concat(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001463 self,
1464 rng,
1465 op,
1466 inputs,
1467 args_dict,
1468 validator_fcns=None,
1469 error_name=None,
1470 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001471 ):
Won Jeon74342e52024-01-09 00:34:40 +00001472 if op["op"] == Op.CONCAT_SHAPE:
1473 axis = 0
1474 else:
1475 axis = args_dict["axis"]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001476 if error_name != ErrorIf.WrongInputType:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001477 assert type(axis) == int
Matthew Haddon818ab902021-07-27 09:12:49 +01001478
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001479 result_tensor = OutputShaper.concatOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001480 self.ser, rng, axis, inputs, error_name=error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001481 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001482
Matthew Haddon818ab902021-07-27 09:12:49 +01001483 input_tensor_names = []
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001484 for tensor in inputs:
Matthew Haddon818ab902021-07-27 09:12:49 +01001485 input_tensor_names.append(tensor.name)
1486
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001487 # Invalidate Input/Output list for error if checks.
1488 input_list = input_tensor_names
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001489 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001490 pCount, cCount = op["operands"]
1491 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001492 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001493 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001494 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001495
Les Bell729b0352021-11-24 10:28:21 +00001496 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001497 self.ser,
1498 validator_fcns,
1499 error_name,
1500 op=op,
1501 axis=axis,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001502 input_shape=inputs[0].shape,
1503 output_shape=result_tensor.shape,
1504 input_dtype=inputs[0].dtype,
1505 output_dtype=result_tensor.dtype,
1506 inputs=inputs,
1507 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001508 input_list=input_list,
1509 output_list=output_list,
1510 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001511 ):
1512 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001513
Won Jeon74342e52024-01-09 00:34:40 +00001514 if op["op"] == Op.CONCAT:
1515 attr = ts.TosaSerializerAttribute()
1516 attr.AxisAttribute(axis)
1517 else:
1518 assert op["op"] == Op.CONCAT_SHAPE
1519 attr = None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001520 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001521
1522 compliance = self.tensorComplianceMetaData(
1523 op, inputs[0].dtype, args_dict, result_tensor, error_name
1524 )
1525
1526 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001527
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001528 def build_pad(
1529 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001530 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001531 op,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001532 inputs,
1533 args_dict,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001534 validator_fcns=None,
1535 error_name=None,
1536 qinfo=None,
1537 ):
Tai Lye095da72024-01-25 22:00:18 +00001538 assert len(inputs) == 2
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001539 a = inputs[0]
Tai Lye095da72024-01-25 22:00:18 +00001540 pad_input = inputs[1]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001541 padding = args_dict["pad"]
1542 pad_const_int = args_dict["pad_const_int"]
1543 pad_const_float = args_dict["pad_const_fp"]
1544
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001545 result_tensor = OutputShaper.padOp(self.ser, rng, a, padding, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001546
Tai Ly60dc48c2024-03-08 22:19:41 +00001547 # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
1548 if gtu.dtypeIsFloat(a.dtype):
1549 pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
1550 else:
1551 pad_const_val_as_bytes = struct.pack("<i", pad_const_int)
1552
Kevin Chengfe392ce2021-10-18 21:51:55 +00001553 attr = ts.TosaSerializerAttribute()
Tai Ly60dc48c2024-03-08 22:19:41 +00001554 attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
Eric Kunzee5e26762020-10-13 16:11:07 -07001555
Matthew Haddone807aae2021-10-11 18:12:58 +01001556 # Invalidate Input/Output list for error if checks.
Tai Lye095da72024-01-25 22:00:18 +00001557 input_list = [a.name, pad_input.name]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001558 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001559 pCount, cCount = op["operands"]
1560 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001561 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001562 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001563 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001564
Les Bell729b0352021-11-24 10:28:21 +00001565 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001566 self.ser,
1567 validator_fcns,
1568 error_name,
1569 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001570 input_shape=a.shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001571 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001572 input_dtype=a.dtype,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001573 output_dtype=result_tensor.dtype,
Matthew Haddone807aae2021-10-11 18:12:58 +01001574 pad=padding,
1575 qinfo=qinfo,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001576 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001577 input_list=input_list,
1578 output_list=output_list,
1579 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001580 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001581 ):
1582 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001583
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001584 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001585
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001586 compliance = self.tensorComplianceMetaData(
1587 op, a.dtype, args_dict, result_tensor, error_name
1588 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001589
1590 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001591
Won Jeona21b2e82023-08-10 10:33:01 +00001592 def build_dim(
1593 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001594 rng,
Won Jeona21b2e82023-08-10 10:33:01 +00001595 op,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001596 inputs,
1597 args_dict,
Won Jeona21b2e82023-08-10 10:33:01 +00001598 validator_fcns=None,
1599 error_name=None,
1600 qinfo=None,
1601 ):
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001602 assert len(inputs) == 1
1603 a = inputs[0]
1604 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001605 result_tensor = OutputShaper.dimOp(self.ser, rng, a, axis, error_name)
Won Jeona21b2e82023-08-10 10:33:01 +00001606
1607 # Invalidate Input/Output list for error if checks.
1608 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001609 output_list = [result_tensor.name]
Won Jeona21b2e82023-08-10 10:33:01 +00001610 pCount, cCount = op["operands"]
1611 num_operands = pCount + cCount
1612 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001613 rng, error_name, input_list, output_list
Won Jeona21b2e82023-08-10 10:33:01 +00001614 )
1615
1616 if not TosaErrorValidator.evValidateErrorIfs(
1617 self.ser,
1618 validator_fcns,
1619 error_name,
1620 op=op,
1621 axis=axis,
1622 input_shape=a.shape,
1623 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001624 output_shape=result_tensor.shape,
1625 output_dtype=result_tensor.dtype,
1626 result_tensors=[result_tensor],
Won Jeona21b2e82023-08-10 10:33:01 +00001627 input_list=input_list,
1628 output_list=output_list,
1629 num_operands=num_operands,
1630 ):
1631 return None
1632
1633 attr = ts.TosaSerializerAttribute()
1634 attr.AxisAttribute(axis)
1635
1636 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001637 return TosaTestGen.BuildInfo(result_tensor, None)
Won Jeona21b2e82023-08-10 10:33:01 +00001638
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001639 def build_reshape(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001640 self,
1641 rng,
1642 op,
1643 inputs,
1644 args_dict,
1645 validator_fcns=None,
1646 error_name=None,
1647 qinfo=None,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001648 ):
Tai Ly8690a082023-12-18 20:40:24 +00001649 assert len(inputs) == 2
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001650 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001651 shape = inputs[1]
1652 shape_attr = args_dict["new_shape"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001653 result_tensor = OutputShaper.reshapeOp(self.ser, rng, a, shape_attr, error_name)
Matthew Haddone807aae2021-10-11 18:12:58 +01001654
1655 # Invalidate Input/Output list for error if checks.
Won Jeon64e4bfe2024-01-18 06:31:55 +00001656 input_list = [a.name, shape.name]
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001657 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001658 pCount, cCount = op["operands"]
1659 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001660 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001661 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001662 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001663
Les Bell729b0352021-11-24 10:28:21 +00001664 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001665 self.ser,
1666 validator_fcns,
1667 error_name,
1668 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001669 input_shape=a.shape,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001670 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001671 input_dtype=a.dtype,
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001672 output_dtype=result_tensor.dtype,
1673 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001674 input_list=input_list,
1675 output_list=output_list,
1676 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001677 ):
1678 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001679
Tai Ly8690a082023-12-18 20:40:24 +00001680 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00001681
1682 compliance = self.tensorComplianceMetaData(
1683 op, a.dtype, args_dict, result_tensor, error_name
1684 )
1685
1686 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001687
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001688 def build_reverse(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001689 self,
1690 rng,
1691 op,
1692 inputs,
1693 args_dict,
1694 validator_fcns=None,
1695 error_name=None,
1696 qinfo=None,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001697 ):
1698 assert len(inputs) == 1
1699 a = inputs[0]
1700 axis = args_dict["axis"]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001701 result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001702
1703 # Invalidate Input/Output list for error if checks.
1704 input_list = [a.name]
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001705 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001706 pCount, cCount = op["operands"]
1707 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001708 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001709 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001710 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001711
Les Bell729b0352021-11-24 10:28:21 +00001712 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001713 self.ser,
1714 validator_fcns,
1715 error_name,
1716 op=op,
1717 axis=axis,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001718 input_shape=a.shape,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001719 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001720 input_dtype=a.dtype,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001721 output_dtype=result_tensor.dtype,
1722 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001723 input_list=input_list,
1724 output_list=output_list,
1725 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001726 ):
1727 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001728
1729 attr = ts.TosaSerializerAttribute()
1730 attr.AxisAttribute(axis)
1731
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001732 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001733 return TosaTestGen.BuildInfo(result_tensor, None)
Eric Kunzee5e26762020-10-13 16:11:07 -07001734
evacha0198477222024-01-26 12:25:32 +00001735 def build_transpose(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001736 self,
1737 rng,
1738 op,
1739 inputs,
1740 args_dict,
1741 validator_fcns=None,
1742 error_name=None,
1743 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00001744 ):
1745 assert len(inputs) == 1
1746 a = inputs[0]
1747 perms = args_dict["perms"]
1748
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001749 result_tensor = OutputShaper.transposeOp(self.ser, rng, a, perms, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07001750
Kevin Chengfe392ce2021-10-18 21:51:55 +00001751 attr = ts.TosaSerializerAttribute()
1752 attr.TransposeAttribute(perms)
Eric Kunzee5e26762020-10-13 16:11:07 -07001753
Matthew Haddone807aae2021-10-11 18:12:58 +01001754 # Invalidate Input/Output list for error if checks.
Kevin Chengfe392ce2021-10-18 21:51:55 +00001755 input_list = [a.name]
evacha0198477222024-01-26 12:25:32 +00001756 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001757 pCount, cCount = op["operands"]
1758 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001759 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001760 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001761 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001762
Les Bell729b0352021-11-24 10:28:21 +00001763 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001764 self.ser,
1765 validator_fcns,
1766 error_name,
1767 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001768 input_shape=a.shape,
evacha0198477222024-01-26 12:25:32 +00001769 output_shape=result_tensor.shape,
Matthew Haddone807aae2021-10-11 18:12:58 +01001770 perms=perms,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001771 input_dtype=a.dtype,
evacha0198477222024-01-26 12:25:32 +00001772 output_dtype=result_tensor.dtype,
1773 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001774 input_list=input_list,
1775 output_list=output_list,
1776 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001777 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001778 ):
1779 return None
Matthew Haddone807aae2021-10-11 18:12:58 +01001780
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001781 self.ser.addOperator(op["op"], input_list, output_list, attr)
evacha0198477222024-01-26 12:25:32 +00001782
1783 compliance = self.tensorComplianceMetaData(
1784 op, a.dtype, args_dict, result_tensor, error_name
1785 )
1786
1787 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001788
evacha017f7d4252024-01-24 12:08:09 +00001789 def build_slice(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001790 self,
1791 rng,
1792 op,
1793 inputs,
1794 args_dict,
1795 validator_fcns=None,
1796 error_name=None,
1797 qinfo=None,
evacha017f7d4252024-01-24 12:08:09 +00001798 ):
TatWai Chongf15bad82024-01-31 21:33:27 -08001799 assert len(inputs) == 3
1800 a, start_var, size_var = inputs
1801 start_const = args_dict["start"]
1802 size_const = args_dict["size"]
evacha017f7d4252024-01-24 12:08:09 +00001803
1804 result_tensor = OutputShaper.sliceOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001805 self.ser, rng, a, start_const, size_const, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001806 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001807
1808 # Invalidate Input/Output list for error if checks.
TatWai Chongf15bad82024-01-31 21:33:27 -08001809 input_list = [a.name, start_var.name, size_var.name]
evacha017f7d4252024-01-24 12:08:09 +00001810 output_list = [result_tensor.name]
Matthew Haddone807aae2021-10-11 18:12:58 +01001811 pCount, cCount = op["operands"]
1812 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001813 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001814 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001815 )
Matthew Haddone807aae2021-10-11 18:12:58 +01001816
Les Bell729b0352021-11-24 10:28:21 +00001817 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddone807aae2021-10-11 18:12:58 +01001818 self.ser,
1819 validator_fcns,
1820 error_name,
1821 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001822 input_shape=a.shape,
evacha017f7d4252024-01-24 12:08:09 +00001823 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001824 input_dtype=a.dtype,
evacha017f7d4252024-01-24 12:08:09 +00001825 output_dtype=result_tensor.dtype,
TatWai Chongf15bad82024-01-31 21:33:27 -08001826 start=start_const,
1827 size=size_const,
evacha017f7d4252024-01-24 12:08:09 +00001828 result_tensors=[result_tensor],
Matthew Haddone807aae2021-10-11 18:12:58 +01001829 input_list=input_list,
1830 output_list=output_list,
1831 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001832 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001833 ):
1834 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001835
Tai Ly8ead6c42024-02-14 22:35:44 +00001836 self.ser.addOperator(op["op"], input_list, output_list)
evacha017f7d4252024-01-24 12:08:09 +00001837
1838 compliance = self.tensorComplianceMetaData(
1839 op, a.dtype, args_dict, result_tensor, error_name
1840 )
1841
1842 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001843
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001844 def build_tile(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001845 self,
1846 rng,
1847 op,
1848 inputs,
1849 args_dict,
1850 validator_fcns=None,
1851 error_name=None,
1852 qinfo=None,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001853 ):
Tai Ly8690a082023-12-18 20:40:24 +00001854 assert len(inputs) == 2
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001855 a = inputs[0]
Won Jeon64e4bfe2024-01-18 06:31:55 +00001856 multiples = inputs[1]
Tai Ly8690a082023-12-18 20:40:24 +00001857 multiples_attr = args_dict["multiples"]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001858 result_tensor = OutputShaper.tileOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001859 self.ser, rng, a, multiples_attr, error_name
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001860 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001861
1862 # Invalidate Input/Output list for error if checks.
Tai Ly8690a082023-12-18 20:40:24 +00001863 input_list = [a.name, multiples.name]
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001864 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001865 pCount, cCount = op["operands"]
1866 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001867 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001868 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001869 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001870
Les Bell729b0352021-11-24 10:28:21 +00001871 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001872 self.ser,
1873 validator_fcns,
1874 error_name,
1875 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001876 input_shape=a.shape,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001877 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001878 input_dtype=a.dtype,
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001879 output_dtype=result_tensor.dtype,
1880 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001881 input_list=input_list,
1882 output_list=output_list,
1883 num_operands=num_operands,
Luke Huttona4e48ca2023-02-22 11:53:48 +00001884 input1=a,
Les Bell729b0352021-11-24 10:28:21 +00001885 ):
1886 return None
Eric Kunzee5e26762020-10-13 16:11:07 -07001887
Tai Ly8690a082023-12-18 20:40:24 +00001888 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00001889
1890 compliance = self.tensorComplianceMetaData(
1891 op, a.dtype, args_dict, result_tensor, error_name
1892 )
1893
1894 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07001895
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001896 def build_gather(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001897 self,
1898 rng,
1899 op,
1900 inputs,
1901 args_dict,
1902 validator_fcns=None,
1903 error_name=None,
1904 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001905 ):
1906 assert len(inputs) == 2
1907 values, indices = inputs
Eric Kunzee5e26762020-10-13 16:11:07 -07001908
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001909 result_tensor = OutputShaper.gatherOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001910 self.ser, rng, values, indices, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001911 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001912
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001913 # Invalidate Input/Output list for error if checks.
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001914 input_list = [values.name, indices.name]
1915 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001916 pCount, cCount = op["operands"]
1917 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001918 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001919 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001920 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001921
Les Bell729b0352021-11-24 10:28:21 +00001922 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001923 self.ser,
1924 validator_fcns,
1925 error_name,
1926 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001927 input_shape=values.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001928 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001929 input_dtype=values.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001930 output_dtype=result_tensor.dtype,
1931 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001932 input_list=input_list,
1933 output_list=output_list,
1934 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001935 ):
1936 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001937
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001938 self.ser.addOperator(op["op"], input_list, output_list)
Eric Kunzee5e26762020-10-13 16:11:07 -07001939
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001940 compliance = self.tensorComplianceMetaData(
1941 op, values.dtype, args_dict, result_tensor, error_name
1942 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001943
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001944 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001945
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001946 def build_scatter(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001947 self,
1948 rng,
1949 op,
1950 inputs,
1951 args_dict,
1952 validator_fcns=None,
1953 error_name=None,
1954 qinfo=None,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001955 ):
1956 assert len(inputs) == 3
1957 values_in, indices, input = inputs
1958 result_tensor = OutputShaper.scatterOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001959 self.ser, rng, values_in, indices, input, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001960 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001961
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001962 # Invalidate Input/Output list for error if checks.
Jeremy Johnson194fe312023-12-07 14:17:57 +00001963 input_list = [values_in.name, indices.name, input.name]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001964 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001965 pCount, cCount = op["operands"]
1966 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001967 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001968 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001969 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001970
Les Bell729b0352021-11-24 10:28:21 +00001971 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001972 self.ser,
1973 validator_fcns,
1974 error_name,
1975 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001976 input_shape=values_in.shape,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001977 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001978 input_dtype=values_in.dtype,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001979 output_dtype=result_tensor.dtype,
1980 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001981 input_list=input_list,
1982 output_list=output_list,
1983 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00001984 ):
1985 return None
Kevin Cheng77d0f762020-11-24 10:26:32 -08001986
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00001987 self.ser.addOperator(op["op"], input_list, output_list)
Matthew Haddonbb5676f2021-10-13 11:30:30 +01001988
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001989 compliance = self.tensorComplianceMetaData(
1990 op, values_in.dtype, args_dict, result_tensor, error_name
1991 )
1992
1993 return TosaTestGen.BuildInfo(result_tensor, compliance)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001994
Kevin Cheng550ccc52021-03-03 11:21:43 -08001995 def build_resize(
1996 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01001997 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001998 op,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001999 inputs,
2000 args_dict,
Matthew Haddone86fd342021-09-07 16:12:21 +01002001 validator_fcns,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002002 error_name=None,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002003 qinfo=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002004 ):
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002005 assert len(inputs) == 4
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002006 input = inputs[0]
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002007 scale_input = inputs[1]
2008 offset_input = inputs[2]
2009 border_input = inputs[3]
2010
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002011 mode = args_dict["mode"]
2012 scale = args_dict["scale"]
2013 offset = args_dict["offset"]
2014 border = args_dict["border"]
2015 output_dtype = args_dict["output_dtype"]
2016
2017 result_tensor = OutputShaper.resizeOp(
Kevin Cheng550ccc52021-03-03 11:21:43 -08002018 self.ser,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002019 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002020 input,
2021 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002022 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002023 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002024 border,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002025 input.dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002026 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002027 error_name,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002028 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002029
Matthew Haddon848efb42021-09-09 12:30:53 +01002030 # Invalidate Input/Output list for error if checks.
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002031 input_list = [
2032 input.name,
2033 scale_input.name,
2034 offset_input.name,
2035 border_input.name,
2036 ]
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002037 output_list = [result_tensor.name]
Matthew Haddon848efb42021-09-09 12:30:53 +01002038 pCount, cCount = op["operands"]
2039 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002040 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002041 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002042 )
Matthew Haddone86fd342021-09-07 16:12:21 +01002043
Les Bell729b0352021-11-24 10:28:21 +00002044 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon848efb42021-09-09 12:30:53 +01002045 self.ser,
2046 validator_fcns,
2047 error_name,
2048 op=op,
2049 mode=mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002050 scale=scale,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002051 input_dtype=input.dtype,
Matthew Haddon848efb42021-09-09 12:30:53 +01002052 output_dtype=output_dtype,
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002053 input_shape=input.shape,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002054 output_shape=result_tensor.shape,
Matthew Haddon848efb42021-09-09 12:30:53 +01002055 offset=offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002056 border=border,
Matthew Haddon848efb42021-09-09 12:30:53 +01002057 input_list=input_list,
2058 output_list=output_list,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002059 result_tensors=[result_tensor],
Matthew Haddon848efb42021-09-09 12:30:53 +01002060 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002061 ):
2062 return None
Matthew Haddone86fd342021-09-07 16:12:21 +01002063
Eric Kunzee5e26762020-10-13 16:11:07 -07002064 attr = ts.TosaSerializerAttribute()
Tai Lyc5c2a7e2024-02-22 23:26:28 +00002065 # write empty scale/offset/border into ResizeAttribute
2066 attr.ResizeAttribute([], [], [], mode)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002067 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00002068
2069 compliance = self.tensorComplianceMetaData(
2070 op, input.dtype, args_dict, result_tensor, error_name
2071 )
2072
2073 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002074
evacha0198477222024-01-26 12:25:32 +00002075 def build_const(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002076 self,
2077 rng,
2078 op,
2079 inputs,
2080 args_dict,
2081 validator_fcns=None,
2082 error_name=None,
2083 qinfo=None,
evacha0198477222024-01-26 12:25:32 +00002084 ):
2085 assert len(inputs) == 1
2086 val = inputs[0]
Kevin Cheng17e92022021-10-01 14:33:33 -07002087 self.ser.addOutputTensor(val)
evacha0198477222024-01-26 12:25:32 +00002088
2089 compliance = self.tensorComplianceMetaData(
2090 op, val.dtype, args_dict, val, error_name
2091 )
2092
2093 return TosaTestGen.BuildInfo(val, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002094
2095 # Type Conversion
Jeremy Johnson708da822023-11-15 16:25:45 +00002096 def build_cast(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002097 self,
2098 rng,
2099 op,
2100 inputs,
2101 args_dict,
2102 validator_fcns=None,
2103 error_name=None,
2104 qinfo=None,
Jeremy Johnson708da822023-11-15 16:25:45 +00002105 ):
2106 assert len(inputs) == 1
2107 val = inputs[0]
2108 out_dtype = args_dict["out_type"]
2109
2110 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002111 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002112 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002113
2114 # Invalidate Input/Output list for error if checks.
2115 input_list = [val.name]
Jeremy Johnson708da822023-11-15 16:25:45 +00002116 output_list = [result_tensor.name]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002117 pCount, cCount = op["operands"]
2118 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002119 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002120 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002121 )
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002122
Les Bell729b0352021-11-24 10:28:21 +00002123 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002124 self.ser,
2125 validator_fcns,
2126 error_name,
2127 op=op,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002128 input_shape=val.shape,
Jeremy Johnson708da822023-11-15 16:25:45 +00002129 output_shape=result_tensor.shape,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002130 input_dtype=val.dtype,
Jeremy Johnson708da822023-11-15 16:25:45 +00002131 output_dtype=result_tensor.dtype,
2132 result_tensors=[result_tensor],
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002133 input_list=input_list,
2134 output_list=output_list,
2135 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002136 ):
2137 return None
Matthew Haddonbb5676f2021-10-13 11:30:30 +01002138
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002139 self.ser.addOperator(op["op"], input_list, output_list)
Jeremy Johnson708da822023-11-15 16:25:45 +00002140
2141 compliance = self.tensorComplianceMetaData(
2142 op, val.dtype, args_dict, result_tensor, error_name
2143 )
2144
2145 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002146
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002147 def build_rescale(
2148 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002149 rng,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002150 op,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002151 inputs,
2152 args_dict,
2153 validator_fcns=None,
2154 error_name=None,
2155 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002156 ):
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002157 assert len(inputs) == 3
Jeremy Johnson587cc842024-02-08 11:45:44 +00002158 val = inputs[0]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002159 multiplier_val = inputs[1]
2160 shift_val = inputs[2]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002161 out_dtype = args_dict["output_dtype"]
2162 scale32 = args_dict["scale"]
2163 double_round = args_dict["double_round"]
2164 per_channel = args_dict["per_channel"]
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002165 shift_arr = args_dict["shift"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002166 multiplier_arr = args_dict["multiplier"]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002167
2168 result_tensor = OutputShaper.typeConversionOp(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002169 self.ser, rng, val, out_dtype, error_name
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002170 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002171
2172 if per_channel:
2173 nc = val.shape[-1]
2174 else:
2175 nc = 1
2176
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002177 in_type_width = gtu.dtypeWidth(val.dtype)
2178 out_type_width = gtu.dtypeWidth(out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002179
Tai Ly8690a082023-12-18 20:40:24 +00002180 input_unsigned = False
2181 output_unsigned = False
2182
Kevin Cheng3a478572021-01-22 17:21:02 -08002183 if val.dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002184 input_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002185 in_type_width += 1
Kevin Chengacb550f2021-06-29 15:32:19 -07002186 elif val.dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002187 input_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002188 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002189 input_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002190 elif error_name in [
2191 ErrorIf.InputZeroPointNotZero,
2192 ErrorIf.U16InputZeroPointNotValid,
2193 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002194 input_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002195 if input_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002196 input_zp = input_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002197 in_type_width += 1
2198 elif val.dtype == DType.UINT16:
2199 # Must come after ErrorIf.U16InputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002200 input_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002201 in_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002202 input_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002203 else:
2204 input_zp = 0
2205
Kevin Cheng3a478572021-01-22 17:21:02 -08002206 if out_dtype == DType.INT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002207 output_zp = rng.randInt(-128, 128)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002208 out_type_width += 1
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002209 elif out_dtype == DType.UINT8:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002210 output_zp = rng.randInt(0, 256)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002211 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002212 output_unsigned = True
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002213 elif error_name in [
2214 ErrorIf.OutputZeroPointNotZero,
2215 ErrorIf.U16OutputZeroPointNotValid,
2216 ]:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002217 output_zp = rng.randInt(-128, 128)
Matthew Haddonc2025212021-10-08 21:21:05 +01002218 if output_zp == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002219 output_zp = output_zp + rng.integers(1, 10)
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002220 out_type_width += 1
2221 elif out_dtype == DType.UINT16:
2222 # Must come after ErrorIf.U16OutputZeroPointNotValid check
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002223 output_zp = rng.choice([0, 32768])
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002224 out_type_width += 1
Tai Ly8690a082023-12-18 20:40:24 +00002225 output_unsigned = True
Eric Kunzee5e26762020-10-13 16:11:07 -07002226 else:
2227 output_zp = 0
2228
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002229 min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
2230 max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
Eric Kunzee5e26762020-10-13 16:11:07 -07002231
2232 for i in range(nc):
Eric Kunze750d27d2022-06-30 21:37:09 +00002233 min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
2234 max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002235
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002236 logger.debug(
2237 f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
2238 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002239 if scale32 and error_name is None:
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002240 # Make sure random values are within apply_scale_32 specification
Eric Kunze750d27d2022-06-30 21:37:09 +00002241 # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002242 assert val.placeholderFilename
2243 values = np.load(
2244 os.path.join(self.basePath, self.testPath, val.placeholderFilename)
2245 )
Jeremy Johnsonc0fe04d2022-02-17 12:29:35 +00002246 val_adj = np.subtract(values, input_zp, dtype=np.int64)
2247 val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64)
2248 val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64)
Jerry Gec5291692024-01-02 22:29:08 +00002249 val_adj = np.add(val_adj, input_zp, dtype=np.int64)
2250 # Check we can safely convert to the expected dtype
2251 assert (
2252 val_adj.all() >= np.iinfo(values.dtype).min
2253 and val_adj.all() <= np.iinfo(values.dtype).max
2254 )
2255
2256 # Force casting to output datatype
2257 val_adj = val_adj.astype(values.dtype, casting="unsafe")
2258
Jeremy Johnson42c9bae2022-02-01 11:37:58 +00002259 if not np.all(np.array_equal(values, val_adj)):
2260 # Values changed so overwrite file with new values
2261 np.save(
2262 os.path.join(self.basePath, self.testPath, val.placeholderFilename),
2263 val_adj,
2264 False,
2265 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002266
Matthew Haddonc2025212021-10-08 21:21:05 +01002267 # Invalidate Input/Output list for error if checks.
Tai Ly6e1e2bc2024-03-01 20:59:32 +00002268 input_list = [val.name, multiplier_val.name, shift_val.name]
Jeremy Johnson587cc842024-02-08 11:45:44 +00002269 output_list = [result_tensor.name]
Matthew Haddonc2025212021-10-08 21:21:05 +01002270 pCount, cCount = op["operands"]
2271 num_operands = pCount + cCount
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002272 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002273 rng, error_name, input_list, output_list
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002274 )
Matthew Haddonc2025212021-10-08 21:21:05 +01002275
2276 qinfo = (input_zp, output_zp)
Les Bell729b0352021-11-24 10:28:21 +00002277 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddonc2025212021-10-08 21:21:05 +01002278 self.ser,
2279 validator_fcns,
2280 error_name,
2281 op=op,
2282 input_dtype=val.dtype,
2283 output_dtype=out_dtype,
2284 input_shape=val.shape,
2285 qinfo=qinfo,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002286 scale32=scale32,
2287 double_round=double_round,
Matthew Haddonc2025212021-10-08 21:21:05 +01002288 input_list=input_list,
2289 output_list=output_list,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002290 result_tensors=[result_tensor],
Matthew Haddonc2025212021-10-08 21:21:05 +01002291 num_operands=num_operands,
Les Bell729b0352021-11-24 10:28:21 +00002292 ):
2293 return None
Matthew Haddonc2025212021-10-08 21:21:05 +01002294
Eric Kunzee5e26762020-10-13 16:11:07 -07002295 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002296 attr.RescaleAttribute(
2297 input_zp,
2298 output_zp,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002299 scale32,
2300 double_round,
2301 per_channel,
Tai Ly8690a082023-12-18 20:40:24 +00002302 input_unsigned,
2303 output_unsigned,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002304 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002305
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002306 self.ser.addOperator(op["op"], input_list, output_list, attr)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002307
2308 compliance = self.tensorComplianceMetaData(
2309 op, val.dtype, args_dict, result_tensor, error_name
2310 )
2311
2312 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002313
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002314 def _get_condition_tensor(self, rng, op, cond, error_name):
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002315 if error_name == ErrorIf.CondIfCondNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002316 cond_type = gtu.get_wrong_output_type(op, rng, DType.BOOL)
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002317 else:
2318 cond_type = DType.BOOL
2319 if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002320 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002321 if choice == 1:
2322 cond_shape = [2]
2323 else:
2324 cond_shape = [1, 2]
2325 else:
2326 # Must be of size 1 (rank 0)
2327 cond_shape = []
2328 cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
2329 return cond_tens
2330
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002331 def build_cond_if_const(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002332 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002333 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002334 op,
2335 inputs,
2336 args_dict,
2337 validator_fcns=None,
2338 error_name=None,
2339 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002340 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002341 # For cond_if with constants, we're supplied with then/else tensors that we ignore
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002342 # (except for the generated shape) and the condition. Build Then/Else blocks
Eric Kunzee5e26762020-10-13 16:11:07 -07002343 # and fill them with const nodes for the body.
Jeremy Johnson587cc842024-02-08 11:45:44 +00002344 assert len(inputs) == 2
2345 then_tens, else_tens = inputs
2346
2347 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002348
2349 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002350 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002351
2352 # Make then/else tensors
2353 out_shape = then_tens.shape
Matthew Haddon630c17c2021-10-14 15:05:41 +01002354
Jeremy Johnson587cc842024-02-08 11:45:44 +00002355 dtype = DType.INT32
2356
Matthew Haddon630c17c2021-10-14 15:05:41 +01002357 # Create an incorrect output shape for error_if tests
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002358 if error_name in [
2359 ErrorIf.CondIfOutputListThenGraphMismatch,
2360 ErrorIf.CondIfOutputListElseGraphMismatch,
2361 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002362 incorrect_shape = deepcopy(then_tens.shape)
2363 for i in range(len(incorrect_shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002364 incorrect_shape[i] += (
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002365 rng.choice([-3, -2, 2, 3])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002366 if incorrect_shape[i] > 3
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002367 else rng.choice([1, 2, 4])
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002368 )
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002369 incorrect_arr = np.int32(rng.integers(0, 256, size=incorrect_shape))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002370
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002371 then_arr = np.int32(rng.integers(0, 256, size=out_shape))
2372 else_arr = np.int32(rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07002373
2374 # And the result tensor based on any of the outputs
Jeremy Johnson587cc842024-02-08 11:45:44 +00002375 result_tensor = self.ser.addOutput(out_shape, dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002376
2377 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002378 then_block = "THEN_BLOCK"
2379 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002380 attr = ts.TosaSerializerAttribute()
2381 attr.CondIfAttribute(then_block, else_block)
2382
2383 # Finally, build the op and the two blocks
Jeremy Johnson587cc842024-02-08 11:45:44 +00002384 self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002385
Jerry Ge9e94af82022-10-27 09:57:00 -07002386 self.ser.addBasicBlock(then_block)
Eric Kunzee5e26762020-10-13 16:11:07 -07002387 # Build the actual then/else tensors inside their blocks
Matthew Haddon630c17c2021-10-14 15:05:41 +01002388 if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002389 then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002390 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002391 then_tens = self.ser.addConst(out_shape, dtype, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002392 self.ser.addOutputTensor(then_tens)
2393
Jerry Ge9e94af82022-10-27 09:57:00 -07002394 self.ser.addBasicBlock(else_block)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002395 if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002396 else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002397 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00002398 else_tens = self.ser.addConst(out_shape, dtype, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07002399 self.ser.addOutputTensor(else_tens)
2400
Les Bell729b0352021-11-24 10:28:21 +00002401 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002402 self.ser,
2403 validator_fcns,
2404 error_name,
2405 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002406 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002407 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002408 ):
2409 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002410
Jeremy Johnson587cc842024-02-08 11:45:44 +00002411 compliance = self.tensorComplianceMetaData(
2412 op, dtype, args_dict, result_tensor, error_name
2413 )
2414
2415 return TosaTestGen.BuildInfo(result_tensor, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002416
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002417 def build_cond_if_binary(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002418 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002419 rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002420 op,
2421 inputs,
2422 args_dict,
2423 validator_fcns=None,
2424 error_name=None,
2425 qinfo=None,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002426 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002427 # For cond_if with a binary op in the then/else blocks, take a and b and
2428 # alternately add or subtract them based on the condition
Jeremy Johnson587cc842024-02-08 11:45:44 +00002429 assert len(inputs) == 2
2430 a, b = inputs
2431
2432 cond = args_dict["condition"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002433
2434 # Condition tensor
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002435 cond_tens = self._get_condition_tensor(rng, op, cond, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002436
Jeremy Johnson587cc842024-02-08 11:45:44 +00002437 result_tensor = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002438
2439 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002440 then_block = "THEN_BLOCK"
2441 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002442 attr = ts.TosaSerializerAttribute()
2443 attr.CondIfAttribute(then_block, else_block)
2444
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002445 if error_name in [
2446 ErrorIf.CondIfInputListThenGraphMismatch,
2447 ErrorIf.CondIfInputListElseGraphMismatch,
2448 ErrorIf.CondIfOutputListElseGraphMismatch,
2449 ErrorIf.CondIfOutputListThenGraphMismatch,
2450 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002451 incorrect_shape = a.shape.copy()
2452 for i in range(len(incorrect_shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002453 incorrect_shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002454 incorrect_block_input = deepcopy(a)
2455 incorrect_block_input.shape = incorrect_shape
2456
Eric Kunzee5e26762020-10-13 16:11:07 -07002457 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08002458 self.ser.addOperator(
Jeremy Johnson587cc842024-02-08 11:45:44 +00002459 op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
Kevin Cheng550ccc52021-03-03 11:21:43 -08002460 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002461
James Ward24dbc422022-10-19 12:20:31 +01002462 if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002463 then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
Les Bell6040b4d2021-10-11 12:50:31 +01002464 elif a.dtype in (DType.INT8, DType.INT16):
Jeremy Johnson587cc842024-02-08 11:45:44 +00002465 then_op, else_op = (
2466 self.TOSA_OP_LIST["logical_right_shift"],
2467 self.TOSA_OP_LIST["logical_left_shift"],
2468 )
Les Bell6040b4d2021-10-11 12:50:31 +01002469 else:
2470 assert False, f"No tests for DType: {a.dtype}"
Eric Kunzee5e26762020-10-13 16:11:07 -07002471
Jeremy Johnson587cc842024-02-08 11:45:44 +00002472 # Determine the element-wise binary operation that compliance will need to
2473 # check the results of
2474 compliance_op = then_op if cond else else_op
2475
2476 for block, block_op in ((then_block, then_op), (else_block, else_op)):
Jerry Ge9e94af82022-10-27 09:57:00 -07002477 self.ser.addBasicBlock(block)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002478 if (
2479 error_name == ErrorIf.CondIfInputListThenGraphMismatch
2480 and block == then_block
2481 ) or (
2482 error_name == ErrorIf.CondIfInputListElseGraphMismatch
2483 and block == else_block
2484 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002485 self.ser.addInputTensor(incorrect_block_input)
2486 self.ser.addInputTensor(b)
2487 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002488 elif (
2489 error_name == ErrorIf.CondIfOutputListThenGraphMismatch
2490 and block == then_block
2491 ) or (
2492 error_name == ErrorIf.CondIfOutputListElseGraphMismatch
2493 and block == else_block
2494 ):
Matthew Haddon630c17c2021-10-14 15:05:41 +01002495 self.ser.addInputTensor(a)
2496 self.ser.addInputTensor(b)
2497 tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
2498 else:
2499 self.ser.addInputTensor(a)
2500 self.ser.addInputTensor(b)
2501 tens = self.ser.addOutput(a.shape, a.dtype)
Jeremy Johnson587cc842024-02-08 11:45:44 +00002502 self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002503
Les Bell729b0352021-11-24 10:28:21 +00002504 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002505 self.ser,
2506 validator_fcns,
2507 error_name,
2508 op=op,
2509 a=a,
2510 b=b,
Jerry Ge9e94af82022-10-27 09:57:00 -07002511 basicBlocks=self.ser.currRegion.basicBlocks,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002512 cond=cond_tens,
Les Bell729b0352021-11-24 10:28:21 +00002513 ):
2514 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002515
Jeremy Johnson587cc842024-02-08 11:45:44 +00002516 compliance = self.tensorComplianceMetaData(
2517 compliance_op, a.dtype, args_dict, result_tensor, error_name
2518 )
Eric Kunzee5e26762020-10-13 16:11:07 -07002519
Jeremy Johnson587cc842024-02-08 11:45:44 +00002520 return TosaTestGen.BuildInfo(result_tensor, compliance)
2521
2522 def build_while_loop(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002523 self,
2524 rng,
2525 op,
2526 inputs,
2527 args_dict,
2528 validator_fcns=None,
2529 error_name=None,
2530 qinfo=None,
Jeremy Johnson587cc842024-02-08 11:45:44 +00002531 ):
2532 assert len(inputs) == 1
2533 a = inputs[0]
2534 iter_val = args_dict["iterations"]
2535
Kevin Cheng550ccc52021-03-03 11:21:43 -08002536 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07002537
Kevin Cheng550ccc52021-03-03 11:21:43 -08002538 cond_block = "COND_BLOCK"
2539 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07002540
2541 attr = ts.TosaSerializerAttribute()
2542 attr.WhileLoopAttribute(cond_block, body_block)
2543
2544 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08002545 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002546 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08002547 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07002548
2549 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002550 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2551 a_out = self.ser.addIntermediate(a.shape, a.dtype)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002552 if error_name == ErrorIf.InputListOutputListMismatch:
2553 incorrect_acc = deepcopy(acc)
2554 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002555 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002556 acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
2557 else:
2558 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002559
2560 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08002561 self.ser.addOperator(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002562 op["op"],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002563 [iter.name, a.name, acc.name],
2564 [iter_out.name, a_out.name, acc_out.name],
2565 attr,
2566 )
Kevin Chengb227ae52021-09-02 13:43:17 -07002567 self.ser.addOutputTensor(acc_out)
Eric Kunzee5e26762020-10-13 16:11:07 -07002568
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002569 if error_name in [
2570 ErrorIf.InputListCondGraphMismatch,
2571 ErrorIf.InputListBodyGraphInputMismatch,
2572 ErrorIf.InputListBodyGraphOutputMismatch,
2573 ]:
Matthew Haddon630c17c2021-10-14 15:05:41 +01002574 incorrect_iter = deepcopy(iter)
2575 for i in range(len(incorrect_iter.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002576 incorrect_iter.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002577 if len(incorrect_iter.shape) == 0:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002578 incorrect_iter.shape.append(rng.choice([-3, -2, 2, 3]))
Matthew Haddon630c17c2021-10-14 15:05:41 +01002579
2580 incorrect_acc = deepcopy(acc)
2581 for i in range(len(incorrect_acc.shape)):
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002582 incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002583
Eric Kunzee5e26762020-10-13 16:11:07 -07002584 # COND block (input: iter, output: cond_tens )
Jerry Ge9e94af82022-10-27 09:57:00 -07002585 self.ser.addBasicBlock(cond_block)
2586
Matthew Haddon630c17c2021-10-14 15:05:41 +01002587 if error_name == ErrorIf.InputListCondGraphMismatch:
2588 self.ser.addInputTensor(incorrect_iter)
2589 self.ser.addInputTensor(a)
2590 self.ser.addInputTensor(incorrect_acc)
2591 else:
2592 self.ser.addInputTensor(iter)
2593 self.ser.addInputTensor(a)
2594 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002595 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002596
2597 if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002598 cond_type = rng.choice([DType.INT8, DType.INT32, DType.FP32])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002599 else:
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002600 cond_type = DType.BOOL
2601 if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002602 choice = rng.choice([1, 2])
Jeremy Johnson05c711e2022-12-12 18:00:41 +00002603 if choice == 1:
2604 cond_shape = [3]
2605 else:
2606 cond_shape = [1, 2]
2607 else:
2608 cond_shape = []
2609 cond_tens = self.ser.addOutput(cond_shape, cond_type)
Matthew Haddon630c17c2021-10-14 15:05:41 +01002610
Kevin Cheng550ccc52021-03-03 11:21:43 -08002611 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07002612
2613 # BODY block (input: a, acc, iter, output: a, acc, iter)
2614 # Note that local intermediate tensors need to be declared here for the outputs
Jerry Ge9e94af82022-10-27 09:57:00 -07002615 self.ser.addBasicBlock(body_block)
2616
Matthew Haddon630c17c2021-10-14 15:05:41 +01002617 if error_name == ErrorIf.InputListBodyGraphInputMismatch:
2618 self.ser.addInputTensor(incorrect_iter)
2619 self.ser.addInputTensor(a)
2620 self.ser.addInputTensor(incorrect_acc)
2621 else:
2622 self.ser.addInputTensor(iter)
2623 self.ser.addInputTensor(a)
2624 self.ser.addInputTensor(acc)
2625
Kevin Cheng550ccc52021-03-03 11:21:43 -08002626 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
Matthew Haddon630c17c2021-10-14 15:05:41 +01002627
2628 if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002629 iter_body_out = self.ser.addIntermediate(
2630 incorrect_iter.shape, incorrect_iter.dtype
2631 )
2632 acc_body_out = self.ser.addIntermediate(
2633 incorrect_acc.shape, incorrect_acc.dtype
2634 )
Matthew Haddon630c17c2021-10-14 15:05:41 +01002635 else:
2636 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
2637 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
2638
Eric Kunzee5e26762020-10-13 16:11:07 -07002639 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
2640 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
2641 self.ser.addOutputTensor(iter_body_out)
2642 self.ser.addOutputTensor(a)
2643 self.ser.addOutputTensor(acc_body_out)
2644
Les Bell729b0352021-11-24 10:28:21 +00002645 if not TosaErrorValidator.evValidateErrorIfs(
Matthew Haddon630c17c2021-10-14 15:05:41 +01002646 self.ser,
2647 validator_fcns,
2648 error_name,
2649 op=op,
Jerry Ge9e94af82022-10-27 09:57:00 -07002650 basicBlocks=self.ser.currRegion.basicBlocks,
Les Bell729b0352021-11-24 10:28:21 +00002651 ):
2652 return None
Matthew Haddon630c17c2021-10-14 15:05:41 +01002653
Jeremy Johnson587cc842024-02-08 11:45:44 +00002654 compliance = self.tensorComplianceMetaData(
2655 op, a.dtype, args_dict, acc_out, error_name
2656 )
2657
2658 return TosaTestGen.BuildInfo(acc_out, compliance)
Eric Kunzee5e26762020-10-13 16:11:07 -07002659
Luke Hutton57287132023-02-06 14:54:18 +00002660 def build_fft2d(
Tai Lyd3797f02023-11-15 23:06:19 +00002661 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002662 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002663 op,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002664 inputs,
2665 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002666 validator_fcns=None,
2667 error_name=None,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002668 qinfo=None,
Luke Hutton57287132023-02-06 14:54:18 +00002669 ):
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002670 assert len(inputs) == 2
2671 val1, val2 = inputs
2672 inverse = args_dict["inverse"]
2673
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002674 results = OutputShaper.fft2dOp(self.ser, rng, val1, val2, error_name)
Luke Hutton57287132023-02-06 14:54:18 +00002675
2676 input_names = [val1.name, val2.name]
2677 pCount, cCount = op["operands"]
2678 num_operands = pCount + cCount
2679
2680 output_names = [res.name for res in results]
2681 output_shapes = [res.shape for res in results]
2682 output_dtypes = [res.dtype for res in results]
2683
2684 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002685 rng, error_name, input_names, output_names
Luke Hutton57287132023-02-06 14:54:18 +00002686 )
2687
2688 if not TosaErrorValidator.evValidateErrorIfs(
2689 self.ser,
2690 validator_fcns,
2691 error_name,
2692 op=op,
2693 inverse=inverse,
2694 input1=val1,
2695 input2=val2,
2696 input_shape=val1.shape,
2697 input_dtype=val1.dtype,
2698 output_shape=output_shapes,
2699 output_dtype=output_dtypes,
2700 result_tensors=results,
2701 input_list=input_names,
2702 output_list=output_names,
2703 num_operands=num_operands,
2704 ):
2705 return None
2706
Tai Lyd3797f02023-11-15 23:06:19 +00002707 # TODO - Test local_bound, for now set local bound attribute to False
2708 local_bound = False
2709
Luke Hutton57287132023-02-06 14:54:18 +00002710 attr = ts.TosaSerializerAttribute()
Tai Lyd3797f02023-11-15 23:06:19 +00002711 attr.FFTAttribute(inverse, local_bound)
Luke Hutton57287132023-02-06 14:54:18 +00002712
2713 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002714
2715 compliance = []
2716 for res in results:
2717 compliance.append(
2718 self.tensorComplianceMetaData(
2719 op, val1.dtype, args_dict, res, error_name
2720 )
2721 )
2722
2723 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton57287132023-02-06 14:54:18 +00002724
Tai Lyd3797f02023-11-15 23:06:19 +00002725 def build_rfft2d(
2726 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002727 rng,
Tai Lyd3797f02023-11-15 23:06:19 +00002728 op,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002729 inputs,
2730 args_dict,
Tai Lyd3797f02023-11-15 23:06:19 +00002731 validator_fcns=None,
2732 error_name=None,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002733 qinfo=None,
Tai Lyd3797f02023-11-15 23:06:19 +00002734 ):
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002735 assert len(inputs) == 1
2736 val = inputs[0]
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002737 results = OutputShaper.rfft2dOp(self.ser, rng, val, error_name)
Luke Hutton261b7b62023-01-10 14:50:31 +00002738
2739 input_names = [val.name]
2740 pCount, cCount = op["operands"]
2741 num_operands = pCount + cCount
2742
2743 output_names = [res.name for res in results]
Luke Hutton57287132023-02-06 14:54:18 +00002744 output_shapes = [res.shape for res in results]
Luke Hutton261b7b62023-01-10 14:50:31 +00002745 output_dtypes = [res.dtype for res in results]
2746
2747 input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002748 rng, error_name, input_names, output_names
Luke Hutton261b7b62023-01-10 14:50:31 +00002749 )
2750
2751 if not TosaErrorValidator.evValidateErrorIfs(
2752 self.ser,
2753 validator_fcns,
2754 error_name,
2755 op=op,
2756 input_shape=val.shape,
2757 input_dtype=val.dtype,
Luke Hutton57287132023-02-06 14:54:18 +00002758 output_shape=output_shapes,
Luke Hutton261b7b62023-01-10 14:50:31 +00002759 output_dtype=output_dtypes,
2760 result_tensors=results,
2761 input_list=input_names,
2762 output_list=output_names,
2763 num_operands=num_operands,
2764 ):
2765 return None
2766
Tai Lyd3797f02023-11-15 23:06:19 +00002767 # TODO - Test local_bound, for now set local bound attribute to False
2768 local_bound = False
2769
2770 attr = ts.TosaSerializerAttribute()
2771 attr.RFFTAttribute(local_bound)
2772
2773 self.ser.addOperator(op["op"], input_names, output_names, attr)
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002774
2775 compliance = []
2776 for res in results:
2777 compliance.append(
2778 self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
2779 )
2780
2781 return TosaTestGen.BuildInfo(results, compliance)
Luke Hutton261b7b62023-01-10 14:50:31 +00002782
Won Jeon74342e52024-01-09 00:34:40 +00002783 def build_shape_op(
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002784 self,
2785 rng,
2786 op,
2787 inputs,
2788 args_dict,
2789 validator_fcns=None,
2790 error_name=None,
2791 qinfo=None,
Won Jeon74342e52024-01-09 00:34:40 +00002792 ):
2793 assert len(inputs) == 2
2794 a, b = inputs
2795
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002796 result_tensor = OutputShaper.addShapeOp(self.ser, rng, a, b, error_name)
Won Jeon74342e52024-01-09 00:34:40 +00002797
2798 # Invalidate Input/Output list for error if checks.
2799 input_list = [a.name, b.name]
2800 output_list = [result_tensor.name]
2801 pCount, cCount = op["operands"]
2802 num_operands = pCount + cCount
2803 input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
2804 self, error_name, input_list, output_list
2805 )
2806
2807 if not TosaErrorValidator.evValidateErrorIfs(
2808 self.ser,
2809 validator_fcns,
2810 error_name,
2811 op=op,
2812 input1=a,
2813 input2=b,
2814 input_shape=a.shape,
2815 input_dtype=a.dtype,
2816 output_shape=result_tensor.shape,
2817 output_dtype=result_tensor.dtype,
2818 result_tensors=[result_tensor],
2819 input_list=input_list,
2820 output_list=output_list,
2821 num_operands=num_operands,
2822 ):
2823 return None
2824
2825 self.ser.addOperator(
2826 op["op"],
2827 input_list,
2828 output_list,
2829 )
2830 compliance = self.tensorComplianceMetaData(
2831 op, a.dtype, args_dict, result_tensor, error_name
2832 )
2833
2834 return TosaTestGen.BuildInfo(result_tensor, compliance)
2835
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002836 def create_filter_lists(
2837 self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
2838 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002839 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
2840 default_test_rank_range = range(1, 5)
2841 if not shapeFilter:
2842 shapeFilter = [None]
2843
2844 # Calculate the filters based on what is requested and what the operator allows
2845 rmin, rmax = op["rank"]
2846 if rankFilter is not None:
2847 cleanRankFilter = []
2848 # Ensure rankFilter values are allowed by operator
2849 for rank in rankFilter:
2850 if rank >= rmin and rank <= rmax:
2851 cleanRankFilter.append(rank)
2852 elif rankFilter is None and shapeFilter[0] is None:
Jeremy Johnson03bec732021-10-07 12:06:00 +01002853 # Ensure default behaviour is bounded by default range or by operator,
2854 # whichever is the smaller range of ranks.
2855 opRankRange = range(rmin, rmax + 1)
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002856 cleanRankFilter = (
2857 opRankRange
2858 if len(opRankRange) <= len(default_test_rank_range)
2859 else default_test_rank_range
2860 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002861 else:
2862 cleanRankFilter = range(rmin, rmax + 1)
2863
2864 dtypes = op["types"]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01002865
Matthew Haddon1c00b712021-10-01 15:51:03 +01002866 if dtypeFilter is not None:
2867 cleanDtypeFilter = []
Jeremy Johnson03bec732021-10-07 12:06:00 +01002868 # Create list of operator dtypes filtered by requested dtypes
2869 for dtype in dtypes:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002870 if dtype in dtypeFilter or (
2871 isinstance(dtype, list) and dtype[0] in dtypeFilter
2872 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01002873 cleanDtypeFilter.append(dtype)
2874 else:
2875 cleanDtypeFilter = dtypes
2876
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002877 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01002878 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002879 "shapeFilter": shapeFilter,
2880 "rankFilter": cleanRankFilter,
2881 "dtypeFilter": cleanDtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002882 }
2883 return filterDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002884 elif testType == "negative":
Matthew Haddone807aae2021-10-11 18:12:58 +01002885 if validator is not None:
2886 validator_info = validator(check=False, op=op)
2887 else:
2888 return None
2889
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002890 error_arguments = validator_info["param_reqs"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002891
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002892 # Set parameters as required
2893 if error_arguments["rank"] is not None:
2894 rankFilter = error_arguments["rank"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002895 else:
2896 rankFilter = cleanRankFilter
2897
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002898 if error_arguments["dtype"] is not None:
2899 dtypeFilter = error_arguments["dtype"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002900 else:
2901 dtypeFilter = cleanDtypeFilter
2902
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002903 if error_arguments["shape"] is not None:
2904 shapeFilter = error_arguments["shape"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002905 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002906 shapeFilter = shapeFilter[
2907 :2
2908 ] # Reduce number of shapes to keep test numbers small
Matthew Haddon1c00b712021-10-01 15:51:03 +01002909
2910 filterDict = {
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002911 "shapeFilter": shapeFilter,
2912 "rankFilter": rankFilter,
2913 "dtypeFilter": dtypeFilter,
Matthew Haddon1c00b712021-10-01 15:51:03 +01002914 }
2915 return filterDict
2916
Kevin Cheng550ccc52021-03-03 11:21:43 -08002917 def genOpTestList(
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002918 self,
2919 opName,
2920 shapeFilter=[None],
2921 rankFilter=None,
2922 dtypeFilter=None,
2923 testType="positive",
Kevin Cheng550ccc52021-03-03 11:21:43 -08002924 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002925
2926 try:
2927 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002928 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002929 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07002930
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002931 if not self.args.stable_rng:
2932 # Initialize a new random number generator per op
2933 self.resetGlobalRNG()
Eric Kunzee5e26762020-10-13 16:11:07 -07002934
Jeremy Johnson1271c442023-09-05 11:39:26 +01002935 _, tgen_fcn, _, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07002936
Eric Kunzee5e26762020-10-13 16:11:07 -07002937 # Test list consists of a tuple of:
2938 # (opName, testNameStr, dtype, shapeList, argumentsList)
2939 testList = []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002940 if testType == "negative" and "error_if_validators" in op:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002941 error_if_validators = op["error_if_validators"]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002942 num_error_types_created = 0
Matthew Haddon1c00b712021-10-01 15:51:03 +01002943 else:
2944 error_if_validators = [None]
Jeremy Johnsondd975b82024-02-28 17:29:13 +00002945 num_error_types_created = None
Eric Kunzee5e26762020-10-13 16:11:07 -07002946
Matthew Haddon1c00b712021-10-01 15:51:03 +01002947 for validator in error_if_validators:
2948 if validator is not None:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002949 error_name = validator(check=False, op=op)["error_name"]
Matthew Haddon1c00b712021-10-01 15:51:03 +01002950 else:
2951 error_name = None
2952
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002953 filterDict = self.create_filter_lists(
2954 op, shapeFilter, rankFilter, dtypeFilter, testType, validator
2955 )
2956 if filterDict is None:
Matthew Haddone807aae2021-10-11 18:12:58 +01002957 return []
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00002958 cleanRankFilter = filterDict["rankFilter"]
2959 cleanDtypeFilter = filterDict["dtypeFilter"]
2960 cleanShapeFilter = filterDict["shapeFilter"]
Jeremy Johnsonaf090182024-02-13 18:25:39 +00002961 logger.debug(
2962 f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
2963 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01002964
2965 for r in cleanRankFilter:
Matthew Haddon1c00b712021-10-01 15:51:03 +01002966 for t in cleanDtypeFilter:
2967 for shape in cleanShapeFilter:
Matthew Haddon74567092021-07-16 15:38:20 +01002968 # Filter out by rank
2969 if shape is not None and len(shape) != r:
2970 continue
Matthew Haddon74567092021-07-16 15:38:20 +01002971 self.setTargetShape(shape)
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002972 typeStr = self.typeStr(t)
2973 if self.args.stable_rng:
2974 shape_rng = TosaHashRandomGenerator(
2975 self.random_seed,
2976 [opName, r, typeStr],
2977 self.random_dtype_range,
2978 )
2979 else:
2980 shape_rng = self.global_rng
2981 shapeList = tgen_fcn(self, shape_rng, op, r, error_name)
Eric Kunzee5e26762020-10-13 16:11:07 -07002982
Matthew Haddon74567092021-07-16 15:38:20 +01002983 shapeStr = self.shapeStr(shapeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07002984
Matthew Haddon74567092021-07-16 15:38:20 +01002985 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
2986 argList = []
2987 if agen_fcn:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01002988 if self.args.stable_rng:
2989 arg_rng = TosaHashRandomGenerator(
2990 self.random_seed,
2991 [opName, shapeStr, typeStr],
2992 self.random_dtype_range,
2993 )
2994 else:
2995 arg_rng = self.global_rng
2996
2997 argList = agen_fcn(
2998 self, arg_rng, opName, shapeList, t, error_name
2999 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003000 else:
Matthew Haddon74567092021-07-16 15:38:20 +01003001 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07003002
Matthew Haddon74567092021-07-16 15:38:20 +01003003 for argStr, args in argList:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003004 # Create the test name string - for example: add_1x2x3_i32
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003005 if testType == "positive":
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003006 name_parts = [opName, shapeStr, typeStr]
3007 else:
3008 assert testType == "negative"
3009 name_parts = [
3010 opName,
3011 "ERRORIF",
3012 error_name,
3013 shapeStr,
3014 typeStr,
3015 ]
3016 if argStr:
3017 name_parts.append(argStr)
3018 testStr = "_".join(name_parts)
Matthew Haddon1c00b712021-10-01 15:51:03 +01003019
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003020 testList.append(
3021 (opName, testStr, t, error_name, shapeList, args)
3022 )
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003023 if error_name is not None:
3024 # Check the last test is of the error we wanted
3025 if len(testList) == 0 or testList[-1][3] != error_name:
3026 if self.args.level8k:
3027 logger.info(f"Missing {error_name} tests due to level8k mode")
3028 else:
3029 logger.error(f"ERROR: Failed to create any {error_name} tests")
3030 logger.debug(
3031 "Last test created: {}".format(
3032 testList[-1] if testList else None
3033 )
3034 )
3035 else:
3036 # Successfully created at least one ERRROR_IF test
3037 num_error_types_created += 1
Matthew Haddon1c00b712021-10-01 15:51:03 +01003038
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003039 if testType == "positive":
Matthew Haddon1c00b712021-10-01 15:51:03 +01003040 # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
3041 if "invalid_test_validators" in op:
3042 invalid_test_validators = op["invalid_test_validators"]
3043 clean_testList = []
3044 for test in testList:
Jeremy Johnson0c716862023-04-13 17:18:19 +01003045 remove_test = False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003046 for validator_fcn in invalid_test_validators:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003047 if validator_fcn(
3048 opName=test[0],
3049 input_dtype=test[2],
3050 shapeList=test[4],
3051 args=test[5],
3052 ):
Matthew Haddon1c00b712021-10-01 15:51:03 +01003053 remove_test = True
3054 if not remove_test:
3055 clean_testList.append(test)
3056 testList = clean_testList
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003057 else:
3058 if num_error_types_created is not None and not self.args.level8k:
3059 remaining_error_types = (
3060 len(error_if_validators) - num_error_types_created
3061 )
3062 if remaining_error_types:
3063 raise Exception(
3064 f"Failed to create {remaining_error_types} error types for {opName}"
3065 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003066
3067 return testList
3068
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003069 def serializeTest(
Jeremy Johnson587cc842024-02-08 11:45:44 +00003070 self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003071 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07003072 try:
3073 op = self.TOSA_OP_LIST[opName]
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003074 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003075 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07003076
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003077 logger.info(f"Creating {testStr}")
Jeremy Johnson0c716862023-04-13 17:18:19 +01003078
Eric Kunzee5e26762020-10-13 16:11:07 -07003079 # Create a serializer
3080 self.createSerializer(opName, testStr)
3081
Jeremy Johnson1271c442023-09-05 11:39:26 +01003082 build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
Matthew Haddone86fd342021-09-07 16:12:21 +01003083 if "error_if_validators" in op:
3084 error_if_validators = op["error_if_validators"]
3085 else:
3086 error_if_validators = None
3087
Kevin Cheng550ccc52021-03-03 11:21:43 -08003088 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07003089 num_operands = pCount + cCount
3090
3091 if isinstance(dtype_or_dtypeList, list):
3092 dtypeList = dtype_or_dtypeList
Won Jeon74342e52024-01-09 00:34:40 +00003093 elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003094 dtypeList = [dtype_or_dtypeList] * len(shapeList)
Kevin Cheng989cb052021-04-28 16:29:44 -07003095 else:
3096 dtypeList = [dtype_or_dtypeList] * (num_operands)
3097
Won Jeon74342e52024-01-09 00:34:40 +00003098 if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
Matthew Haddon818ab902021-07-27 09:12:49 +01003099 assert (
3100 len(shapeList) == num_operands
3101 ), "shapeList length {} must match number of operands {}".format(
3102 len(shapeList), num_operands
3103 )
3104 assert (
3105 len(dtypeList) == num_operands
3106 ), "dtypeList length {} must match number of operands {}".format(
3107 len(dtypeList), num_operands
3108 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003109
3110 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003111 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003112 except KeyError:
3113 qgen = None
3114
3115 # Build the random tensor operands and the test
Kevin Chengaee1fac2020-11-11 13:54:06 -08003116
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003117 # Set the random number generator
3118 if self.args.stable_rng:
3119 build_rng = TosaHashRandomGenerator(
3120 self.random_seed, [testStr], self.random_dtype_range
3121 )
3122 else:
3123 build_rng = self.global_rng
3124
Matthew Haddon1c00b712021-10-01 15:51:03 +01003125 if qgen is not None:
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003126 qinfo = qgen(
3127 build_rng, self.args.zeropoint, op, dtype_or_dtypeList, error_name
3128 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003129 else:
3130 qinfo = None
3131
Jeremy Johnson1271c442023-09-05 11:39:26 +01003132 # Extra meta data for the desc.json
3133 tensMeta = {}
3134
Jeremy Johnson587cc842024-02-08 11:45:44 +00003135 # Check we are using the new interface with an argsDict dictionary
3136 assert isinstance(
3137 argsDict, dict
3138 ), f"{opName} is not using new tvg/build_fcn interface"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003139
Jeremy Johnson587cc842024-02-08 11:45:44 +00003140 # New interface with args info in dictionary
3141 assert "dg_type" in argsDict
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003142 tvgInfo = tvgen_fcn(
3143 self, build_rng, opName, dtypeList, shapeList, argsDict, error_name
3144 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003145 if tvgInfo.dataGenDict:
3146 tensMeta["data_gen"] = tvgInfo.dataGenDict
3147 tens = tvgInfo.tensorList
Jeremy Johnson81ee53d2022-03-23 15:32:34 +00003148
Jeremy Johnson587cc842024-02-08 11:45:44 +00003149 result = build_fcn(
3150 self,
Jeremy Johnson0a6d1de2023-09-27 14:59:43 +01003151 build_rng,
Jeremy Johnson587cc842024-02-08 11:45:44 +00003152 op,
3153 tens,
3154 argsDict,
3155 validator_fcns=error_if_validators,
3156 error_name=error_name,
3157 qinfo=qinfo,
3158 )
Matthew Haddon1c00b712021-10-01 15:51:03 +01003159
Jeremy Johnson1271c442023-09-05 11:39:26 +01003160 if result:
Les Bell729b0352021-11-24 10:28:21 +00003161 # The test is valid, serialize it
Jeremy Johnsonc8330812024-01-18 16:57:28 +00003162 if isinstance(result, TosaTestGen.BuildInfo):
3163 # Add the compliance meta data (if any)
3164 compliance = result.getComplianceInfo()
3165 if compliance:
3166 tensMeta["compliance"] = compliance
Jeremy Johnson1271c442023-09-05 11:39:26 +01003167 self.serialize("test", tensMeta)
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003168 return True
Les Bell729b0352021-11-24 10:28:21 +00003169 else:
3170 # The test is not valid
Jeremy Johnsonaf090182024-02-13 18:25:39 +00003171 logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003172 return False
Matthew Haddon1c00b712021-10-01 15:51:03 +01003173
Eric Kunzee5e26762020-10-13 16:11:07 -07003174 def createDynamicOpLists(self):
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003175 # Find all the ops marked as templates
3176 templateKeys = []
3177 for opName in self.TOSA_OP_LIST:
Eric Kunzee5e26762020-10-13 16:11:07 -07003178 try:
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003179 if self.TOSA_OP_LIST[opName]["template"]:
3180 templateKeys.append(opName)
Eric Kunzee5e26762020-10-13 16:11:07 -07003181 except KeyError:
3182 pass
3183
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003184 bigK = self.TOSA_8K_LEVEL_MAX_KERNEL
3185
3186 # Add dynamic ops based on kernel sizes
3187 for opName in templateKeys:
3188 assert opName.endswith("_TEMPLATE"), "Found incorrect template"
3189 realName = opName[: len(opName) - len("_TEMPLATE")]
3190 template = self.TOSA_OP_LIST[opName]
3191 k_rank = 3 if realName == "conv3d" else 2
3192
3193 # Choose kernels to build tests for from the template or args
3194 if self.args.level8k:
3195 if k_rank == 3:
3196 kernels = [[1, bigK, 1], [2, 2, bigK]]
3197 else:
3198 kernels = [[1, bigK], [bigK, 2]]
3199 else:
3200 kernels = []
3201 if len(self.args.conv_kernels) > 0:
3202 kernels = [k for k in self.args.conv_kernels if len(k) == k_rank]
3203 if len(kernels) == 0:
3204 logger.debug(
3205 f"{realName} op using defaults as no rank {k_rank} kernels found in {self.args.conv_kernels}"
3206 )
3207 if len(kernels) == 0:
3208 # Fallback to use the defined template kernels
3209 kernels = self.TOSA_OP_LIST[opName]["filter"]
3210
3211 # Dynamically create ops for listed kernel sizes
3212 for k in kernels:
3213 kernelStr = "x".join([str(d) for d in k])
3214 testName = f"{realName}_{kernelStr}"
3215 kernelOp = template.copy()
3216 kernelOp["filter"] = k
3217 kernelOp["template"] = False
3218 kernelOp["real_name"] = realName
3219 self.TOSA_OP_LIST[testName] = kernelOp
3220
3221 # Delete the template after having created the dynamic ops
3222 del self.TOSA_OP_LIST[opName]
Eric Kunzee5e26762020-10-13 16:11:07 -07003223
3224 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003225 """Fill in default fields for ops if they aren't already specified.
3226 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07003227 for op in self.TOSA_OP_LIST:
3228
3229 # Required fields
3230 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003231 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003232 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003233 raise Exception(
3234 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
3235 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003236
3237 try:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003238 fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003239 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08003240 raise Exception(
3241 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
3242 op
3243 )
3244 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003245
3246 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003247 _ = self.TOSA_OP_LIST[op]["types"]
3248 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003249 raise Exception(
3250 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
3251 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003252
3253 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003254 _ = self.TOSA_OP_LIST[op]["op"]
3255 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003256 raise Exception(
3257 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
3258 )
Eric Kunzee5e26762020-10-13 16:11:07 -07003259
3260 # Put in default rank range, if missing
3261 try:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003262 _ = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07003263 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08003264 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07003265
3266 # Tensor operator list
3267 # 'op': op name
3268 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08003269 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
3270 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07003271 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
3272 # 'types': array of datatypes to be tested
James Ward24dbc422022-10-19 12:20:31 +01003273 TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003274
Kevin Cheng550ccc52021-03-03 11:21:43 -08003275 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
James Ward8b390432022-08-12 20:48:56 +01003276 TYPE_INT_FP = [
3277 DType.INT8,
3278 DType.INT16,
3279 DType.INT32,
3280 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003281 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003282 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003283 ] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07003284
Kevin Cheng550ccc52021-03-03 11:21:43 -08003285 TYPE_BOOL = [DType.BOOL]
James Ward24dbc422022-10-19 12:20:31 +01003286 TYPE_FI32 = [
3287 DType.FP32,
3288 DType.FP16,
3289 DType.BF16,
3290 DType.INT32,
3291 ] # floating-types and INT32
James Ward8b390432022-08-12 20:48:56 +01003292 TYPE_FIB = [
3293 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01003294 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003295 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01003296 DType.INT8,
3297 DType.INT16,
3298 DType.INT32,
3299 DType.BOOL,
3300 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003301 TYPE_FI16 = [DType.FP32, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07003302
Won Jeon2c34b462024-02-06 18:37:00 +00003303 TYPE_NARROW_INT_FP = [
3304 DType.INT8,
3305 DType.INT16,
3306 DType.FP16,
3307 DType.BF16,
3308 DType.FP32,
3309 ]
Eric Kunzee5e26762020-10-13 16:11:07 -07003310
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003311 # List of [Input Type 1, Input Type 2, Accumulator Type]
Kevin Cheng1533b852021-09-01 12:51:58 -07003312 TYPE_CONV = [
Kevin Chenga9017402021-07-28 17:19:23 -07003313 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07003314 [DType.INT8, DType.INT8, DType.INT32],
3315 [DType.INT16, DType.INT8, DType.INT48],
James Ward8b390432022-08-12 20:48:56 +01003316 [DType.FP16, DType.FP16, DType.FP16],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003317 [DType.FP16, DType.FP16, DType.FP32],
James Ward24dbc422022-10-19 12:20:31 +01003318 [DType.BF16, DType.BF16, DType.FP32],
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003319 [DType.FP32, DType.FP32, DType.FP32],
Won Jeon2c34b462024-02-06 18:37:00 +00003320 [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
3321 [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
Kevin Cheng989cb052021-04-28 16:29:44 -07003322 ]
3323
Jeremy Johnsondd975b82024-02-28 17:29:13 +00003324 DEFAULT_RANK_RANGE = (1, gtu.MAX_TENSOR_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07003325
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003326 KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
3327 KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
3328
Eric Kunzee5e26762020-10-13 16:11:07 -07003329 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08003330 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08003331 "argmax": {
3332 "op": Op.ARGMAX,
3333 "operands": (1, 0),
Jerry Ge0bd4ec82023-05-01 18:36:43 +00003334 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003335 "build_fcn": (
3336 build_argmax,
3337 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003338 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003339 TosaArgGen.agAxis,
3340 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003341 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003342 "error_if_validators": (
3343 TosaErrorValidator.evAxisSmallerZero,
3344 TosaErrorValidator.evAxisLargerRank,
3345 TosaErrorValidator.evArgmaxOutputRankMismatch,
3346 TosaErrorValidator.evArgmaxOutputShapeMismatch,
3347 TosaErrorValidator.evWrongRank,
3348 TosaErrorValidator.evWrongInputType,
3349 TosaErrorValidator.evWrongOutputType,
3350 TosaErrorValidator.evWrongInputList,
3351 TosaErrorValidator.evWrongOutputList,
3352 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00003353 "data_gen": {
3354 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3355 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003356 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003357 "avg_pool2d": {
3358 "op": Op.AVG_POOL2D,
3359 "operands": (1, 0),
3360 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003361 "build_fcn": (
3362 build_pool2d,
3363 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003364 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003365 TosaArgGen.agPooling,
3366 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003367 "qgen": TosaQuantGen.qgUnary,
Won Jeon2c34b462024-02-06 18:37:00 +00003368 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003369 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003370 "error_if_validators": (
3371 TosaErrorValidator.evKernelSmallerOne,
3372 TosaErrorValidator.evStrideSmallerOne,
3373 TosaErrorValidator.evPadSmallerZero,
3374 TosaErrorValidator.evWrongRank,
3375 TosaErrorValidator.evWrongInputType,
3376 TosaErrorValidator.evWrongOutputType,
3377 TosaErrorValidator.evWrongInputList,
3378 TosaErrorValidator.evWrongOutputList,
3379 TosaErrorValidator.evInputZeroPointNotZero,
3380 TosaErrorValidator.evOutputZeroPointNotZero,
3381 TosaErrorValidator.evPadLargerEqualKernel,
3382 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003383 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00003384 TosaErrorValidator.evWrongAccumulatorType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003385 ),
Jeremy Johnson0601f802023-11-08 16:28:09 +00003386 "data_gen": {
3387 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3388 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003389 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003390 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003391 "conv2d_TEMPLATE": {
3392 "op": Op.CONV2D,
3393 "operands": (1, 2),
3394 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003395 "build_fcn": (
3396 build_conv2d,
3397 TosaTensorGen.tgConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003398 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003399 TosaArgGen.agConv,
3400 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003401 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003402 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003403 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3404 "error_if_validators": (
3405 TosaErrorValidator.evWrongInputType,
3406 TosaErrorValidator.evWrongOutputType,
3407 TosaErrorValidator.evWrongInputList,
3408 TosaErrorValidator.evWrongOutputList,
3409 TosaErrorValidator.evInputZeroPointNotZero,
3410 TosaErrorValidator.evWeightZeroPointNotZero,
3411 TosaErrorValidator.evPadSmallerZero,
3412 TosaErrorValidator.evStrideSmallerOne,
3413 TosaErrorValidator.evDilationSmallerOne,
3414 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003415 TosaErrorValidator.evConvOutputShapeMismatch,
3416 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003417 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003418 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003419 "data_gen": {
3420 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3421 },
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003422 "broadcastable_bias": True,
3423 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003424 "template": True,
3425 },
Kevin Cheng1533b852021-09-01 12:51:58 -07003426 # Templated operator. Filled in by createDynamicOpLists
3427 "conv3d_TEMPLATE": {
3428 "op": Op.CONV3D,
3429 "operands": (1, 2),
3430 "rank": (5, 5),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003431 "build_fcn": (
3432 build_conv3d,
3433 TosaTensorGen.tgConv3D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003434 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003435 TosaArgGen.agConv,
3436 ),
Kevin Cheng1533b852021-09-01 12:51:58 -07003437 "qgen": TosaQuantGen.qgConv,
3438 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003439 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3440 "error_if_validators": (
3441 TosaErrorValidator.evWrongInputType,
3442 TosaErrorValidator.evWrongOutputType,
3443 TosaErrorValidator.evWrongInputList,
3444 TosaErrorValidator.evWrongOutputList,
3445 TosaErrorValidator.evInputZeroPointNotZero,
3446 TosaErrorValidator.evWeightZeroPointNotZero,
3447 TosaErrorValidator.evPadSmallerZero,
3448 TosaErrorValidator.evStrideSmallerOne,
3449 TosaErrorValidator.evDilationSmallerOne,
3450 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003451 TosaErrorValidator.evConvOutputShapeMismatch,
3452 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003453 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003454 ),
evacha0147ab1762024-01-29 13:23:23 +00003455 "data_gen": {
3456 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3457 },
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003458 "filter": KERNELS_3D,
Kevin Cheng1533b852021-09-01 12:51:58 -07003459 "template": True,
3460 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003461 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003462 "depthwise_conv2d_TEMPLATE": {
3463 "op": Op.DEPTHWISE_CONV2D,
3464 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003465 "rank": (4, 4),
3466 "build_fcn": (
3467 build_depthwise_conv2d,
3468 TosaTensorGen.tgDepthwiseConv2D,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01003469 TosaTensorValuesGen.tvgLazyGenDefault,
Les Bell7aa69f42021-09-20 10:44:07 +01003470 TosaArgGen.agConv,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003471 ),
3472 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003473 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003474 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
3475 "error_if_validators": (
3476 TosaErrorValidator.evWrongInputType,
3477 TosaErrorValidator.evWrongOutputType,
3478 TosaErrorValidator.evWrongInputList,
3479 TosaErrorValidator.evWrongOutputList,
3480 TosaErrorValidator.evInputZeroPointNotZero,
3481 TosaErrorValidator.evWeightZeroPointNotZero,
3482 TosaErrorValidator.evPadSmallerZero,
3483 TosaErrorValidator.evStrideSmallerOne,
3484 TosaErrorValidator.evDilationSmallerOne,
3485 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003486 TosaErrorValidator.evConvOutputShapeMismatch,
3487 TosaErrorValidator.evConvOutputShapeNonInteger,
Tai Lyf36f2562024-03-14 16:21:29 +00003488 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003489 ),
Jeremy Johnson4f931302024-01-04 17:05:24 +00003490 "data_gen": {
3491 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3492 },
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003493 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003494 "template": True,
3495 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003496 "fully_connected": {
3497 "op": Op.FULLY_CONNECTED,
3498 "operands": (1, 2),
3499 "rank": (2, 2),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003500 "build_fcn": (
3501 build_fully_connected,
3502 TosaTensorGen.tgFullyConnected,
Jeremy Johnson30476252023-11-20 16:15:30 +00003503 TosaTensorValuesGen.tvgFullyConnected,
James Ward8b390432022-08-12 20:48:56 +01003504 TosaArgGen.agFullyConnected,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003505 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003506 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003507 "types": TYPE_CONV,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003508 "error_if_validators": (
3509 TosaErrorValidator.evInputZeroPointNotZero,
3510 TosaErrorValidator.evWeightZeroPointNotZero,
3511 TosaErrorValidator.evWrongRank,
3512 TosaErrorValidator.evWrongInputType,
3513 TosaErrorValidator.evWrongOutputType,
3514 TosaErrorValidator.evWrongInputList,
3515 TosaErrorValidator.evWrongOutputList,
3516 ),
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00003517 "data_gen": {
3518 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3519 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003520 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003521 "matmul": {
3522 "op": Op.MATMUL,
3523 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07003524 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003525 "build_fcn": (
3526 build_matmul,
3527 TosaTensorGen.tgMatmul,
Jeremy Johnson1271c442023-09-05 11:39:26 +01003528 TosaTensorValuesGen.tvgLazyGenDefault,
James Ward8b390432022-08-12 20:48:56 +01003529 TosaArgGen.agMatMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003530 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003531 "qgen": TosaQuantGen.qgMatmul,
Won Jeon2c34b462024-02-06 18:37:00 +00003532 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003533 "error_if_validators": (
3534 TosaErrorValidator.evInputZeroPointNotZero,
3535 TosaErrorValidator.evWrongRank,
3536 TosaErrorValidator.evWrongInputType,
3537 TosaErrorValidator.evWrongOutputType,
3538 TosaErrorValidator.evWrongInputList,
3539 TosaErrorValidator.evWrongOutputList,
3540 ),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003541 "data_gen": {
3542 "fp": (gtu.DataGenType.DOT_PRODUCT,),
Jeremy Johnson1271c442023-09-05 11:39:26 +01003543 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003544 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003545 "max_pool2d": {
3546 "op": Op.MAX_POOL2D,
3547 "operands": (1, 0),
3548 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003549 "build_fcn": (
Jeremy Johnson0601f802023-11-08 16:28:09 +00003550 build_pool2d,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003551 TosaTensorGen.tgNHWC,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003552 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003553 TosaArgGen.agPooling,
3554 ),
Won Jeon2c34b462024-02-06 18:37:00 +00003555 "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Les Bell0e027d42021-11-09 14:42:14 +00003556 "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003557 "error_if_validators": (
3558 TosaErrorValidator.evKernelSmallerOne,
3559 TosaErrorValidator.evStrideSmallerOne,
3560 TosaErrorValidator.evPadSmallerZero,
3561 TosaErrorValidator.evWrongRank,
3562 TosaErrorValidator.evWrongInputType,
3563 TosaErrorValidator.evWrongOutputType,
3564 TosaErrorValidator.evWrongInputList,
3565 TosaErrorValidator.evWrongOutputList,
3566 TosaErrorValidator.evPadLargerEqualKernel,
3567 TosaErrorValidator.evPoolingOutputShapeMismatch,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003568 TosaErrorValidator.evPoolingOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003569 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01003570 "data_gen": {
3571 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3572 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003573 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003574 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08003575 "transpose_conv2d_TEMPLATE": {
3576 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07003577 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003578 "rank": (4, 4),
3579 "build_fcn": (
3580 build_transpose_conv2d,
3581 TosaTensorGen.tgTransposeConv2D,
Jeremy Johnson95a67102024-01-10 14:16:39 +00003582 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003583 TosaArgGen.agTransposeConv2D,
3584 ),
3585 "qgen": TosaQuantGen.qgConv,
Kevin Cheng1533b852021-09-01 12:51:58 -07003586 "types": TYPE_CONV,
Les Bell0e027d42021-11-09 14:42:14 +00003587 "invalid_test_validators": (
3588 TosaInvalidValidator.ivHeightWidthInvalid,
3589 TosaInvalidValidator.ivNonPositiveOutputShape,
3590 ),
3591 "error_if_validators": (
3592 TosaErrorValidator.evWrongInputType,
3593 TosaErrorValidator.evWrongOutputType,
3594 TosaErrorValidator.evWrongInputList,
3595 TosaErrorValidator.evWrongOutputList,
3596 TosaErrorValidator.evInputZeroPointNotZero,
3597 TosaErrorValidator.evWeightZeroPointNotZero,
Eric Kunzec1a97832022-07-01 16:56:09 -07003598 TosaErrorValidator.evPadLargerEqualKernel,
Les Bell0e027d42021-11-09 14:42:14 +00003599 TosaErrorValidator.evStrideSmallerOne,
Les Bell0e027d42021-11-09 14:42:14 +00003600 TosaErrorValidator.evWrongRank,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01003601 TosaErrorValidator.evConvOutputShapeMismatch,
Tai Lyf36f2562024-03-14 16:21:29 +00003602 TosaErrorValidator.evWrongAccumulatorType,
Les Bell0e027d42021-11-09 14:42:14 +00003603 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00003604 "data_gen": {
3605 "fp": (gtu.DataGenType.DOT_PRODUCT,),
3606 },
Jeremy Johnson5e36bde2024-03-14 16:56:10 +00003607 "filter": KERNELS_2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08003608 "template": True,
3609 },
Eric Kunzee5e26762020-10-13 16:11:07 -07003610 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08003611 "clamp": {
3612 "op": Op.CLAMP,
3613 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003614 "build_fcn": (
3615 build_clamp,
3616 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003617 TosaTensorValuesGen.tvgLazyGenDefault,
3618 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003619 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003620 "types": TYPE_NARROW_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003621 "error_if_validators": (
3622 TosaErrorValidator.evMaxSmallerMin,
3623 TosaErrorValidator.evWrongInputType,
3624 TosaErrorValidator.evWrongOutputType,
3625 TosaErrorValidator.evWrongInputList,
3626 TosaErrorValidator.evWrongOutputList,
3627 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003628 "data_gen": {
3629 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3630 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003631 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003632 "sigmoid": {
3633 "op": Op.SIGMOID,
3634 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003635 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003636 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003637 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003638 TosaTensorValuesGen.tvgLazyGenDefault,
3639 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003640 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003641 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003642 "error_if_validators": (
3643 TosaErrorValidator.evWrongInputType,
3644 TosaErrorValidator.evWrongOutputType,
3645 TosaErrorValidator.evWrongInputList,
3646 TosaErrorValidator.evWrongOutputList,
3647 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003648 "data_gen": {
3649 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3650 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003651 },
3652 "tanh": {
3653 "op": Op.TANH,
3654 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003655 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003656 build_activation,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003657 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003658 TosaTensorValuesGen.tvgLazyGenDefault,
3659 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003660 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08003661 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003662 "error_if_validators": (
3663 TosaErrorValidator.evWrongInputType,
3664 TosaErrorValidator.evWrongOutputType,
3665 TosaErrorValidator.evWrongInputList,
3666 TosaErrorValidator.evWrongOutputList,
3667 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003668 "data_gen": {
3669 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3670 },
Jeremy Johnsond80ea5e2024-01-03 10:54:12 +00003671 "compliance": {
3672 "abs_error_lower_bound": 0.5,
3673 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08003674 },
Won Jeon78155c62023-06-10 00:20:04 +00003675 "erf": {
3676 "op": Op.ERF,
3677 "operands": (1, 0),
3678 "build_fcn": (
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003679 build_activation,
Won Jeon78155c62023-06-10 00:20:04 +00003680 TosaTensorGen.tgBasic,
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003681 TosaTensorValuesGen.tvgLazyGenDefault,
3682 TosaArgGen.agNone,
Won Jeon78155c62023-06-10 00:20:04 +00003683 ),
3684 "types": TYPE_FP,
3685 "error_if_validators": (
3686 TosaErrorValidator.evWrongInputType,
3687 TosaErrorValidator.evWrongOutputType,
3688 TosaErrorValidator.evWrongInputList,
3689 TosaErrorValidator.evWrongOutputList,
3690 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00003691 "data_gen": {
3692 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3693 },
3694 "compliance": {"ulp": 5},
Won Jeon78155c62023-06-10 00:20:04 +00003695 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003696 # Elementwise Binary Operators
3697 "add": {
3698 "op": Op.ADD,
3699 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003700 "build_fcn": (
3701 build_binary_broadcast,
3702 TosaTensorGen.tgBroadcastFuzz,
3703 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003704 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003705 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003706 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003707 "error_if_validators": (
3708 TosaErrorValidator.evRankMismatch,
3709 TosaErrorValidator.evWrongInputType,
3710 TosaErrorValidator.evWrongOutputType,
3711 TosaErrorValidator.evWrongInputList,
3712 TosaErrorValidator.evWrongOutputList,
3713 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003714 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003715 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003716 "data_gen": {
3717 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3718 },
3719 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003720 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003721 "arithmetic_right_shift": {
3722 "op": Op.ARITHMETIC_RIGHT_SHIFT,
3723 "operands": (2, 0),
3724 "build_fcn": (
3725 build_arithmetic_right_shift,
3726 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003727 TosaTensorValuesGen.tvgArithmeticRightShift,
Jared Smolens573ecd42021-03-04 15:24:10 -08003728 TosaArgGen.agArithmeticRightShift,
3729 ),
3730 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003731 "error_if_validators": (
3732 TosaErrorValidator.evRankMismatch,
3733 TosaErrorValidator.evWrongInputType,
3734 TosaErrorValidator.evWrongOutputType,
3735 TosaErrorValidator.evWrongInputList,
3736 TosaErrorValidator.evWrongOutputList,
3737 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003738 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003739 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003740 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003741 "bitwise_and": {
3742 "op": Op.BITWISE_AND,
3743 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003744 "build_fcn": (
3745 build_binary_broadcast,
3746 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003747 TosaTensorValuesGen.tvgLazyGenDefault,
3748 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003749 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003750 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003751 "error_if_validators": (
3752 TosaErrorValidator.evRankMismatch,
3753 TosaErrorValidator.evWrongInputType,
3754 TosaErrorValidator.evWrongOutputType,
3755 TosaErrorValidator.evWrongInputList,
3756 TosaErrorValidator.evWrongOutputList,
3757 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003758 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003759 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003760 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003761 "bitwise_or": {
3762 "op": Op.BITWISE_OR,
3763 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003764 "build_fcn": (
3765 build_binary_broadcast,
3766 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003767 TosaTensorValuesGen.tvgLazyGenDefault,
3768 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003769 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003770 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003771 "error_if_validators": (
3772 TosaErrorValidator.evRankMismatch,
3773 TosaErrorValidator.evWrongInputType,
3774 TosaErrorValidator.evWrongOutputType,
3775 TosaErrorValidator.evWrongInputList,
3776 TosaErrorValidator.evWrongOutputList,
3777 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003778 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003779 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003780 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003781 "bitwise_xor": {
3782 "op": Op.BITWISE_XOR,
3783 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003784 "build_fcn": (
3785 build_binary_broadcast,
3786 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003787 TosaTensorValuesGen.tvgLazyGenDefault,
3788 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003789 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003790 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003791 "error_if_validators": (
3792 TosaErrorValidator.evRankMismatch,
3793 TosaErrorValidator.evWrongInputType,
3794 TosaErrorValidator.evWrongOutputType,
3795 TosaErrorValidator.evWrongInputList,
3796 TosaErrorValidator.evWrongOutputList,
3797 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003798 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003799 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003800 },
Matthew Haddon459443c2021-08-23 16:43:13 +01003801 "intdiv": {
3802 "op": Op.INTDIV,
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003803 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003804 "build_fcn": (
3805 build_binary_broadcast,
3806 TosaTensorGen.tgBroadcastFuzz,
3807 TosaTensorValuesGen.tvgIntDiv,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003808 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003809 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003810 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003811 "error_if_validators": (
3812 TosaErrorValidator.evRankMismatch,
3813 TosaErrorValidator.evWrongInputType,
3814 TosaErrorValidator.evWrongOutputType,
3815 TosaErrorValidator.evWrongInputList,
3816 TosaErrorValidator.evWrongOutputList,
3817 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003818 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003819 ),
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07003820 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003821 "logical_and": {
3822 "op": Op.LOGICAL_AND,
3823 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003824 "build_fcn": (
3825 build_binary_broadcast,
3826 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003827 TosaTensorValuesGen.tvgLazyGenDefault,
3828 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003829 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003830 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003831 "error_if_validators": (
3832 TosaErrorValidator.evRankMismatch,
3833 TosaErrorValidator.evWrongInputType,
3834 TosaErrorValidator.evWrongOutputType,
3835 TosaErrorValidator.evWrongInputList,
3836 TosaErrorValidator.evWrongOutputList,
3837 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003838 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003839 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003840 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003841 "logical_left_shift": {
3842 "op": Op.LOGICAL_LEFT_SHIFT,
3843 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003844 "build_fcn": (
3845 build_binary_broadcast,
3846 TosaTensorGen.tgBroadcastFuzz,
3847 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003848 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003849 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003850 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003851 "error_if_validators": (
3852 TosaErrorValidator.evRankMismatch,
3853 TosaErrorValidator.evWrongInputType,
3854 TosaErrorValidator.evWrongOutputType,
3855 TosaErrorValidator.evWrongInputList,
3856 TosaErrorValidator.evWrongOutputList,
3857 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003858 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003859 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003860 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003861 "logical_right_shift": {
3862 "op": Op.LOGICAL_RIGHT_SHIFT,
3863 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003864 "build_fcn": (
3865 build_binary_broadcast,
3866 TosaTensorGen.tgBroadcastFuzz,
3867 TosaTensorValuesGen.tvgLogicalShift,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003868 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003869 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003870 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003871 "error_if_validators": (
3872 TosaErrorValidator.evRankMismatch,
3873 TosaErrorValidator.evWrongInputType,
3874 TosaErrorValidator.evWrongOutputType,
3875 TosaErrorValidator.evWrongInputList,
3876 TosaErrorValidator.evWrongOutputList,
3877 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003878 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003879 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003880 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003881 "logical_or": {
3882 "op": Op.LOGICAL_OR,
3883 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003884 "build_fcn": (
3885 build_binary_broadcast,
3886 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003887 TosaTensorValuesGen.tvgLazyGenDefault,
3888 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003889 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003890 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003891 "error_if_validators": (
3892 TosaErrorValidator.evRankMismatch,
3893 TosaErrorValidator.evWrongInputType,
3894 TosaErrorValidator.evWrongOutputType,
3895 TosaErrorValidator.evWrongInputList,
3896 TosaErrorValidator.evWrongOutputList,
3897 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003898 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003899 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003900 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003901 "logical_xor": {
3902 "op": Op.LOGICAL_XOR,
3903 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003904 "build_fcn": (
3905 build_binary_broadcast,
3906 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003907 TosaTensorValuesGen.tvgLazyGenDefault,
3908 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003909 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003910 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003911 "error_if_validators": (
3912 TosaErrorValidator.evRankMismatch,
3913 TosaErrorValidator.evWrongInputType,
3914 TosaErrorValidator.evWrongOutputType,
3915 TosaErrorValidator.evWrongInputList,
3916 TosaErrorValidator.evWrongOutputList,
3917 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003918 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003919 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003920 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003921 "maximum": {
3922 "op": Op.MAXIMUM,
3923 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003924 "build_fcn": (
3925 build_binary_broadcast,
3926 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003927 TosaTensorValuesGen.tvgLazyGenDefault,
3928 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003929 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003930 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003931 "error_if_validators": (
3932 TosaErrorValidator.evRankMismatch,
3933 TosaErrorValidator.evWrongInputType,
3934 TosaErrorValidator.evWrongOutputType,
3935 TosaErrorValidator.evWrongInputList,
3936 TosaErrorValidator.evWrongOutputList,
3937 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003938 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003939 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003940 "data_gen": {
3941 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3942 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003943 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003944 "minimum": {
3945 "op": Op.MINIMUM,
3946 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003947 "build_fcn": (
3948 build_binary_broadcast,
3949 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003950 TosaTensorValuesGen.tvgLazyGenDefault,
3951 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003952 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003953 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003954 "error_if_validators": (
3955 TosaErrorValidator.evRankMismatch,
3956 TosaErrorValidator.evWrongInputType,
3957 TosaErrorValidator.evWrongOutputType,
3958 TosaErrorValidator.evWrongInputList,
3959 TosaErrorValidator.evWrongOutputList,
3960 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003961 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003962 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00003963 "data_gen": {
3964 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3965 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003966 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003967 "mul": {
3968 "op": Op.MUL,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003969 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003970 "build_fcn": (
3971 build_mul,
Jeremy Johnson0a042992024-02-28 13:20:05 +00003972 TosaTensorGen.tgMul,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003973 TosaTensorValuesGen.tvgMul,
3974 TosaArgGen.agMul,
3975 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08003976 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003977 "error_if_validators": (
3978 TosaErrorValidator.evWrongInputType,
3979 TosaErrorValidator.evWrongOutputType,
3980 TosaErrorValidator.evWrongInputList,
3981 TosaErrorValidator.evWrongOutputList,
3982 TosaErrorValidator.evRankMismatch,
3983 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00003984 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00003985 ),
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01003986 "data_gen": {
3987 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
3988 },
3989 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08003990 },
Jared Smolens573ecd42021-03-04 15:24:10 -08003991 "pow": {
3992 "op": Op.POW,
3993 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003994 "build_fcn": (
3995 build_binary_broadcast,
3996 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnson30476252023-11-20 16:15:30 +00003997 TosaTensorValuesGen.tvgPow,
3998 TosaArgGen.agPow,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003999 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004000 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004001 "error_if_validators": (
4002 TosaErrorValidator.evRankMismatch,
4003 TosaErrorValidator.evWrongInputType,
4004 TosaErrorValidator.evWrongOutputType,
4005 TosaErrorValidator.evWrongInputList,
4006 TosaErrorValidator.evWrongOutputList,
4007 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004008 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004009 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004010 "data_gen": {
4011 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4012 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004013 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004014 "sub": {
4015 "op": Op.SUB,
4016 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004017 "build_fcn": (
4018 build_binary_broadcast,
4019 TosaTensorGen.tgBroadcastFuzz,
4020 TosaTensorValuesGen.tvgAddSub,
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004021 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004022 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004023 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004024 "error_if_validators": (
4025 TosaErrorValidator.evRankMismatch,
4026 TosaErrorValidator.evWrongInputType,
4027 TosaErrorValidator.evWrongOutputType,
4028 TosaErrorValidator.evWrongInputList,
4029 TosaErrorValidator.evWrongOutputList,
4030 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004031 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004032 ),
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00004033 "data_gen": {
4034 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4035 },
4036 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004037 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004038 "table": {
4039 "op": Op.TABLE,
4040 # Use the automatic generation functions to create the input array
4041 # but create the table tensor in the build function, as it may be
4042 # a different type from the input
4043 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004044 "build_fcn": (
4045 build_table,
4046 TosaTensorGen.tgBasic,
Jeremy Johnson587cc842024-02-08 11:45:44 +00004047 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004048 TosaArgGen.agTable,
4049 ),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01004050 "types": [DType.INT8, DType.INT16],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004051 "error_if_validators": (
4052 TosaErrorValidator.evWrongInputType,
4053 TosaErrorValidator.evWrongOutputType,
4054 TosaErrorValidator.evWrongInputList,
4055 TosaErrorValidator.evWrongOutputList,
4056 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004057 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004058 # Elementwise Unary operators
4059 "abs": {
4060 "op": Op.ABS,
4061 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004062 "build_fcn": (
4063 build_unary,
4064 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004065 TosaTensorValuesGen.tvgLazyGenDefault,
4066 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004067 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004068 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004069 "error_if_validators": (
4070 TosaErrorValidator.evWrongInputType,
4071 TosaErrorValidator.evWrongOutputType,
4072 TosaErrorValidator.evWrongInputList,
4073 TosaErrorValidator.evWrongOutputList,
4074 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004075 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004076 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004077 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004078 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004079 "bitwise_not": {
4080 "op": Op.BITWISE_NOT,
4081 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004082 "build_fcn": (
4083 build_unary,
4084 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004085 TosaTensorValuesGen.tvgLazyGenDefault,
4086 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004087 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004088 "types": TYPE_INT,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004089 "error_if_validators": (
4090 TosaErrorValidator.evWrongInputType,
4091 TosaErrorValidator.evWrongOutputType,
4092 TosaErrorValidator.evWrongInputList,
4093 TosaErrorValidator.evWrongOutputList,
4094 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004095 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004096 "ceil": {
4097 "op": Op.CEIL,
4098 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004099 "build_fcn": (
4100 build_unary,
4101 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004102 TosaTensorValuesGen.tvgLazyGenDefault,
4103 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004104 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004105 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004106 "error_if_validators": (
4107 TosaErrorValidator.evWrongInputType,
4108 TosaErrorValidator.evWrongOutputType,
4109 TosaErrorValidator.evWrongInputList,
4110 TosaErrorValidator.evWrongOutputList,
4111 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004112 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004113 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004114 },
4115 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004116 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004117 "clz": {
4118 "op": Op.CLZ,
4119 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004120 "build_fcn": (
4121 build_unary,
4122 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004123 TosaTensorValuesGen.tvgLazyGenDefault,
4124 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004125 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004126 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004127 "error_if_validators": (
4128 TosaErrorValidator.evWrongInputType,
4129 TosaErrorValidator.evWrongOutputType,
4130 TosaErrorValidator.evWrongInputList,
4131 TosaErrorValidator.evWrongOutputList,
4132 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004133 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004134 "cos": {
4135 "op": Op.COS,
4136 "operands": (1, 0),
4137 "build_fcn": (
4138 build_unary,
4139 TosaTensorGen.tgBasic,
4140 TosaTensorValuesGen.tvgLazyGenDefault,
4141 TosaArgGen.agNone,
4142 ),
4143 "types": TYPE_FP,
4144 "error_if_validators": (
4145 TosaErrorValidator.evWrongInputType,
4146 TosaErrorValidator.evWrongOutputType,
4147 TosaErrorValidator.evWrongInputList,
4148 TosaErrorValidator.evWrongOutputList,
4149 ),
4150 "data_gen": {
4151 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4152 },
4153 "compliance": {"abs_error_normal_divisor": 2},
4154 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004155 "exp": {
4156 "op": Op.EXP,
4157 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004158 "build_fcn": (
4159 build_unary,
4160 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004161 TosaTensorValuesGen.tvgExp,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004162 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004163 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004164 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004165 "error_if_validators": (
4166 TosaErrorValidator.evWrongInputType,
4167 TosaErrorValidator.evWrongOutputType,
4168 TosaErrorValidator.evWrongInputList,
4169 TosaErrorValidator.evWrongOutputList,
4170 ),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004171 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004172 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson9a758382023-11-07 16:27:35 +00004173 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004174 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004175 "floor": {
4176 "op": Op.FLOOR,
4177 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004178 "build_fcn": (
4179 build_unary,
4180 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004181 TosaTensorValuesGen.tvgLazyGenDefault,
4182 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004183 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004184 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004185 "error_if_validators": (
4186 TosaErrorValidator.evWrongInputType,
4187 TosaErrorValidator.evWrongOutputType,
4188 TosaErrorValidator.evWrongInputList,
4189 TosaErrorValidator.evWrongOutputList,
4190 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004191 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004192 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004193 },
4194 "compliance": {"ulp": 0.5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004195 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004196 "log": {
4197 "op": Op.LOG,
4198 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004199 "build_fcn": (
4200 build_unary,
4201 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004202 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004203 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004204 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004205 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004206 "error_if_validators": (
4207 TosaErrorValidator.evWrongInputType,
4208 TosaErrorValidator.evWrongOutputType,
4209 TosaErrorValidator.evWrongInputList,
4210 TosaErrorValidator.evWrongOutputList,
4211 ),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004212 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004213 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson0bbd8bc2023-11-09 16:56:07 +00004214 },
4215 "compliance": {"ulp": 5},
Jared Smolens573ecd42021-03-04 15:24:10 -08004216 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004217 "logical_not": {
4218 "op": Op.LOGICAL_NOT,
4219 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004220 "build_fcn": (
4221 build_unary,
4222 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004223 TosaTensorValuesGen.tvgLazyGenDefault,
4224 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004225 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004226 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004227 "error_if_validators": (
4228 TosaErrorValidator.evWrongInputType,
4229 TosaErrorValidator.evWrongOutputType,
4230 TosaErrorValidator.evWrongInputList,
4231 TosaErrorValidator.evWrongOutputList,
4232 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004233 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004234 "negate": {
4235 "op": Op.NEGATE,
4236 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004237 "build_fcn": (
4238 build_unary,
4239 TosaTensorGen.tgBasic,
4240 TosaTensorValuesGen.tvgNegate,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004241 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004242 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004243 "qgen": TosaQuantGen.qgUnary,
4244 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004245 "error_if_validators": (
4246 TosaErrorValidator.evInputZeroPointNotZero,
4247 TosaErrorValidator.evOutputZeroPointNotZero,
4248 TosaErrorValidator.evWrongInputType,
4249 TosaErrorValidator.evWrongOutputType,
4250 TosaErrorValidator.evWrongInputList,
4251 TosaErrorValidator.evWrongOutputList,
4252 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004253 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004254 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004255 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004256 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004257 "reciprocal": {
4258 "op": Op.RECIPROCAL,
4259 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004260 "build_fcn": (
4261 build_unary,
4262 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004263 TosaTensorValuesGen.tvgLazyGenDefault,
4264 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004265 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004266 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004267 "error_if_validators": (
4268 TosaErrorValidator.evWrongInputType,
4269 TosaErrorValidator.evWrongOutputType,
4270 TosaErrorValidator.evWrongInputList,
4271 TosaErrorValidator.evWrongOutputList,
4272 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004273 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004274 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004275 },
4276 "compliance": {"ulp": 1.0},
Jared Smolens573ecd42021-03-04 15:24:10 -08004277 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004278 "rsqrt": {
4279 "op": Op.RSQRT,
4280 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004281 "build_fcn": (
4282 build_unary,
4283 TosaTensorGen.tgBasic,
Jeremy Johnson30476252023-11-20 16:15:30 +00004284 TosaTensorValuesGen.tvgLogRsqrt,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004285 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004286 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004287 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004288 "error_if_validators": (
4289 TosaErrorValidator.evWrongInputType,
4290 TosaErrorValidator.evWrongOutputType,
4291 TosaErrorValidator.evWrongInputList,
4292 TosaErrorValidator.evWrongOutputList,
4293 ),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004294 "data_gen": {
evacha019c96eef2024-02-07 11:21:55 +00004295 "fp": (gtu.DataGenType.FULL_RANGE,),
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004296 },
4297 "compliance": {"ulp": 2},
Jared Smolens573ecd42021-03-04 15:24:10 -08004298 },
Jerry Ge51bd4f52024-02-20 11:21:19 -08004299 "sin": {
4300 "op": Op.SIN,
4301 "operands": (1, 0),
4302 "build_fcn": (
4303 build_unary,
4304 TosaTensorGen.tgBasic,
4305 TosaTensorValuesGen.tvgLazyGenDefault,
4306 TosaArgGen.agNone,
4307 ),
4308 "types": TYPE_FP,
4309 "error_if_validators": (
4310 TosaErrorValidator.evWrongInputType,
4311 TosaErrorValidator.evWrongOutputType,
4312 TosaErrorValidator.evWrongInputList,
4313 TosaErrorValidator.evWrongOutputList,
4314 ),
4315 "data_gen": {
4316 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4317 },
4318 "compliance": {"abs_error_normal_divisor": 2},
4319 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004320 # Elementwise Ternary operators
4321 "select": {
4322 "op": Op.SELECT,
4323 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004324 "build_fcn": (
4325 build_select,
4326 TosaTensorGen.tgBroadcastFuzz,
4327 TosaTensorValuesGen.tvgSelect,
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004328 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004329 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004330 "types": TYPE_FIB,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004331 "error_if_validators": (
4332 TosaErrorValidator.evRankMismatch,
4333 TosaErrorValidator.evWrongInputType,
4334 TosaErrorValidator.evWrongOutputType,
4335 TosaErrorValidator.evWrongInputList,
4336 TosaErrorValidator.evWrongOutputList,
4337 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004338 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004339 ),
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00004340 "data_gen": {
4341 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4342 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004343 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004344 # Comparison operators
4345 "equal": {
4346 "op": Op.EQUAL,
4347 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004348 "build_fcn": (
4349 build_comparison,
4350 TosaTensorGen.tgBroadcastFuzz,
4351 TosaTensorValuesGen.tvgEqual,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004352 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004353 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004354 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004355 "error_if_validators": (
4356 TosaErrorValidator.evRankMismatch,
4357 TosaErrorValidator.evWrongInputType,
4358 TosaErrorValidator.evWrongOutputType,
4359 TosaErrorValidator.evWrongInputList,
4360 TosaErrorValidator.evWrongOutputList,
4361 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004362 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004363 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004364 "data_gen": {
4365 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4366 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004367 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004368 "greater_equal": {
4369 "op": Op.GREATER_EQUAL,
4370 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004371 "build_fcn": (
4372 build_comparison,
4373 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004374 TosaTensorValuesGen.tvgLazyGenDefault,
4375 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004376 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004377 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004378 "error_if_validators": (
4379 TosaErrorValidator.evRankMismatch,
4380 TosaErrorValidator.evWrongInputType,
4381 TosaErrorValidator.evWrongOutputType,
4382 TosaErrorValidator.evWrongInputList,
4383 TosaErrorValidator.evWrongOutputList,
4384 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004385 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004386 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004387 "data_gen": {
4388 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4389 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004390 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004391 "greater": {
4392 "op": Op.GREATER,
4393 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004394 "build_fcn": (
4395 build_comparison,
4396 TosaTensorGen.tgBroadcastFuzz,
Jeremy Johnsona0150012023-11-15 15:52:06 +00004397 TosaTensorValuesGen.tvgLazyGenDefault,
4398 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004399 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004400 "types": TYPE_FI32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004401 "error_if_validators": (
4402 TosaErrorValidator.evRankMismatch,
4403 TosaErrorValidator.evWrongInputType,
4404 TosaErrorValidator.evWrongOutputType,
4405 TosaErrorValidator.evWrongInputList,
4406 TosaErrorValidator.evWrongOutputList,
4407 TosaErrorValidator.evDimensionMismatch,
Jerry Ge135c9552023-05-23 20:59:32 +00004408 TosaErrorValidator.evBroadcastShapesMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004409 ),
Jeremy Johnsona0150012023-11-15 15:52:06 +00004410 "data_gen": {
4411 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4412 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004413 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004414 # Reduction operators
4415 "reduce_all": {
4416 "op": Op.REDUCE_ALL,
4417 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004418 "build_fcn": (
4419 build_reduce,
4420 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004421 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004422 TosaArgGen.agAxis,
4423 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004424 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004425 "error_if_validators": (
4426 TosaErrorValidator.evAxisLargerRank,
4427 TosaErrorValidator.evAxisSmallerZero,
4428 TosaErrorValidator.evShapeOfAxisNotOne,
4429 TosaErrorValidator.evWrongInputType,
4430 TosaErrorValidator.evWrongOutputType,
4431 TosaErrorValidator.evWrongRank,
4432 TosaErrorValidator.evWrongInputList,
4433 TosaErrorValidator.evWrongOutputList,
4434 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004435 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004436 "reduce_any": {
4437 "op": Op.REDUCE_ANY,
4438 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004439 "build_fcn": (
4440 build_reduce,
4441 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004442 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004443 TosaArgGen.agAxis,
4444 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004445 "types": TYPE_BOOL,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004446 "error_if_validators": (
4447 TosaErrorValidator.evAxisLargerRank,
4448 TosaErrorValidator.evAxisSmallerZero,
4449 TosaErrorValidator.evShapeOfAxisNotOne,
4450 TosaErrorValidator.evWrongInputType,
4451 TosaErrorValidator.evWrongOutputType,
4452 TosaErrorValidator.evWrongRank,
4453 TosaErrorValidator.evWrongInputList,
4454 TosaErrorValidator.evWrongOutputList,
4455 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004456 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004457 "reduce_max": {
4458 "op": Op.REDUCE_MAX,
4459 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004460 "build_fcn": (
4461 build_reduce,
4462 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004463 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004464 TosaArgGen.agAxis,
4465 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004466 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004467 "error_if_validators": (
4468 TosaErrorValidator.evAxisLargerRank,
4469 TosaErrorValidator.evAxisSmallerZero,
4470 TosaErrorValidator.evShapeOfAxisNotOne,
4471 TosaErrorValidator.evWrongInputType,
4472 TosaErrorValidator.evWrongOutputType,
4473 TosaErrorValidator.evWrongRank,
4474 TosaErrorValidator.evWrongInputList,
4475 TosaErrorValidator.evWrongOutputList,
4476 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004477 "data_gen": {
4478 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4479 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004480 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004481 "reduce_min": {
Jeremy Johnson8a8cca92022-03-14 12:16:46 +00004482 "op": Op.REDUCE_MIN,
Jared Smolens573ecd42021-03-04 15:24:10 -08004483 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004484 "build_fcn": (
4485 build_reduce,
4486 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004487 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004488 TosaArgGen.agAxis,
4489 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004490 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004491 "error_if_validators": (
4492 TosaErrorValidator.evAxisLargerRank,
4493 TosaErrorValidator.evAxisSmallerZero,
4494 TosaErrorValidator.evShapeOfAxisNotOne,
4495 TosaErrorValidator.evWrongInputType,
4496 TosaErrorValidator.evWrongOutputType,
4497 TosaErrorValidator.evWrongRank,
4498 TosaErrorValidator.evWrongInputList,
4499 TosaErrorValidator.evWrongOutputList,
4500 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004501 "data_gen": {
4502 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4503 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004504 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004505 "reduce_product": {
4506 "op": Op.REDUCE_PRODUCT,
4507 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004508 "build_fcn": (
4509 build_reduce,
4510 TosaTensorGen.tgBasic,
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004511 TosaTensorValuesGen.tvgReduceProduct,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004512 TosaArgGen.agAxis,
4513 ),
Jared Smolens573ecd42021-03-04 15:24:10 -08004514 "types": TYPE_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004515 "error_if_validators": (
4516 TosaErrorValidator.evAxisLargerRank,
4517 TosaErrorValidator.evAxisSmallerZero,
4518 TosaErrorValidator.evShapeOfAxisNotOne,
4519 TosaErrorValidator.evWrongInputType,
4520 TosaErrorValidator.evWrongOutputType,
4521 TosaErrorValidator.evWrongRank,
4522 TosaErrorValidator.evWrongInputList,
4523 TosaErrorValidator.evWrongOutputList,
4524 ),
Jeremy Johnsonbd801962024-01-03 17:07:44 +00004525 "data_gen": {
4526 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4527 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004528 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004529 "reduce_sum": {
4530 "op": Op.REDUCE_SUM,
4531 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004532 "build_fcn": (
4533 build_reduce,
4534 TosaTensorGen.tgBasic,
4535 TosaTensorValuesGen.tvgReduceSum,
4536 TosaArgGen.agAxis,
4537 ),
James Ward24dbc422022-10-19 12:20:31 +01004538 "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004539 "error_if_validators": (
4540 TosaErrorValidator.evAxisLargerRank,
4541 TosaErrorValidator.evAxisSmallerZero,
4542 TosaErrorValidator.evShapeOfAxisNotOne,
4543 TosaErrorValidator.evWrongInputType,
4544 TosaErrorValidator.evWrongOutputType,
4545 TosaErrorValidator.evWrongRank,
4546 TosaErrorValidator.evWrongInputList,
4547 TosaErrorValidator.evWrongOutputList,
4548 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004549 "data_gen": {
4550 "fp": (gtu.DataGenType.DOT_PRODUCT,),
4551 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004552 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004553 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08004554 "concat": {
4555 "op": Op.CONCAT,
4556 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004557 "build_fcn": (
4558 build_concat,
4559 TosaTensorGen.tgConcat,
4560 TosaTensorValuesGen.tvgConcat,
4561 TosaArgGen.agAxis,
4562 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004563 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004564 "error_if_validators": (
4565 TosaErrorValidator.evAxisLargerRank,
4566 TosaErrorValidator.evAxisSmallerZero,
4567 TosaErrorValidator.evConcatInputRankMismatch,
4568 TosaErrorValidator.evConcatShapeSumMismatch,
4569 TosaErrorValidator.evConcatInputDimMismatch,
4570 TosaErrorValidator.evWrongInputType,
4571 TosaErrorValidator.evWrongOutputType,
4572 TosaErrorValidator.evWrongOutputList,
4573 ),
Jeremy Johnson3eafe662024-01-10 13:13:35 +00004574 "data_gen": {
4575 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4576 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004577 },
4578 "pad": {
4579 "op": Op.PAD,
Tai Lye095da72024-01-25 22:00:18 +00004580 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004581 "build_fcn": (
4582 build_pad,
4583 TosaTensorGen.tgBasic,
Tai Lye095da72024-01-25 22:00:18 +00004584 TosaTensorValuesGen.tvgPad,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004585 TosaArgGen.agPad,
4586 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004587 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004588 "error_if_validators": (
4589 TosaErrorValidator.evWrongInputType,
4590 TosaErrorValidator.evPadSmallerZero,
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01004591 TosaErrorValidator.evPadOutputShapeMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004592 TosaErrorValidator.evWrongOutputType,
4593 TosaErrorValidator.evWrongInputList,
4594 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004595 TosaErrorValidator.evRankMismatch,
4596 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004597 ),
Jeremy Johnsond41feb72023-10-12 16:03:15 +01004598 "data_gen": {
4599 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4600 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004601 },
Won Jeona21b2e82023-08-10 10:33:01 +00004602 "dim": {
4603 "op": Op.DIM,
4604 "operands": (1, 0),
4605 "build_fcn": (
4606 build_dim,
4607 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004608 TosaTensorValuesGen.tvgLazyGenDefault,
Won Jeona21b2e82023-08-10 10:33:01 +00004609 TosaArgGen.agAxis,
4610 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004611 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Won Jeona21b2e82023-08-10 10:33:01 +00004612 "error_if_validators": (
4613 TosaErrorValidator.evAxisLargerRank,
4614 TosaErrorValidator.evAxisSmallerZero,
4615 TosaErrorValidator.evWrongInputType,
4616 TosaErrorValidator.evWrongInputList,
4617 TosaErrorValidator.evWrongOutputList,
4618 TosaErrorValidator.evWrongRank,
4619 ),
4620 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004621 "reshape": {
4622 "op": Op.RESHAPE,
Tai Ly8690a082023-12-18 20:40:24 +00004623 "operands": (2, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004624 "build_fcn": (
4625 build_reshape,
4626 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004627 TosaTensorValuesGen.tvgReshape,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004628 TosaArgGen.agReshape,
4629 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004630 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004631 "error_if_validators": (
4632 TosaErrorValidator.evTensorSizeInputOutputMismatch,
4633 TosaErrorValidator.evWrongInputType,
4634 TosaErrorValidator.evWrongOutputType,
4635 TosaErrorValidator.evWrongInputList,
4636 TosaErrorValidator.evWrongOutputList,
4637 ),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00004638 "data_gen": {
4639 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4640 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004641 },
4642 "reverse": {
4643 "op": Op.REVERSE,
4644 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004645 "build_fcn": (
4646 build_reverse,
4647 TosaTensorGen.tgBasic,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00004648 TosaTensorValuesGen.tvgLazyGenDefault,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004649 TosaArgGen.agAxis,
4650 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004651 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004652 "error_if_validators": (
4653 TosaErrorValidator.evAxisSmallerZero,
4654 TosaErrorValidator.evAxisLargerRank,
4655 TosaErrorValidator.evWrongInputType,
4656 TosaErrorValidator.evWrongOutputType,
4657 TosaErrorValidator.evWrongInputList,
4658 TosaErrorValidator.evWrongOutputList,
4659 ),
evacha0198477222024-01-26 12:25:32 +00004660 "data_gen": {
4661 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4662 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004663 },
4664 "slice": {
4665 "op": Op.SLICE,
TatWai Chongf15bad82024-01-31 21:33:27 -08004666 "operands": (3, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004667 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004668 "build_fcn": (
4669 build_slice,
4670 TosaTensorGen.tgBasic,
TatWai Chongf15bad82024-01-31 21:33:27 -08004671 TosaTensorValuesGen.tvgSlice,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004672 TosaArgGen.agSlice,
4673 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004674 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004675 "error_if_validators": (
TatWai Chongf15bad82024-01-31 21:33:27 -08004676 # TODO Turn off these error categories for now as the reference
4677 # model cannot allocate memory space for empty tensor. We probably
4678 # can report an accurate error messege at the right place during
4679 # exeuction.
4680 # TosaErrorValidator.evStartSmallerZero,
4681 # TosaErrorValidator.evSizeSmallerEqualZero,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004682 TosaErrorValidator.evStartSizeOutsideBounds,
4683 TosaErrorValidator.evSizeOutputShapeMismatch,
4684 TosaErrorValidator.evInputSizeStartLengthMismatch,
4685 TosaErrorValidator.evWrongRank,
4686 TosaErrorValidator.evWrongInputType,
4687 TosaErrorValidator.evWrongOutputType,
4688 TosaErrorValidator.evWrongInputList,
4689 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004690 TosaErrorValidator.evRankMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004691 ),
evacha017f7d4252024-01-24 12:08:09 +00004692 "data_gen": {
4693 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4694 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004695 },
4696 "tile": {
4697 "op": Op.TILE,
Tai Ly8690a082023-12-18 20:40:24 +00004698 "operands": (2, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004699 "rank": (1, 6),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004700 "build_fcn": (
4701 build_tile,
4702 TosaTensorGen.tgBasic,
Won Jeon64e4bfe2024-01-18 06:31:55 +00004703 TosaTensorValuesGen.tvgTile,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004704 TosaArgGen.agTile,
4705 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004706 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004707 "error_if_validators": (
4708 TosaErrorValidator.evWrongInputType,
4709 TosaErrorValidator.evWrongOutputType,
4710 TosaErrorValidator.evWrongInputList,
4711 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004712 TosaErrorValidator.evRankMismatch,
4713 TosaErrorValidator.evWrongRank,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004714 ),
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00004715 "data_gen": {
4716 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4717 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004718 },
4719 "transpose": {
4720 "op": Op.TRANSPOSE,
4721 "operands": (1, 0),
Luke Huttona4e48ca2023-02-22 11:53:48 +00004722 "rank": (1, 6),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004723 "build_fcn": (
4724 build_transpose,
4725 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004726 TosaTensorValuesGen.tvgLazyGenDefault,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004727 TosaArgGen.agTranspose,
4728 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004729 "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004730 "error_if_validators": (
4731 TosaErrorValidator.evIndexOutsideBounds,
4732 TosaErrorValidator.evIndexUsedTwice,
4733 TosaErrorValidator.evWrongInputType,
4734 TosaErrorValidator.evWrongOutputType,
4735 TosaErrorValidator.evWrongInputList,
4736 TosaErrorValidator.evWrongOutputList,
Luke Huttona4e48ca2023-02-22 11:53:48 +00004737 TosaErrorValidator.evWrongRank,
4738 TosaErrorValidator.evRankMismatch,
4739 TosaErrorValidator.evTensorSizeInputOutputMismatch,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004740 ),
evacha0198477222024-01-26 12:25:32 +00004741 "data_gen": {
4742 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4743 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004744 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004745 # Data nodes
4746 "const": {
4747 "op": Op.CONST,
Kevin Cheng17e92022021-10-01 14:33:33 -07004748 "operands": (0, 1),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004749 "build_fcn": (
4750 build_const,
4751 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00004752 TosaTensorValuesGen.tvgLazyGenDefault,
4753 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004754 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004755 "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
evacha0198477222024-01-26 12:25:32 +00004756 "data_gen": {
4757 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4758 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004759 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004760 "identity": {
4761 "op": Op.IDENTITY,
4762 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004763 "build_fcn": (
4764 build_unary,
4765 TosaTensorGen.tgBasic,
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004766 TosaTensorValuesGen.tvgLazyGenDefault,
4767 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004768 ),
evacha011adff832024-03-06 17:33:44 +00004769 "types": TYPE_FIB + [DType.INT4, DType.INT48],
Jeremy Johnson2d70ac42023-11-06 17:46:02 +00004770 "data_gen": {
4771 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4772 },
Jared Smolens573ecd42021-03-04 15:24:10 -08004773 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004774 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08004775 "gather": {
4776 "op": Op.GATHER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004777 "operands": (2, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004778 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004779 "build_fcn": (
4780 build_gather,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004781 TosaTensorGen.tgGather,
4782 TosaTensorValuesGen.tvgGather,
4783 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004784 ),
James Ward24dbc422022-10-19 12:20:31 +01004785 "types": (
4786 DType.INT8,
4787 DType.INT16,
4788 DType.INT32,
4789 DType.FP16,
4790 DType.BF16,
4791 DType.FP32,
Won Jeon2c34b462024-02-06 18:37:00 +00004792 DType.FP8E4M3,
4793 DType.FP8E5M2,
James Ward24dbc422022-10-19 12:20:31 +01004794 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004795 "error_if_validators": (
4796 TosaErrorValidator.evWrongInputType,
4797 TosaErrorValidator.evWrongOutputType,
4798 TosaErrorValidator.evWrongInputList,
4799 TosaErrorValidator.evWrongOutputList,
4800 TosaErrorValidator.evWrongRank,
4801 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004802 "data_gen": {
4803 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4804 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004805 },
4806 "scatter": {
4807 "op": Op.SCATTER,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004808 "operands": (3, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004809 "rank": (3, 3),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004810 "build_fcn": (
4811 build_scatter,
4812 TosaTensorGen.tgScatter,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004813 TosaTensorValuesGen.tvgScatter,
4814 TosaArgGen.agNone,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004815 ),
Won Jeon2c34b462024-02-06 18:37:00 +00004816 "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004817 "error_if_validators": (
4818 TosaErrorValidator.evWrongInputType,
4819 TosaErrorValidator.evWrongOutputType,
4820 TosaErrorValidator.evWrongInputList,
4821 TosaErrorValidator.evWrongOutputList,
4822 TosaErrorValidator.evWrongRank,
4823 ),
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00004824 "data_gen": {
4825 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4826 },
Kevin Cheng550ccc52021-03-03 11:21:43 -08004827 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004828 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08004829 "resize": {
4830 "op": Op.RESIZE,
Tai Lyc5c2a7e2024-02-22 23:26:28 +00004831 "operands": (4, 0),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004832 "rank": (4, 4),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004833 "build_fcn": (
4834 build_resize,
4835 TosaTensorGen.tgNHWC,
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004836 TosaTensorValuesGen.tvgResize,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004837 TosaArgGen.agResize,
4838 ),
James Ward24dbc422022-10-19 12:20:31 +01004839 "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004840 "invalid_test_validators": (
4841 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004842 ),
4843 "error_if_validators": (
4844 TosaErrorValidator.evMaxDimExceeded,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004845 TosaErrorValidator.evScaleSmallerEqualZero,
4846 TosaErrorValidator.evScaleNLargerMax,
4847 TosaErrorValidator.evScaleDLargerMax,
4848 TosaErrorValidator.evOffsetSmallerMin,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004849 TosaErrorValidator.evOffsetLargerEqualMax,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004850 TosaErrorValidator.evBorderSmallerMin,
4851 TosaErrorValidator.evBorderLargerEqualMax,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004852 TosaErrorValidator.evWrongInputType,
4853 TosaErrorValidator.evWrongOutputType,
4854 TosaErrorValidator.evWrongRank,
4855 TosaErrorValidator.evWrongInputList,
4856 TosaErrorValidator.evWrongOutputList,
4857 TosaErrorValidator.evBatchMismatch,
4858 TosaErrorValidator.evChannelMismatch,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01004859 TosaErrorValidator.evResizeOutputShapeMismatch,
4860 TosaErrorValidator.evResizeOutputShapeNonInteger,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004861 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00004862 "data_gen": {
4863 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4864 },
4865 "compliance": {"relative": 0.006},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004866 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004867 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08004868 "cast": {
4869 "op": Op.CAST,
4870 "operands": (1, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004871 "build_fcn": (
4872 build_cast,
4873 TosaTensorGen.tgBasic,
Jeremy Johnson708da822023-11-15 16:25:45 +00004874 TosaTensorValuesGen.tvgCast,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004875 TosaArgGen.agCast,
4876 ),
James Ward8b390432022-08-12 20:48:56 +01004877 "types": (
4878 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01004879 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01004880 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01004881 DType.INT8,
4882 DType.INT16,
4883 DType.INT32,
4884 DType.BOOL,
Won Jeon2c34b462024-02-06 18:37:00 +00004885 DType.FP8E4M3,
4886 DType.FP8E5M2,
James Ward8b390432022-08-12 20:48:56 +01004887 ),
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004888 "error_if_validators": (
4889 TosaErrorValidator.evWrongInputType,
4890 TosaErrorValidator.evWrongOutputType,
4891 TosaErrorValidator.evWrongInputList,
4892 TosaErrorValidator.evWrongOutputList,
4893 ),
Jeremy Johnson708da822023-11-15 16:25:45 +00004894 "data_gen": {
4895 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
4896 },
4897 "compliance": {"ulp": 0.5},
Kevin Cheng550ccc52021-03-03 11:21:43 -08004898 },
4899 "rescale": {
4900 "op": Op.RESCALE,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004901 "operands": (3, 0),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004902 "build_fcn": (
4903 build_rescale,
4904 TosaTensorGen.tgBasic,
Tai Ly6e1e2bc2024-03-01 20:59:32 +00004905 TosaTensorValuesGen.tvgRescale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004906 TosaArgGen.agRescale,
4907 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004908 "types": [
4909 DType.UINT8,
4910 DType.INT8,
4911 DType.INT16,
4912 DType.INT32,
4913 DType.INT48,
4914 DType.UINT16,
4915 ],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004916 "error_if_validators": (
4917 TosaErrorValidator.evInputZeroPointNotZero,
4918 TosaErrorValidator.evOutputZeroPointNotZero,
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01004919 TosaErrorValidator.evU16InputZeroPointNotValid,
4920 TosaErrorValidator.evU16OutputZeroPointNotValid,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004921 TosaErrorValidator.evScaleTrue,
4922 TosaErrorValidator.evScaleNotTrue,
4923 TosaErrorValidator.evWrongInputType,
4924 TosaErrorValidator.evWrongOutputType,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004925 TosaErrorValidator.evWrongInputList,
4926 TosaErrorValidator.evWrongOutputList,
4927 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004928 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004929 # Custom
4930 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08004931 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07004932 # Two varients of cond_if, one that generates one of two constant tensors (no
4933 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
4934 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08004935 "cond_if_const": {
4936 "op": Op.COND_IF,
4937 "operands": (0, 2),
4938 "build_fcn": (
4939 build_cond_if_const,
4940 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004941 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004942 TosaArgGen.agCondIf,
4943 ),
4944 "types": [DType.BOOL],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004945 "error_if_validators": (
4946 TosaErrorValidator.evOutputListThenGraphMismatch,
4947 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004948 TosaErrorValidator.evCondIfCondNotMatchingBool,
4949 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004950 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004951 },
4952 "cond_if_binary": {
4953 "op": Op.COND_IF,
4954 "operands": (2, 0),
4955 "build_fcn": (
4956 build_cond_if_binary,
4957 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004958 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004959 TosaArgGen.agCondIf,
4960 ),
Les Bell6040b4d2021-10-11 12:50:31 +01004961 "types": TYPE_INT_FP,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004962 "error_if_validators": (
4963 TosaErrorValidator.evInputListThenGraphMismatch,
4964 TosaErrorValidator.evInputListElseGraphMismatch,
4965 TosaErrorValidator.evOutputListThenGraphMismatch,
4966 TosaErrorValidator.evOutputListElseGraphMismatch,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004967 TosaErrorValidator.evCondIfCondNotMatchingBool,
4968 TosaErrorValidator.evCondIfCondShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004969 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004970 },
Eric Kunzee5e26762020-10-13 16:11:07 -07004971 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08004972 "while_loop": {
4973 "op": Op.WHILE_LOOP,
4974 "operands": (0, 1),
4975 "build_fcn": (
4976 build_while_loop,
4977 TosaTensorGen.tgBasic,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01004978 TosaTensorValuesGen.tvgCondIfWhileLoop,
Kevin Cheng550ccc52021-03-03 11:21:43 -08004979 TosaArgGen.agWhileLoop,
4980 ),
4981 "types": [DType.INT32],
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004982 "error_if_validators": (
4983 TosaErrorValidator.evInputListOutputListMismatch,
4984 TosaErrorValidator.evInputListCondGraphMismatch,
4985 TosaErrorValidator.evInputListBodyGraphInputMismatch,
4986 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
4987 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
Jeremy Johnson05c711e2022-12-12 18:00:41 +00004988 TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00004989 ),
Kevin Cheng550ccc52021-03-03 11:21:43 -08004990 },
Luke Hutton57287132023-02-06 14:54:18 +00004991 "fft2d": {
4992 "op": Op.FFT2D,
4993 "operands": (2, 0),
4994 "rank": (3, 3),
4995 "build_fcn": (
4996 build_fft2d,
4997 TosaTensorGen.tgFFT2d,
Jeremy Johnsonc8330812024-01-18 16:57:28 +00004998 TosaTensorValuesGen.tvgLazyGenDefault,
Luke Hutton57287132023-02-06 14:54:18 +00004999 TosaArgGen.agFFT2d,
5000 ),
5001 "types": [DType.FP32],
5002 "error_if_validators": (
5003 TosaErrorValidator.evWrongInputType,
5004 TosaErrorValidator.evWrongOutputType,
5005 TosaErrorValidator.evWrongInputList,
5006 TosaErrorValidator.evWrongOutputList,
5007 TosaErrorValidator.evWrongRank,
5008 TosaErrorValidator.evBatchMismatch,
5009 TosaErrorValidator.evKernelNotPowerOfTwo,
5010 TosaErrorValidator.evFFTInputShapeMismatch,
5011 TosaErrorValidator.evFFTOutputShapeMismatch,
5012 ),
Jeremy Johnsonc8330812024-01-18 16:57:28 +00005013 "data_gen": {
5014 "fp": (gtu.DataGenType.DOT_PRODUCT,),
5015 },
Luke Hutton57287132023-02-06 14:54:18 +00005016 },
Luke Hutton261b7b62023-01-10 14:50:31 +00005017 "rfft2d": {
5018 "op": Op.RFFT2D,
5019 "operands": (1, 0),
5020 "rank": (3, 3),
5021 "build_fcn": (
5022 build_rfft2d,
5023 TosaTensorGen.tgRFFT2d,
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00005024 TosaTensorValuesGen.tvgLazyGenDefault,
5025 TosaArgGen.agRFFT2d,
Luke Hutton261b7b62023-01-10 14:50:31 +00005026 ),
5027 "types": [DType.FP32],
5028 "error_if_validators": (
5029 TosaErrorValidator.evWrongInputType,
5030 TosaErrorValidator.evWrongOutputType,
5031 TosaErrorValidator.evWrongInputList,
5032 TosaErrorValidator.evWrongOutputList,
5033 TosaErrorValidator.evWrongRank,
5034 TosaErrorValidator.evBatchMismatch,
5035 TosaErrorValidator.evKernelNotPowerOfTwo,
Luke Hutton57287132023-02-06 14:54:18 +00005036 TosaErrorValidator.evFFTOutputShapeMismatch,
Luke Hutton261b7b62023-01-10 14:50:31 +00005037 ),
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00005038 "data_gen": {
5039 "fp": (gtu.DataGenType.DOT_PRODUCT,),
5040 },
Luke Hutton261b7b62023-01-10 14:50:31 +00005041 },
Won Jeon74342e52024-01-09 00:34:40 +00005042 # Shape
5043 "add_shape": {
5044 "op": Op.ADD_SHAPE,
5045 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005046 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005047 "build_fcn": (
5048 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005049 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005050 TosaTensorValuesGen.tvgAddSub,
5051 TosaArgGen.agNone,
5052 ),
5053 "types": [DType.SHAPE],
5054 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5055 },
5056 "sub_shape": {
5057 "op": Op.SUB_SHAPE,
5058 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005059 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005060 "build_fcn": (
5061 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005062 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005063 TosaTensorValuesGen.tvgAddSub,
5064 TosaArgGen.agNone,
5065 ),
5066 "types": [DType.SHAPE],
5067 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5068 },
5069 "mul_shape": {
5070 "op": Op.MUL_SHAPE,
5071 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005072 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005073 "build_fcn": (
5074 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005075 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005076 TosaTensorValuesGen.tvgMul,
5077 TosaArgGen.agNone,
5078 ),
5079 "types": [DType.SHAPE],
5080 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5081 },
5082 "div_shape": {
5083 "op": Op.DIV_SHAPE,
5084 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005085 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005086 "build_fcn": (
5087 build_shape_op,
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005088 TosaTensorGen.tgBasic,
Won Jeon74342e52024-01-09 00:34:40 +00005089 TosaTensorValuesGen.tvgIntDiv,
5090 TosaArgGen.agNone,
5091 ),
5092 "types": [DType.SHAPE],
5093 "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
5094 },
5095 "concat_shape": {
5096 "op": Op.CONCAT_SHAPE,
5097 "operands": (2, 0),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005098 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005099 "build_fcn": (
5100 build_concat,
5101 TosaTensorGen.tgConcat,
5102 TosaTensorValuesGen.tvgConcat,
5103 TosaArgGen.agNone,
5104 ),
5105 "types": [DType.SHAPE],
5106 "error_if_validators": (),
5107 },
5108 "const_shape": {
5109 "op": Op.CONST_SHAPE,
5110 "operands": (0, 1),
Jeremy Johnsonfc4bde92024-01-25 12:53:21 +00005111 "rank": (1, 1),
Won Jeon74342e52024-01-09 00:34:40 +00005112 "build_fcn": (
5113 build_const,
5114 TosaTensorGen.tgBasic,
evacha0198477222024-01-26 12:25:32 +00005115 TosaTensorValuesGen.tvgLazyGenDefault,
5116 TosaArgGen.agNone,
Won Jeon74342e52024-01-09 00:34:40 +00005117 ),
5118 "types": [DType.SHAPE],
5119 },
Eric Kunzee5e26762020-10-13 16:11:07 -07005120 }
5121
Kevin Cheng550ccc52021-03-03 11:21:43 -08005122
Eric Kunzee5e26762020-10-13 16:11:07 -07005123class OutputShaper:
5124 # Methods in this class compute the expected output shape and datatype
5125 # for common classes of operations
5126 def __init__(self):
5127 pass
5128
5129 # These methods return arguments that can be used for
5130 # creating a new output tensor
5131 @staticmethod
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005132 def binaryBroadcastOp(ser, rng, a, b, error_name=None):
5133 if error_name != ErrorIf.RankMismatch:
5134 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005135 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005136
5137 shape = []
5138 for i in range(len(a.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005139 if a.shape[i] == 1 and error_name is None:
Eric Kunzee5e26762020-10-13 16:11:07 -07005140 shape.append(b.shape[i])
5141 else:
5142 shape.append(a.shape[i])
5143
Jerry Ge135c9552023-05-23 20:59:32 +00005144 fuzz_idx = rng.integers(0, len(a.shape))
5145 if error_name == ErrorIf.DimensionMismatch:
5146 shape[fuzz_idx] += 1
5147
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005148 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005149 all_dtypes = [
5150 DType.INT8,
5151 DType.INT16,
5152 DType.INT32,
5153 DType.INT48,
James Ward24dbc422022-10-19 12:20:31 +01005154 DType.FP16,
5155 DType.BF16,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005156 DType.FP32,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005157 ]
Matthew Haddoneacff9a2021-09-24 14:42:13 +01005158 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5159 outputDType = rng.choice(wrong_dtypes)
5160 else:
5161 outputDType = a.dtype
5162
5163 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005164
5165 @staticmethod
5166 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005167 assert len(a.shape) == len(b.shape)
5168 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005169
5170 shape = []
5171 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005172 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07005173 shape.append(a.shape[i])
5174
Kevin Cheng550ccc52021-03-03 11:21:43 -08005175 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005176
5177 @staticmethod
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005178 def unaryOp(ser, rng, a, error_name=None):
5179 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005180 all_dtypes = [
5181 DType.INT8,
5182 DType.INT16,
5183 DType.INT32,
5184 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005185 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005186 DType.FP16,
5187 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005188 ]
Matthew Haddone4ecdb22021-09-28 11:38:21 +01005189 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5190 outputDType = rng.choice(wrong_dtypes)
5191 else:
5192 outputDType = a.dtype
5193
5194 return ser.addOutput(a.shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005195
5196 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005197 def selectOp(ser, rng, cond, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005198 if error_name != ErrorIf.RankMismatch:
5199 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005200 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005201
5202 shape = []
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005203 for i in range(len(cond.shape)):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005204 if cond.shape[i] == 1 and error_name is None:
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005205 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
5206 else:
5207 shape.append(cond.shape[i])
Eric Kunzee5e26762020-10-13 16:11:07 -07005208
Jerry Ge135c9552023-05-23 20:59:32 +00005209 fuzz_idx = rng.integers(0, len(a.shape))
5210 if error_name == ErrorIf.DimensionMismatch:
5211 shape[fuzz_idx] += 1
5212
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005213 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005214 all_dtypes = [
5215 DType.INT8,
5216 DType.INT16,
5217 DType.INT32,
5218 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005219 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005220 DType.FP16,
5221 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005222 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005223 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5224 outputDType = rng.choice(wrong_dtypes)
5225 else:
5226 outputDType = a.dtype
5227
5228 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005229
5230 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005231 def binaryComparisonOp(ser, rng, a, b, error_name=None):
Jeremy Johnson7e9ac9a2021-11-08 18:10:51 +00005232 if error_name != ErrorIf.RankMismatch:
5233 assert len(a.shape) == len(b.shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -08005234 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005235
5236 # Do broadcast
5237 shape = []
5238 for i in range(len(a.shape)):
Eric Kunzea1d49852022-01-04 10:07:29 -08005239 if a.shape[i] == 1 and len(b.shape) > i:
Eric Kunzee5e26762020-10-13 16:11:07 -07005240 shape.append(b.shape[i])
5241 else:
5242 shape.append(a.shape[i])
5243
Jerry Ge135c9552023-05-23 20:59:32 +00005244 fuzz_idx = rng.integers(0, len(a.shape))
5245 if error_name == ErrorIf.DimensionMismatch:
5246 shape[fuzz_idx] += 1
5247
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005248 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005249 wrong_dtypes = [
5250 DType.INT8,
5251 DType.INT16,
5252 DType.INT32,
5253 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005254 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005255 DType.FP16,
5256 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005257 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005258 outputDType = rng.choice(wrong_dtypes)
5259 else:
5260 outputDType = DType.BOOL
5261
5262 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005263
5264 @staticmethod
Matthew Haddond6ce7252021-09-29 15:35:44 +01005265 def reduceOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005266 shape = a.shape.copy()
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005267 if error_name not in [
5268 ErrorIf.AxisSmallerZero,
5269 ErrorIf.AxisLargerRank,
5270 ErrorIf.ShapeOfAxisNotOne,
5271 ]:
Matthew Haddond6ce7252021-09-29 15:35:44 +01005272 shape[axis] = 1
5273 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
5274 shape[axis] = rng.integers(2, 10)
Eric Kunzee5e26762020-10-13 16:11:07 -07005275
Matthew Haddond6ce7252021-09-29 15:35:44 +01005276 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005277 all_dtypes = [
5278 DType.INT8,
5279 DType.INT16,
5280 DType.INT32,
5281 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005282 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005283 DType.FP16,
5284 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005285 ]
Matthew Haddond6ce7252021-09-29 15:35:44 +01005286 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5287 outputDType = rng.choice(wrong_dtypes)
5288 else:
5289 outputDType = a.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005290
Matthew Haddond6ce7252021-09-29 15:35:44 +01005291 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005292
5293 @staticmethod
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005294 def argmaxOp(ser, rng, a, axis, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005295 shape = a.shape.copy()
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005296
5297 if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
5298 del shape[axis]
5299
5300 if error_name == ErrorIf.ArgmaxOutputRankMismatch:
5301 remove = rng.choice([True, False])
5302 if remove and len(shape) > 1:
5303 del shape[0]
5304 else:
5305 shape.append(1)
5306 elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
5307 for i in range(len(shape)):
5308 shape[i] = shape[i] + rng.integers(1, 10)
5309
5310 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005311 all_dtypes = [
5312 DType.INT8,
5313 DType.INT16,
5314 DType.INT32,
5315 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005316 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005317 DType.FP16,
5318 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005319 DType.FP8E4M3,
5320 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005321 ]
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005322 wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
5323 outputDType = rng.choice(wrong_dtypes)
5324 else:
5325 outputDType = DType.INT32
5326
5327 return ser.addOutput(shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005328
5329 @staticmethod
Tai Lyf36f2562024-03-14 16:21:29 +00005330 def _get_conv_output_type(input_dtype):
5331 if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
5332 return input_dtype
5333 elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
5334 return DType.FP16
5335 elif input_dtype in (DType.INT8, DType.INT4):
5336 return DType.INT32
5337 elif input_dtype in (DType.INT16,):
5338 return DType.INT48
5339 assert True, f"Unsupported convolution data type {input_dtype}"
5340
5341 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005342 def conv2dOp(
5343 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5344 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005345
5346 # IFM: NHWC
5347 # Filter: OHWI
5348 # OFM: NHWC
5349
Kevin Cheng550ccc52021-03-03 11:21:43 -08005350 h = (
5351 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005352 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005353 + padding[0]
5354 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005355 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005356 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005357
Kevin Cheng550ccc52021-03-03 11:21:43 -08005358 w = (
5359 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005360 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005361 + padding[2]
5362 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005363 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005364 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005365
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005366 if error_name == ErrorIf.ConvOutputShapeMismatch:
5367 choices = [1, 2, 3]
5368 change = rng.choice(choices)
5369 # increment in multiples of stride to not hit non-integer error case
5370 if change in [1, 3]:
5371 h = h + (rng.choice(choices) * strides[0])
5372 if change in [2, 3]:
5373 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005374
Eric Kunzee5e26762020-10-13 16:11:07 -07005375 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
5376
James Ward8b390432022-08-12 20:48:56 +01005377 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005378 # Pick some potentially correct output dtype if input type is incorrect
5379 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005380 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005381 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005382
5383 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005384 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005385 excludes = [DType.FP16, DType.FP32]
Jeremy Johnson80fd9b82024-03-12 11:46:50 +00005386 elif ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
Won Jeon2c34b462024-02-06 18:37:00 +00005387 excludes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01005388 else:
5389 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005390 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005391 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005392
Kevin Cheng550ccc52021-03-03 11:21:43 -08005393 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005394
5395 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005396 def conv3dOp(
5397 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
5398 ):
Kevin Cheng1533b852021-09-01 12:51:58 -07005399
5400 # IFM: NDHWC
5401 # Filter: ODHWI
5402 # OFM: NDHWC
5403
5404 d = (
5405 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005406 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005407 + padding[0]
5408 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005409 - (filter.shape[1] - 1) * dilations[0]
Kevin Cheng1533b852021-09-01 12:51:58 -07005410 ) // strides[0] + 1
5411
5412 h = (
5413 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005414 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005415 + padding[2]
5416 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005417 - (filter.shape[2] - 1) * dilations[1]
Kevin Cheng1533b852021-09-01 12:51:58 -07005418 ) // strides[1] + 1
5419
5420 w = (
5421 ifm.shape[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005422 - 1
Kevin Cheng1533b852021-09-01 12:51:58 -07005423 + padding[4]
5424 + padding[5]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005425 - (filter.shape[3] - 1) * dilations[2]
Kevin Cheng1533b852021-09-01 12:51:58 -07005426 ) // strides[2] + 1
5427
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005428 if error_name == ErrorIf.ConvOutputShapeMismatch:
5429 choices = [1, 2, 3, 4]
5430 change = rng.choice(choices)
5431 # increment in multiples of stride to not hit non-integer error case
5432 if change in [1, 4]:
5433 d = d + (rng.choice(choices) * strides[0])
5434 if change in [2, 4]:
5435 h = h + (rng.choice(choices) * strides[1])
5436 if change in [3, 4]:
5437 w = w + (rng.choice(choices) * strides[2])
Les Bell0e027d42021-11-09 14:42:14 +00005438
Kevin Cheng1533b852021-09-01 12:51:58 -07005439 ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
5440
James Ward8b390432022-08-12 20:48:56 +01005441 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005442 # Pick some potentially correct output dtype if input type is incorrect
5443 out_dtype = DType.INT32
Kevin Cheng1533b852021-09-01 12:51:58 -07005444 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005445 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005446
5447 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005448 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005449 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005450 else:
5451 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005452 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005453 out_dtype = rng.choice(wrong_dtypes)
Kevin Cheng1533b852021-09-01 12:51:58 -07005454
5455 return ser.addOutput(ofm_shape, out_dtype)
5456
5457 @staticmethod
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005458 def depthwiseConv2dOp(
James Ward8b390432022-08-12 20:48:56 +01005459 ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005460 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07005461 # IFM: NHWC
5462 # Filter: HWCM
5463 # OFM: NHW C*M
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005464
Kevin Cheng550ccc52021-03-03 11:21:43 -08005465 h = (
5466 ifm.shape[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005467 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005468 + padding[0]
5469 + padding[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005470 - (filter.shape[0] - 1) * dilations[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005471 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005472
Kevin Cheng550ccc52021-03-03 11:21:43 -08005473 w = (
5474 ifm.shape[2]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005475 - 1
Kevin Cheng550ccc52021-03-03 11:21:43 -08005476 + padding[2]
5477 + padding[3]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005478 - (filter.shape[1] - 1) * dilations[1]
Kevin Cheng550ccc52021-03-03 11:21:43 -08005479 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07005480
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005481 if error_name == ErrorIf.ConvOutputShapeMismatch:
5482 choices = [1, 2, 3]
5483 change = rng.choice(choices)
5484 # increment in multiples of stride to not hit non-integer error case
5485 if change in [1, 3]:
5486 h = h + (rng.choice(choices) * strides[0])
5487 if change in [2, 3]:
5488 w = w + (rng.choice(choices) * strides[1])
Les Bell0e027d42021-11-09 14:42:14 +00005489
Eric Kunzee5e26762020-10-13 16:11:07 -07005490 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
5491
James Ward8b390432022-08-12 20:48:56 +01005492 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00005493 # Pick some potentially correct output dtype if input type is incorrect
5494 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005495 else:
Tai Lyf36f2562024-03-14 16:21:29 +00005496 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00005497
5498 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01005499 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005500 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01005501 else:
5502 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01005503 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00005504 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07005505
Kevin Cheng550ccc52021-03-03 11:21:43 -08005506 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005507
5508 @staticmethod
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005509 def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005510 # input: NHWC
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005511 if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005512 # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005513 h = 1
5514 w = 1
5515 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005516 h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
5517 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005518
5519 if error_name == ErrorIf.PoolingOutputShapeMismatch:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005520 choices = [1, 2, 3]
5521 change = rng.choice(choices)
5522 # increment in multiples of stride to not hit non-integer error case
5523 if change in [1, 3]:
5524 h = h + (rng.choice(choices) * stride[0])
5525 if change in [2, 3]:
5526 w = w + (rng.choice(choices) * stride[1])
Eric Kunzee5e26762020-10-13 16:11:07 -07005527 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005528
5529 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005530 all_dtypes = [
5531 DType.INT8,
5532 DType.INT16,
5533 DType.INT32,
5534 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005535 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005536 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005537 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005538 DType.FP8E4M3,
5539 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005540 ]
Matthew Haddonb6b59e32021-10-07 17:19:20 +01005541 wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
5542 outputDType = rng.choice(wrong_dtypes)
5543 else:
5544 outputDType = ifm.dtype
5545
5546 return ser.addOutput(ofm_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005547
5548 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005549 def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005550 # input: N, IC
5551 # filter: OC, IC
5552 # output: N, OC
5553
5554 output_shape = [input.shape[0], filter.shape[0]]
5555
James Ward8b390432022-08-12 20:48:56 +01005556 # Validated in arg_gen (also invalidated for ErrorIf)
5557 out_dtype = accum_dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07005558
Kevin Cheng550ccc52021-03-03 11:21:43 -08005559 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005560
5561 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005562 def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
Kevin Cheng2d60f002021-06-09 14:18:32 -07005563 # a: N, H, C
5564 # b: N, C, W
5565 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07005566
Kevin Cheng2d60f002021-06-09 14:18:32 -07005567 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005568
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005569 if error_name == ErrorIf.WrongOutputType:
5570 if a.dtype == DType.INT8:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005571 incorrect_types = (
5572 DType.INT4,
5573 DType.INT8,
5574 DType.INT16,
5575 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005576 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005577 DType.FP16,
5578 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005579 DType.FP8E4M3,
5580 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005581 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005582 elif a.dtype == DType.INT16:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005583 incorrect_types = (
5584 DType.INT4,
5585 DType.INT8,
5586 DType.INT16,
5587 DType.INT32,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005588 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005589 DType.FP16,
5590 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005591 DType.FP8E4M3,
5592 DType.FP8E5M2,
5593 )
5594 elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
5595 incorrect_types = (
5596 DType.INT4,
5597 DType.INT8,
5598 DType.INT16,
5599 DType.INT32,
5600 DType.INT48,
5601 DType.FP32,
5602 DType.BF16,
5603 DType.FP8E4M3,
5604 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005605 )
James Ward24dbc422022-10-19 12:20:31 +01005606 elif (
5607 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
5608 ):
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005609 incorrect_types = (
5610 DType.INT4,
5611 DType.INT8,
5612 DType.INT16,
5613 DType.INT32,
5614 DType.INT48,
Won Jeon2c34b462024-02-06 18:37:00 +00005615 DType.FP8E4M3,
5616 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005617 )
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005618 out_dtype = rng.choice(a=incorrect_types)
Matthew Haddonc4cf0372021-10-11 09:38:10 +01005619 elif error_name == ErrorIf.WrongInputType:
5620 # Pick some potentially correct output dtype if input type is incorrect
5621 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07005622 else:
James Ward8b390432022-08-12 20:48:56 +01005623 out_dtype = accum_dtype # Validated in arg_gen
Eric Kunzee5e26762020-10-13 16:11:07 -07005624
Kevin Cheng550ccc52021-03-03 11:21:43 -08005625 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005626
5627 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00005628 def concatOp(ser, rng, axis, inputs, error_name=None):
5629 input1 = inputs[0]
5630 remaining_inputs = inputs[1:]
Eric Kunzee5e26762020-10-13 16:11:07 -07005631
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005632 # calculate the output shape, if possible, otherwise just use the first input shape
Matthew Haddon818ab902021-07-27 09:12:49 +01005633 output_shape = input1.shape.copy()
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005634 if not (
5635 # unable to concat tensors of different ranks
5636 error_name == ErrorIf.ConcatInputRankMismatch
5637 # unable to concat tensors along an invalid axis
5638 or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005639 ):
5640 for tensor in remaining_inputs:
5641 output_shape[axis] += tensor.shape[axis]
Eric Kunzee5e26762020-10-13 16:11:07 -07005642
Matthew Haddon01c359d2021-10-15 16:30:48 +01005643 if error_name == ErrorIf.ConcatShapeSumMismatch:
5644 output_shape[axis] += rng.integers(5, 10)
5645
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005646 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005647 all_dtypes = {
5648 DType.INT8,
5649 DType.INT16,
5650 DType.INT32,
5651 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005652 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005653 DType.FP16,
5654 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005655 }
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005656 wrong_dtypes = list(all_dtypes - set([input1.dtype]))
5657 outputDType = rng.choice(wrong_dtypes)
5658 else:
5659 outputDType = input1.dtype
Matthew Haddon818ab902021-07-27 09:12:49 +01005660
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005661 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005662
5663 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005664 def padOp(ser, rng, a, padding, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005665
5666 output_shape = a.shape.copy()
5667
5668 for i in range(len(output_shape)):
5669 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
5670
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005671 if error_name == ErrorIf.PadOutputShapeMismatch:
5672 bad_dim = rng.choice(range(len(output_shape)))
Tai Lye095da72024-01-25 22:00:18 +00005673 output_shape[bad_dim] += rng.choice([1, 2])
Luke Huttona4e48ca2023-02-22 11:53:48 +00005674 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005675 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Jeremy Johnsond32c6da2022-08-24 17:09:09 +01005676
Matthew Haddone807aae2021-10-11 18:12:58 +01005677 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005678 all_dtypes = [
5679 DType.INT8,
5680 DType.INT16,
5681 DType.INT32,
5682 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005683 DType.FP32,
James Ward8b390432022-08-12 20:48:56 +01005684 DType.FP16,
James Ward24dbc422022-10-19 12:20:31 +01005685 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005686 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005687 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5688 outputDType = rng.choice(wrong_dtypes)
5689 else:
5690 outputDType = a.dtype
5691
5692 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005693
5694 @staticmethod
Won Jeona21b2e82023-08-10 10:33:01 +00005695 def dimOp(ser, rng, a, axis, error_name=None):
Tai Ly8690a082023-12-18 20:40:24 +00005696 output_shape = [1]
Won Jeona21b2e82023-08-10 10:33:01 +00005697
5698 if error_name == ErrorIf.WrongOutputType:
5699 all_dtypes = [
5700 DType.INT8,
5701 DType.INT16,
5702 DType.INT32,
5703 DType.INT48,
5704 DType.FP32,
5705 DType.FP16,
5706 DType.BF16,
5707 ]
5708 wrong_dtypes = list(set(all_dtypes))
5709 outputDType = rng.choice(wrong_dtypes)
5710 else:
5711 outputDType = DType.SHAPE
5712
5713 return ser.addOutput(output_shape, outputDType)
5714
5715 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005716 def reshapeOp(ser, rng, a, shape, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005717 output_shape = shape.copy()
5718
Matthew Haddone807aae2021-10-11 18:12:58 +01005719 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5720 for i in range(len(output_shape)):
5721 output_shape[i] = output_shape[i] + rng.integers(1, 10)
5722
5723 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005724 all_dtypes = [
5725 DType.INT8,
5726 DType.INT16,
5727 DType.INT32,
5728 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005729 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005730 DType.FP16,
5731 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005732 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005733 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5734 outputDType = rng.choice(wrong_dtypes)
5735 else:
5736 outputDType = a.dtype
5737
5738 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005739
5740 @staticmethod
Luke Huttona4e48ca2023-02-22 11:53:48 +00005741 def sliceOp(ser, rng, input, start, size, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005742
Matthew Haddone807aae2021-10-11 18:12:58 +01005743 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005744 all_dtypes = [
5745 DType.INT8,
5746 DType.INT16,
5747 DType.INT32,
5748 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005749 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005750 DType.FP16,
5751 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005752 ]
Luke Huttona4e48ca2023-02-22 11:53:48 +00005753 wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
Matthew Haddone807aae2021-10-11 18:12:58 +01005754 outputDType = rng.choice(wrong_dtypes)
5755 else:
Luke Huttona4e48ca2023-02-22 11:53:48 +00005756 outputDType = input.dtype
Matthew Haddone807aae2021-10-11 18:12:58 +01005757
Luke Huttona4e48ca2023-02-22 11:53:48 +00005758 output_shape = size.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005759 if error_name == ErrorIf.SizeOutputShapeMismatch:
Matthew Haddone807aae2021-10-11 18:12:58 +01005760 for index in range(len(output_shape)):
5761 if output_shape[index] <= 2:
5762 output_shape[index] = output_shape[index] + rng.choice([1, 2])
5763 else:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005764 output_shape[index] = output_shape[index] + rng.choice(
5765 [-2, -1, 1, 2]
5766 )
Luke Huttona4e48ca2023-02-22 11:53:48 +00005767 elif error_name == ErrorIf.InputSizeStartLengthMismatch:
5768 output_shape = input.shape.copy()
5769 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005770 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Matthew Haddone807aae2021-10-11 18:12:58 +01005771
5772 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005773
5774 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005775 def tileOp(ser, rng, a, multiples, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005776
5777 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08005778 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005779
5780 for i in range(len(output_shape)):
5781 output_shape[i] = a.shape[i] * multiples[i]
5782
Luke Huttona4e48ca2023-02-22 11:53:48 +00005783 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005784 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005785
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005786 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005787 all_dtypes = [
5788 DType.INT8,
5789 DType.INT16,
5790 DType.INT32,
5791 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005792 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005793 DType.FP16,
5794 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005795 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005796 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5797 outputDType = rng.choice(wrong_dtypes)
5798 else:
5799 outputDType = a.dtype
5800
5801 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005802
5803 @staticmethod
Matthew Haddone807aae2021-10-11 18:12:58 +01005804 def transposeOp(ser, rng, a, perms, error_name=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07005805 output_shape = a.shape.copy()
Matthew Haddone807aae2021-10-11 18:12:58 +01005806
Kevin Cheng550ccc52021-03-03 11:21:43 -08005807 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07005808
Luke Huttona4e48ca2023-02-22 11:53:48 +00005809 if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
Matthew Haddone807aae2021-10-11 18:12:58 +01005810 for i in range(len(output_shape)):
5811 output_shape[i] = a.shape[perms[i]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005812
Luke Huttona4e48ca2023-02-22 11:53:48 +00005813 if error_name == ErrorIf.TensorSizeInputOutputMismatch:
5814 for i in range(len(output_shape)):
5815 output_shape[i] += rng.integers(1, 10)
5816 elif error_name == ErrorIf.RankMismatch:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005817 output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
Luke Huttona4e48ca2023-02-22 11:53:48 +00005818
Matthew Haddone807aae2021-10-11 18:12:58 +01005819 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005820 all_dtypes = [
5821 DType.INT8,
5822 DType.INT16,
5823 DType.INT32,
5824 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005825 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005826 DType.FP16,
5827 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005828 ]
Matthew Haddone807aae2021-10-11 18:12:58 +01005829 wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
5830 outputDType = rng.choice(wrong_dtypes)
5831 else:
5832 outputDType = a.dtype
5833
5834 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005835
5836 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005837 def gatherOp(ser, rng, values, indices, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005838 if error_name != ErrorIf.WrongRank:
5839 assert len(values.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005840 assert len(indices.shape) == 2
5841 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07005842
Kevin Cheng77d0f762020-11-24 10:26:32 -08005843 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
5844
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005845 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005846 all_dtypes = [
5847 DType.INT8,
5848 DType.INT16,
5849 DType.INT32,
5850 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005851 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005852 DType.FP16,
5853 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005854 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005855 wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
5856 outputDType = rng.choice(wrong_dtypes)
5857 else:
5858 outputDType = values.dtype
5859
5860 return ser.addOutput(output_shape, outputDType)
Kevin Cheng77d0f762020-11-24 10:26:32 -08005861
5862 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005863 def scatterOp(ser, rng, values_in, indices, input, error_name=None):
Jeremy Johnson3ca02a72021-11-18 12:18:39 +00005864 if error_name != ErrorIf.WrongRank:
5865 assert len(values_in.shape) == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -08005866 assert len(indices.shape) == 2
5867 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08005868 assert values_in.shape[0] == indices.shape[0] # N
5869 assert input.shape[1] == indices.shape[1] # W
5870 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08005871
5872 output_shape = values_in.shape
5873
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005874 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005875 all_dtypes = [
5876 DType.INT8,
5877 DType.INT16,
5878 DType.INT32,
5879 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005880 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005881 DType.FP16,
5882 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +00005883 DType.FP8E4M3,
5884 DType.FP8E5M2,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005885 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005886 wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
5887 outputDType = rng.choice(wrong_dtypes)
5888 else:
5889 outputDType = values_in.dtype
5890
5891 return ser.addOutput(output_shape, outputDType)
Eric Kunzee5e26762020-10-13 16:11:07 -07005892
5893 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005894 def tableOp(ser, rng, input, error_name=None):
5895 # Same shape as the input, dtype dependent on input dtype
5896 if error_name != ErrorIf.WrongInputType:
5897 assert input.dtype == DType.INT16 or input.dtype == DType.INT8
Kevin Chengfe392ce2021-10-18 21:51:55 +00005898 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005899 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005900 wrong_dtypes = [
5901 DType.INT8,
5902 DType.INT16,
5903 DType.INT32,
5904 DType.INT48,
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01005905 DType.FP32,
James Ward24dbc422022-10-19 12:20:31 +01005906 DType.FP16,
5907 DType.BF16,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005908 ]
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005909 wrong_dtypes.remove(output_dtype)
5910 output_dtype = rng.choice(wrong_dtypes)
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01005911 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005912
5913 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08005914 def resizeOp(
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005915 serializer,
5916 rng,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005917 input,
5918 mode,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005919 scale,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005920 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005921 border,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005922 input_dtype,
5923 output_dtype,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005924 error_name=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -08005925 ):
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005926 # Calculate OH, OW
5927 scale_y_n = scale[0]
5928 scale_y_d = scale[1]
5929 scale_x_n = scale[2]
5930 scale_x_d = scale[3]
5931 if error_name == ErrorIf.ScaleSmallerEqualZero:
5932 scale_y_n = max(scale_y_n, 1)
5933 scale_y_d = max(scale_y_d, 1)
5934 scale_x_n = max(scale_x_n, 1)
5935 scale_x_d = max(scale_x_d, 1)
5936
5937 oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
5938 ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
5939
5940 if error_name is not None:
5941 # Make sure the output tensor is valid, which can occur when
5942 # scale, offset or border have been changed for ERROR_IFs
5943 oh = max(oh, 1)
5944 ow = max(ow, 1)
5945 if error_name != ErrorIf.MaxDimExceeded:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005946 oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
5947 ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005948
5949 if error_name == ErrorIf.ResizeOutputShapeMismatch:
5950 choices = [1, 2, 3]
5951 change = rng.choice(choices)
5952 # increment in multiples of scale_y/x_d so we don't hit non-integer error case
5953 if change in [1, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005954 if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005955 oh -= scale_y_d
5956 assert oh > 0 # Should have been caught in agResize
5957 else:
5958 oh += scale_y_d
5959 if change in [2, 3]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01005960 if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005961 ow -= scale_x_d
5962 assert ow > 0 # Should have been caught in agResize
5963 else:
5964 ow += scale_x_d
5965
Matthew Haddon848efb42021-09-09 12:30:53 +01005966 if error_name == ErrorIf.WrongRank:
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005967 output_dims = [
5968 input.shape[0],
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005969 oh,
5970 ow,
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00005971 input.shape[0],
5972 ]
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005973 elif error_name == ErrorIf.BatchMismatch:
5974 output_dims = [
5975 input.shape[0] + rng.integers(1, 10),
5976 oh,
5977 ow,
5978 input.shape[3],
5979 ]
5980 elif error_name == ErrorIf.ChannelMismatch:
5981 output_dims = [
5982 input.shape[0],
5983 oh,
5984 ow,
5985 input.shape[3] + rng.integers(1, 10),
5986 ]
Matthew Haddon848efb42021-09-09 12:30:53 +01005987 else:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01005988 output_dims = [input.shape[0], oh, ow, input.shape[3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07005989
Matthew Haddon693ba9e2021-09-22 11:24:37 +01005990 return serializer.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005991
5992 @staticmethod
Matthew Haddonbb5676f2021-10-13 11:30:30 +01005993 def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
Kevin Cheng550ccc52021-03-03 11:21:43 -08005994 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07005995
5996 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01005997 def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01005998 if error_name == ErrorIf.ConvOutputShapeMismatch:
5999 choices = [1, 2, 3]
6000 change = rng.choice(choices)
6001 if change in [1, 3]:
6002 output_shape[1] = output_shape[1] + rng.choice(choices)
6003 if change in [2, 3]:
6004 output_shape[2] = output_shape[2] + rng.choice(choices)
6005
James Ward8b390432022-08-12 20:48:56 +01006006 if error_name == ErrorIf.WrongInputType:
Les Bell0e027d42021-11-09 14:42:14 +00006007 # Pick some potentially correct output dtype if input type is incorrect
6008 out_dtype = DType.INT32
Eric Kunzee5e26762020-10-13 16:11:07 -07006009 else:
Tai Lyf36f2562024-03-14 16:21:29 +00006010 out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
Les Bell0e027d42021-11-09 14:42:14 +00006011
6012 if error_name == ErrorIf.WrongOutputType:
James Ward8b390432022-08-12 20:48:56 +01006013 if ifm.dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01006014 excludes = [DType.FP16, DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01006015 else:
6016 excludes = [out_dtype]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006017 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Les Bell0e027d42021-11-09 14:42:14 +00006018 out_dtype = rng.choice(wrong_dtypes)
Eric Kunzee5e26762020-10-13 16:11:07 -07006019
Kevin Cheng550ccc52021-03-03 11:21:43 -08006020 return ser.addOutput(output_shape, out_dtype)
Luke Hutton261b7b62023-01-10 14:50:31 +00006021
6022 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +00006023 def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
6024 outputs = []
6025
6026 assert ifm1.dtype == ifm2.dtype
6027 input_dtype = ifm1.dtype
6028
6029 if error_name != ErrorIf.FFTInputShapeMismatch:
6030 assert ifm1.shape == ifm2.shape
6031
6032 input_shape = ifm1.shape
6033 if error_name != ErrorIf.WrongRank:
6034 assert len(input_shape) == 3
6035
6036 output_shape = input_shape.copy()
6037 output_dtype = input_dtype
6038
6039 if error_name == ErrorIf.WrongOutputType:
6040 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006041 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton57287132023-02-06 14:54:18 +00006042 output_dtype = rng.choice(wrong_dtypes)
6043 elif error_name == ErrorIf.BatchMismatch:
6044 output_shape[0] += rng.integers(1, 10)
6045 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6046 modify_dim = rng.choice([1, 2])
6047 output_shape[modify_dim] += rng.integers(1, 10)
6048
6049 outputs.append(serializer.addOutput(output_shape, output_dtype))
6050 outputs.append(serializer.addOutput(output_shape, output_dtype))
6051 return outputs
6052
6053 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +00006054 def rfft2dOp(serializer, rng, value, error_name=None):
6055 outputs = []
6056
6057 input_shape = value.shape
6058 if error_name != ErrorIf.WrongRank:
6059 assert len(input_shape) == 3
6060
6061 output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
6062
6063 output_dtype = value.dtype
6064 if error_name == ErrorIf.WrongOutputType:
6065 excludes = [DType.FP32]
Jeremy Johnson1271c442023-09-05 11:39:26 +01006066 wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
Luke Hutton261b7b62023-01-10 14:50:31 +00006067 output_dtype = rng.choice(wrong_dtypes)
6068 elif error_name == ErrorIf.BatchMismatch:
Luke Hutton57287132023-02-06 14:54:18 +00006069 output_shape[0] += rng.integers(1, 10)
6070 elif error_name == ErrorIf.FFTOutputShapeMismatch:
6071 modify_dim = rng.choice([1, 2])
6072 output_shape[modify_dim] += rng.integers(1, 10)
Luke Hutton261b7b62023-01-10 14:50:31 +00006073
6074 outputs.append(serializer.addOutput(output_shape, output_dtype))
6075 outputs.append(serializer.addOutput(output_shape, output_dtype))
6076 return outputs
Won Jeon74342e52024-01-09 00:34:40 +00006077
6078 @staticmethod
6079 def addShapeOp(ser, rng, a, b, error_name=None):
6080 if error_name != ErrorIf.RankMismatch:
6081 assert len(a.shape) == len(b.shape)
6082 assert a.dtype == b.dtype
6083
6084 shape = []
6085 for i in range(len(a.shape)):
6086 shape.append(a.shape[i])
6087
6088 fuzz_idx = rng.integers(0, len(a.shape))
6089 if error_name == ErrorIf.DimensionMismatch:
6090 shape[fuzz_idx] += 1
6091
6092 if error_name == ErrorIf.WrongOutputType:
6093 wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
6094 outputDType = rng.choice(wrong_dtypes)
6095 else:
6096 outputDType = DType.SHAPE
6097 return ser.addOutput(shape, outputDType)